From 19a8f03ad59abacb3fae60f7ba6d5d358e779c29 Mon Sep 17 00:00:00 2001
From: bunnei <bunneidev@gmail.com>
Date: Thu, 10 Dec 2020 15:32:52 -0800
Subject: [PATCH] hle: service: nvflinger: Refactor locking and interfaces.

---
 src/core/hle/service/nvflinger/nvflinger.cpp | 23 +++++-----
 src/core/hle/service/nvflinger/nvflinger.h   |  7 +--
 src/core/hle/service/vi/vi.cpp               | 46 +++++++-------------
 3 files changed, 31 insertions(+), 45 deletions(-)

diff --git a/src/core/hle/service/nvflinger/nvflinger.cpp b/src/core/hle/service/nvflinger/nvflinger.cpp
index 44aa2bdaea..a7a679df1d 100644
--- a/src/core/hle/service/nvflinger/nvflinger.cpp
+++ b/src/core/hle/service/nvflinger/nvflinger.cpp
@@ -104,6 +104,8 @@ void NVFlinger::SetNVDrvInstance(std::shared_ptr<Nvidia::Module> instance) {
 }
 
 std::optional<u64> NVFlinger::OpenDisplay(std::string_view name) {
+    const auto guard = Lock();
+
     LOG_DEBUG(Service, "Opening \"{}\" display", name);
 
     // TODO(Subv): Currently we only support the Default display.
@@ -121,6 +123,7 @@ std::optional<u64> NVFlinger::OpenDisplay(std::string_view name) {
 }
 
 std::optional<u64> NVFlinger::CreateLayer(u64 display_id) {
+    const auto guard = Lock();
     auto* const display = FindDisplay(display_id);
 
     if (display == nullptr) {
@@ -135,12 +138,15 @@ std::optional<u64> NVFlinger::CreateLayer(u64 display_id) {
 }
 
 void NVFlinger::CloseLayer(u64 layer_id) {
+    const auto guard = Lock();
+
     for (auto& display : displays) {
         display.CloseLayer(layer_id);
     }
 }
 
 std::optional<u32> NVFlinger::FindBufferQueueId(u64 display_id, u64 layer_id) const {
+    const auto guard = Lock();
     const auto* const layer = FindLayer(display_id, layer_id);
 
     if (layer == nullptr) {
@@ -151,6 +157,7 @@ std::optional<u32> NVFlinger::FindBufferQueueId(u64 display_id, u64 layer_id) co
 }
 
 std::shared_ptr<Kernel::ReadableEvent> NVFlinger::FindVsyncEvent(u64 display_id) const {
+    const auto guard = Lock();
     auto* const display = FindDisplay(display_id);
 
     if (display == nullptr) {
@@ -160,20 +167,16 @@ std::shared_ptr<Kernel::ReadableEvent> NVFlinger::FindVsyncEvent(u64 display_id)
     return display->GetVSyncEvent();
 }
 
-BufferQueue& NVFlinger::FindBufferQueue(u32 id) {
+BufferQueue* NVFlinger::FindBufferQueue(u32 id) {
+    const auto guard = Lock();
     const auto itr = std::find_if(buffer_queues.begin(), buffer_queues.end(),
                                   [id](const auto& queue) { return queue.GetId() == id; });
 
-    ASSERT(itr != buffer_queues.end());
-    return *itr;
-}
+    if (itr == buffer_queues.end()) {
+        return nullptr;
+    }
 
-const BufferQueue& NVFlinger::FindBufferQueue(u32 id) const {
-    const auto itr = std::find_if(buffer_queues.begin(), buffer_queues.end(),
-                                  [id](const auto& queue) { return queue.GetId() == id; });
-
-    ASSERT(itr != buffer_queues.end());
-    return *itr;
+    return &*itr;
 }
 
 VI::Display* NVFlinger::FindDisplay(u64 display_id) {
diff --git a/src/core/hle/service/nvflinger/nvflinger.h b/src/core/hle/service/nvflinger/nvflinger.h
index 1ebe949c08..ce1347d6da 100644
--- a/src/core/hle/service/nvflinger/nvflinger.h
+++ b/src/core/hle/service/nvflinger/nvflinger.h
@@ -75,10 +75,7 @@ public:
     [[nodiscard]] std::shared_ptr<Kernel::ReadableEvent> FindVsyncEvent(u64 display_id) const;
 
     /// Obtains a buffer queue identified by the ID.
-    [[nodiscard]] BufferQueue& FindBufferQueue(u32 id);
-
-    /// Obtains a buffer queue identified by the ID.
-    [[nodiscard]] const BufferQueue& FindBufferQueue(u32 id) const;
+    [[nodiscard]] BufferQueue* FindBufferQueue(u32 id);
 
     /// Performs a composition request to the emulated nvidia GPU and triggers the vsync events when
     /// finished.
@@ -86,11 +83,11 @@ public:
 
     [[nodiscard]] s64 GetNextTicks() const;
 
+private:
     [[nodiscard]] std::unique_lock<std::mutex> Lock() const {
         return std::unique_lock{*guard};
     }
 
-private:
     /// Finds the display identified by the specified ID.
     [[nodiscard]] VI::Display* FindDisplay(u64 display_id);
 
diff --git a/src/core/hle/service/vi/vi.cpp b/src/core/hle/service/vi/vi.cpp
index 93a113b930..ce0272e590 100644
--- a/src/core/hle/service/vi/vi.cpp
+++ b/src/core/hle/service/vi/vi.cpp
@@ -551,9 +551,9 @@ private:
             IGBPSetPreallocatedBufferRequestParcel request{ctx.ReadBuffer()};
 
             {
-                const auto guard = nv_flinger.Lock();
-                auto& buffer_queue = nv_flinger.FindBufferQueue(id);
-                buffer_queue.SetPreallocatedBuffer(request.data.slot, request.buffer_container.buffer);
+                auto& buffer_queue = *nv_flinger.FindBufferQueue(id);
+                buffer_queue.SetPreallocatedBuffer(request.data.slot,
+                                                   request.buffer_container.buffer);
             }
 
             IGBPSetPreallocatedBufferResponseParcel response{};
@@ -568,11 +568,8 @@ private:
             std::optional<std::pair<u32, Service::Nvidia::MultiFence*>> result;
 
             while (!result) {
-                {
-                    const auto guard = nv_flinger.Lock();
-                    auto& buffer_queue = nv_flinger.FindBufferQueue(id);
-                    result = buffer_queue.DequeueBuffer(width, height);
-                }
+                auto& buffer_queue = *nv_flinger.FindBufferQueue(id);
+                result = buffer_queue.DequeueBuffer(width, height);
 
                 if (result) {
                     // Buffer is available
@@ -586,8 +583,7 @@ private:
         case TransactionId::RequestBuffer: {
             IGBPRequestBufferRequestParcel request{ctx.ReadBuffer()};
 
-            const auto guard = nv_flinger.Lock();
-            auto& buffer_queue = nv_flinger.FindBufferQueue(id);
+            auto& buffer_queue = *nv_flinger.FindBufferQueue(id);
             auto& buffer = buffer_queue.RequestBuffer(request.slot);
             IGBPRequestBufferResponseParcel response{buffer};
             ctx.WriteBuffer(response.Serialize());
@@ -597,13 +593,10 @@ private:
         case TransactionId::QueueBuffer: {
             IGBPQueueBufferRequestParcel request{ctx.ReadBuffer()};
 
-            {
-                const auto guard = nv_flinger.Lock();
-                auto& buffer_queue = nv_flinger.FindBufferQueue(id);
-                buffer_queue.QueueBuffer(request.data.slot, request.data.transform,
-                                         request.data.GetCropRect(), request.data.swap_interval,
-                                         request.data.multi_fence);
-            }
+            auto& buffer_queue = *nv_flinger.FindBufferQueue(id);
+            buffer_queue.QueueBuffer(request.data.slot, request.data.transform,
+                                     request.data.GetCropRect(), request.data.swap_interval,
+                                     request.data.multi_fence);
 
             IGBPQueueBufferResponseParcel response{1280, 720};
             ctx.WriteBuffer(response.Serialize());
@@ -612,8 +605,7 @@ private:
         case TransactionId::Query: {
             IGBPQueryRequestParcel request{ctx.ReadBuffer()};
 
-            const auto guard = nv_flinger.Lock();
-            auto& buffer_queue = nv_flinger.FindBufferQueue(id);
+            auto& buffer_queue = *nv_flinger.FindBufferQueue(id);
             const u32 value =
                 buffer_queue.Query(static_cast<NVFlinger::BufferQueue::QueryType>(request.type));
 
@@ -624,11 +616,8 @@ private:
         case TransactionId::CancelBuffer: {
             IGBPCancelBufferRequestParcel request{ctx.ReadBuffer()};
 
-            {
-                const auto guard = nv_flinger.Lock();
-                auto& buffer_queue = nv_flinger.FindBufferQueue(id);
-                buffer_queue.CancelBuffer(request.data.slot, request.data.multi_fence);
-            }
+            auto& buffer_queue = *nv_flinger.FindBufferQueue(id);
+            buffer_queue.CancelBuffer(request.data.slot, request.data.multi_fence);
 
             IGBPCancelBufferResponseParcel response{};
             ctx.WriteBuffer(response.Serialize());
@@ -638,11 +627,8 @@ private:
             LOG_WARNING(Service_VI, "(STUBBED) called, transaction=Disconnect");
             const auto buffer = ctx.ReadBuffer();
 
-            {
-                const auto guard = nv_flinger.Lock();
-                auto& buffer_queue = nv_flinger.FindBufferQueue(id);
-                buffer_queue.Disconnect();
-            }
+            auto& buffer_queue = *nv_flinger.FindBufferQueue(id);
+            buffer_queue.Disconnect();
 
             IGBPEmptyResponseParcel response{};
             ctx.WriteBuffer(response.Serialize());
@@ -691,7 +677,7 @@ private:
 
         LOG_WARNING(Service_VI, "(STUBBED) called id={}, unknown={:08X}", id, unknown);
 
-        const auto& buffer_queue = nv_flinger.FindBufferQueue(id);
+        const auto& buffer_queue = *nv_flinger.FindBufferQueue(id);
 
         // TODO(Subv): Find out what this actually is.
         IPC::ResponseBuilder rb{ctx, 2, 1};