From f17415d431777bb234714a3c6a97072872b2cc71 Mon Sep 17 00:00:00 2001
From: ReinUsesLisp <reinuseslisp@airmail.cc>
Date: Tue, 30 Jul 2019 00:21:46 -0300
Subject: [PATCH] shader_ir: Implement ST_S

This instruction writes to a memory buffer shared with threads within
the same work group. It is known as "shared" memory in GLSL.
---
 src/video_core/shader/decode/memory.cpp | 25 ++++++++++++++++---------
 src/video_core/shader/node.h            | 16 +++++++++++++++-
 src/video_core/shader/shader_ir.cpp     |  9 +++++++++
 src/video_core/shader/shader_ir.h       |  6 +++++-
 4 files changed, 45 insertions(+), 11 deletions(-)

diff --git a/src/video_core/shader/decode/memory.cpp b/src/video_core/shader/decode/memory.cpp
index ed108bea8b..8f74fa7d89 100644
--- a/src/video_core/shader/decode/memory.cpp
+++ b/src/video_core/shader/decode/memory.cpp
@@ -35,7 +35,7 @@ u32 GetUniformTypeElementsCount(Tegra::Shader::UniformType uniform_type) {
         return 1;
     }
 }
-} // namespace
+} // Anonymous namespace
 
 u32 ShaderIR::DecodeMemory(NodeBlock& bb, u32 pc) {
     const Instruction instr = {program_code[pc]};
@@ -209,27 +209,34 @@ u32 ShaderIR::DecodeMemory(NodeBlock& bb, u32 pc) {
 
         break;
     }
-    case OpCode::Id::ST_L: {
+    case OpCode::Id::ST_L:
         LOG_DEBUG(HW_GPU, "ST_L cache management mode: {}",
                   static_cast<u64>(instr.st_l.cache_management.Value()));
-
-        const auto GetLmemAddr = [&](s32 offset) {
+        [[fallthrough]];
+    case OpCode::Id::ST_S: {
+        const auto GetAddress = [&](s32 offset) {
             ASSERT(offset % 4 == 0);
             const Node immediate = Immediate(static_cast<s32>(instr.smem_imm) + offset);
             return Operation(OperationCode::IAdd, NO_PRECISE, GetRegister(instr.gpr8), immediate);
         };
 
+        const auto set_memory = opcode->get().GetId() == OpCode::Id::ST_L
+                                    ? &ShaderIR::SetLocalMemory
+                                    : &ShaderIR::SetSharedMemory;
+
         switch (instr.ldst_sl.type.Value()) {
         case Tegra::Shader::StoreType::Bits128:
-            SetLocalMemory(bb, GetLmemAddr(12), GetRegister(instr.gpr0.Value() + 3));
-            SetLocalMemory(bb, GetLmemAddr(8), GetRegister(instr.gpr0.Value() + 2));
+            (this->*set_memory)(bb, GetAddress(12), GetRegister(instr.gpr0.Value() + 3));
+            (this->*set_memory)(bb, GetAddress(8), GetRegister(instr.gpr0.Value() + 2));
+            [[fallthrough]];
         case Tegra::Shader::StoreType::Bits64:
-            SetLocalMemory(bb, GetLmemAddr(4), GetRegister(instr.gpr0.Value() + 1));
+            (this->*set_memory)(bb, GetAddress(4), GetRegister(instr.gpr0.Value() + 1));
+            [[fallthrough]];
         case Tegra::Shader::StoreType::Bits32:
-            SetLocalMemory(bb, GetLmemAddr(0), GetRegister(instr.gpr0));
+            (this->*set_memory)(bb, GetAddress(0), GetRegister(instr.gpr0));
             break;
         default:
-            UNIMPLEMENTED_MSG("ST_L Unhandled type: {}",
+            UNIMPLEMENTED_MSG("{} unhandled type: {}", opcode->get().GetName(),
                               static_cast<u32>(instr.ldst_sl.type.Value()));
         }
         break;
diff --git a/src/video_core/shader/node.h b/src/video_core/shader/node.h
index 5db9313c48..e0d1979fa0 100644
--- a/src/video_core/shader/node.h
+++ b/src/video_core/shader/node.h
@@ -198,12 +198,13 @@ class PredicateNode;
 class AbufNode;
 class CbufNode;
 class LmemNode;
+class SmemNode;
 class GmemNode;
 class CommentNode;
 
 using NodeData =
     std::variant<OperationNode, ConditionalNode, GprNode, ImmediateNode, InternalFlagNode,
-                 PredicateNode, AbufNode, CbufNode, LmemNode, GmemNode, CommentNode>;
+                 PredicateNode, AbufNode, CbufNode, LmemNode, SmemNode, GmemNode, CommentNode>;
 using Node = std::shared_ptr<NodeData>;
 using Node4 = std::array<Node, 4>;
 using NodeBlock = std::vector<Node>;
@@ -536,6 +537,19 @@ private:
     Node address;
 };
 
+/// Shared memory node
+class SmemNode final {
+public:
+    explicit SmemNode(Node address) : address{std::move(address)} {}
+
+    const Node& GetAddress() const {
+        return address;
+    }
+
+private:
+    Node address;
+};
+
 /// Global memory node
 class GmemNode final {
 public:
diff --git a/src/video_core/shader/shader_ir.cpp b/src/video_core/shader/shader_ir.cpp
index 1e5c7f6605..bbbab0bcab 100644
--- a/src/video_core/shader/shader_ir.cpp
+++ b/src/video_core/shader/shader_ir.cpp
@@ -137,6 +137,10 @@ Node ShaderIR::GetLocalMemory(Node address) {
     return MakeNode<LmemNode>(std::move(address));
 }
 
+Node ShaderIR::GetSharedMemory(Node address) {
+    return MakeNode<SmemNode>(std::move(address));
+}
+
 Node ShaderIR::GetTemporary(u32 id) {
     return GetRegister(Register::ZeroIndex + 1 + id);
 }
@@ -378,6 +382,11 @@ void ShaderIR::SetLocalMemory(NodeBlock& bb, Node address, Node value) {
         Operation(OperationCode::Assign, GetLocalMemory(std::move(address)), std::move(value)));
 }
 
+void ShaderIR::SetSharedMemory(NodeBlock& bb, Node address, Node value) {
+    bb.push_back(
+        Operation(OperationCode::Assign, GetSharedMemory(std::move(address)), std::move(value)));
+}
+
 void ShaderIR::SetTemporary(NodeBlock& bb, u32 id, Node value) {
     SetRegister(bb, Register::ZeroIndex + 1 + id, std::move(value));
 }
diff --git a/src/video_core/shader/shader_ir.h b/src/video_core/shader/shader_ir.h
index bcc9b79b67..ab57388ed3 100644
--- a/src/video_core/shader/shader_ir.h
+++ b/src/video_core/shader/shader_ir.h
@@ -208,6 +208,8 @@ private:
     Node GetInternalFlag(InternalFlag flag, bool negated = false);
     /// Generates a node representing a local memory address
     Node GetLocalMemory(Node address);
+    /// Generates a node representing a shared memory address
+    Node GetSharedMemory(Node address);
     /// Generates a temporary, internally it uses a post-RZ register
     Node GetTemporary(u32 id);
 
@@ -217,8 +219,10 @@ private:
     void SetPredicate(NodeBlock& bb, u64 dest, Node src);
     /// Sets an internal flag. src value must be a bool-evaluated node
     void SetInternalFlag(NodeBlock& bb, InternalFlag flag, Node value);
-    /// Sets a local memory address. address and value must be a number-evaluated node
+    /// Sets a local memory address with a value.
     void SetLocalMemory(NodeBlock& bb, Node address, Node value);
+    /// Sets a shared memory address with a value.
+    void SetSharedMemory(NodeBlock& bb, Node address, Node value);
     /// Sets a temporary. Internally it uses a post-RZ register
     void SetTemporary(NodeBlock& bb, u32 id, Node value);