From bee81887998070d213c984d4ae4e5ae35de6fd96 Mon Sep 17 00:00:00 2001
From: FernandoS27 <fsahmkow27@gmail.com>
Date: Thu, 1 Apr 2021 07:42:58 +0200
Subject: [PATCH] shader: Fold composite extract

---
 .../ir_opt/constant_propagation_pass.cpp      | 62 +++++++++++++++++++
 1 file changed, 62 insertions(+)

diff --git a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
index 15e16956e9..8999c3a3d9 100644
--- a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
+++ b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
@@ -368,6 +368,50 @@ void FoldBranchConditional(IR::Inst& inst) {
     }
 }
 
+std::optional<IR::Value> FoldCompositeExtractImpl(IR::Value inst_value, IR::Opcode insert,
+                                                  IR::Opcode construct, u32 first_index) {
+    IR::Inst* const inst{inst_value.InstRecursive()};
+    if (inst->Opcode() == construct) {
+        return inst->Arg(first_index);
+    }
+
+    if (inst->Opcode() != insert) {
+        return std::nullopt;
+    }
+
+    IR::Value value_index{inst->Arg(2)};
+    if (!value_index.IsImmediate()) {
+        return std::nullopt;
+    }
+
+    const u32 second_index = value_index.U32();
+    if (first_index != second_index) {
+        IR::Value value_composite{inst->Arg(0)};
+        if (value_composite.IsImmediate()) {
+            return std::nullopt;
+        }
+        return FoldCompositeExtractImpl(value_composite, insert, construct, first_index);
+    }
+    return inst->Arg(1);
+}
+
+void FoldCompositeExtract(IR::Inst& inst, IR::Opcode construct, IR::Opcode insert) {
+    const IR::Value value_1{inst.Arg(0)};
+    const IR::Value value_2{inst.Arg(1)};
+    if (value_1.IsImmediate()) {
+        return;
+    }
+    if (!value_2.IsImmediate()) {
+        return;
+    }
+    const u32 first_index = value_2.U32();
+    auto result = FoldCompositeExtractImpl(value_1, insert, construct, first_index);
+    if (!result) {
+        return;
+    }
+    inst.ReplaceUsesWith(*result);
+}
+
 void ConstantPropagation(IR::Block& block, IR::Inst& inst) {
     switch (inst.Opcode()) {
     case IR::Opcode::GetRegister:
@@ -458,6 +502,24 @@ void ConstantPropagation(IR::Block& block, IR::Inst& inst) {
         return;
     case IR::Opcode::BranchConditional:
         return FoldBranchConditional(inst);
+    case IR::Opcode::CompositeExtractF32x2:
+        return FoldCompositeExtract(inst, IR::Opcode::CompositeConstructF32x2,
+                                    IR::Opcode::CompositeInsertF32x2);
+    case IR::Opcode::CompositeExtractF32x3:
+        return FoldCompositeExtract(inst, IR::Opcode::CompositeConstructF32x3,
+                                    IR::Opcode::CompositeInsertF32x3);
+    case IR::Opcode::CompositeExtractF32x4:
+        return FoldCompositeExtract(inst, IR::Opcode::CompositeConstructF32x4,
+                                    IR::Opcode::CompositeInsertF32x4);
+    case IR::Opcode::CompositeExtractF16x2:
+        return FoldCompositeExtract(inst, IR::Opcode::CompositeConstructF16x2,
+                                    IR::Opcode::CompositeInsertF16x2);
+    case IR::Opcode::CompositeExtractF16x3:
+        return FoldCompositeExtract(inst, IR::Opcode::CompositeConstructF16x3,
+                                    IR::Opcode::CompositeInsertF16x3);
+    case IR::Opcode::CompositeExtractF16x4:
+        return FoldCompositeExtract(inst, IR::Opcode::CompositeConstructF16x4,
+                                    IR::Opcode::CompositeInsertF16x4);
     default:
         break;
     }