From 420982864634a5e52cea42c43f8623f75483fbcc Mon Sep 17 00:00:00 2001
From: ReinUsesLisp <reinuseslisp@airmail.cc>
Date: Wed, 21 Apr 2021 00:27:55 -0300
Subject: [PATCH] shader: Intrusively store register values in block for SSA
 pass

---
 .../frontend/ir/basic_block.h                 | 10 +++
 .../ir_opt/ssa_rewrite_pass.cpp               | 64 +++++++++++++------
 2 files changed, 53 insertions(+), 21 deletions(-)

diff --git a/src/shader_recompiler/frontend/ir/basic_block.h b/src/shader_recompiler/frontend/ir/basic_block.h
index 6a1d615d92..3a42307557 100644
--- a/src/shader_recompiler/frontend/ir/basic_block.h
+++ b/src/shader_recompiler/frontend/ir/basic_block.h
@@ -101,6 +101,13 @@ public:
         return branch_false;
     }
 
+    void SetSsaRegValue(IR::Reg reg, const Value& value) noexcept {
+        ssa_reg_values[RegIndex(reg)] = value;
+    }
+    const Value& SsaRegValue(IR::Reg reg) const noexcept {
+        return ssa_reg_values[RegIndex(reg)];
+    }
+
     [[nodiscard]] bool empty() const {
         return instructions.empty();
     }
@@ -182,6 +189,9 @@ private:
     /// Block immediate predecessors
     std::vector<Block*> imm_predecessors;
 
+    /// Intrusively store the value of a register in the block.
+    std::array<Value, NUM_REGS> ssa_reg_values;
+
     /// Intrusively stored host definition of this block.
     u32 definition{};
 };
diff --git a/src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp b/src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp
index ddd679e393..bb1a90004f 100644
--- a/src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp
+++ b/src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp
@@ -57,39 +57,62 @@ using Variant = std::variant<IR::Reg, IR::Pred, ZeroFlagTag, SignFlagTag, CarryF
 using ValueMap = boost::container::flat_map<IR::Block*, IR::Value, std::less<IR::Block*>>;
 
 struct DefTable {
-    [[nodiscard]] ValueMap& operator[](IR::Reg variable) noexcept {
-        return regs[IR::RegIndex(variable)];
+    const IR::Value& Def(IR::Block* block, IR::Reg variable) noexcept {
+        return block->SsaRegValue(variable);
+    }
+    void SetDef(IR::Block* block, IR::Reg variable, const IR::Value& value) noexcept {
+        block->SetSsaRegValue(variable, value);
     }
 
-    [[nodiscard]] ValueMap& operator[](IR::Pred variable) noexcept {
-        return preds[IR::PredIndex(variable)];
+    const IR::Value& Def(IR::Block* block, IR::Pred variable) noexcept {
+        return preds[IR::PredIndex(variable)][block];
+    }
+    void SetDef(IR::Block* block, IR::Pred variable, const IR::Value& value) noexcept {
+        preds[IR::PredIndex(variable)].insert_or_assign(block, value);
     }
 
-    [[nodiscard]] ValueMap& operator[](GotoVariable goto_variable) {
-        return goto_vars[goto_variable.index];
+    const IR::Value& Def(IR::Block* block, GotoVariable variable) noexcept {
+        return goto_vars[variable.index][block];
+    }
+    void SetDef(IR::Block* block, GotoVariable variable, const IR::Value& value) noexcept {
+        goto_vars[variable.index].insert_or_assign(block, value);
     }
 
-    [[nodiscard]] ValueMap& operator[](IndirectBranchVariable) {
-        return indirect_branch_var;
+    const IR::Value& Def(IR::Block* block, IndirectBranchVariable) noexcept {
+        return indirect_branch_var[block];
+    }
+    void SetDef(IR::Block* block, IndirectBranchVariable, const IR::Value& value) noexcept {
+        indirect_branch_var.insert_or_assign(block, value);
     }
 
-    [[nodiscard]] ValueMap& operator[](ZeroFlagTag) noexcept {
-        return zero_flag;
+    const IR::Value& Def(IR::Block* block, ZeroFlagTag) noexcept {
+        return zero_flag[block];
+    }
+    void SetDef(IR::Block* block, ZeroFlagTag, const IR::Value& value) noexcept {
+        zero_flag.insert_or_assign(block, value);
     }
 
-    [[nodiscard]] ValueMap& operator[](SignFlagTag) noexcept {
-        return sign_flag;
+    const IR::Value& Def(IR::Block* block, SignFlagTag) noexcept {
+        return sign_flag[block];
+    }
+    void SetDef(IR::Block* block, SignFlagTag, const IR::Value& value) noexcept {
+        sign_flag.insert_or_assign(block, value);
     }
 
-    [[nodiscard]] ValueMap& operator[](CarryFlagTag) noexcept {
-        return carry_flag;
+    const IR::Value& Def(IR::Block* block, CarryFlagTag) noexcept {
+        return carry_flag[block];
+    }
+    void SetDef(IR::Block* block, CarryFlagTag, const IR::Value& value) noexcept {
+        carry_flag.insert_or_assign(block, value);
     }
 
-    [[nodiscard]] ValueMap& operator[](OverflowFlagTag) noexcept {
-        return overflow_flag;
+    const IR::Value& Def(IR::Block* block, OverflowFlagTag) noexcept {
+        return overflow_flag[block];
+    }
+    void SetDef(IR::Block* block, OverflowFlagTag, const IR::Value& value) noexcept {
+        overflow_flag.insert_or_assign(block, value);
     }
 
-    std::array<ValueMap, IR::NUM_USER_REGS> regs;
     std::array<ValueMap, IR::NUM_USER_PREDS> preds;
     boost::container::flat_map<u32, ValueMap> goto_vars;
     ValueMap indirect_branch_var;
@@ -143,7 +166,7 @@ class Pass {
 public:
     template <typename Type>
     void WriteVariable(Type variable, IR::Block* block, const IR::Value& value) {
-        current_def[variable].insert_or_assign(block, value);
+        current_def.SetDef(block, variable, value);
     }
 
     template <typename Type>
@@ -170,9 +193,8 @@ public:
             IR::Block* const block{stack.back().block};
             switch (stack.back().pc) {
             case Status::Start: {
-                const ValueMap& def{current_def[variable]};
-                if (const auto it{def.find(block)}; it != def.end()) {
-                    stack.back().result = it->second;
+                if (const IR::Value& def = current_def.Def(block, variable); !def.IsEmpty()) {
+                    stack.back().result = def;
                 } else if (!sealed_blocks.contains(block)) {
                     // Incomplete CFG
                     IR::Inst* phi{&*block->PrependNewInst(block->begin(), IR::Opcode::Phi)};