From edb5844240c339846d505735d2c2e1ad731f8be7 Mon Sep 17 00:00:00 2001
From: ameerj <52414509+ameerj@users.noreply.github.com>
Date: Fri, 17 Sep 2021 21:31:29 -0400
Subject: [PATCH] rescaling_pass: Fix and simplify shuffle/fragcoord pass

---
 .../ir_opt/rescaling_pass.cpp                 | 46 ++++++++-----------
 1 file changed, 20 insertions(+), 26 deletions(-)

diff --git a/src/shader_recompiler/ir_opt/rescaling_pass.cpp b/src/shader_recompiler/ir_opt/rescaling_pass.cpp
index 4d23b60c8a..8bbaa55e4c 100644
--- a/src/shader_recompiler/ir_opt/rescaling_pass.cpp
+++ b/src/shader_recompiler/ir_opt/rescaling_pass.cpp
@@ -14,45 +14,39 @@
 
 namespace Shader::Optimization {
 namespace {
-void VisitMark(const IR::Program& program, IR::Inst& inst) {
+void VisitMark(const IR::Program& program, const IR::Inst& inst) {
     const bool is_fragment_shader{program.stage == Stage::Fragment};
+    if (!is_fragment_shader) {
+        return;
+    }
     switch (inst.GetOpcode()) {
     case IR::Opcode::ShuffleIndex:
     case IR::Opcode::ShuffleUp:
     case IR::Opcode::ShuffleDown:
     case IR::Opcode::ShuffleButterfly: {
-        const auto try_mark = [is_fragment_shader](IR::Inst* op) {
-            const IR::Attribute attr{op->Arg(0).Attribute()};
+        const IR::Value shfl_arg{inst.Arg(0)};
+        if (shfl_arg.IsImmediate()) {
+            break;
+        }
+        const IR::Inst* const arg_inst{shfl_arg.InstRecursive()};
+        if (arg_inst->GetOpcode() != IR::Opcode::BitCastU32F32) {
+            break;
+        }
+        const IR::Value bitcast_arg{arg_inst->Arg(0)};
+        if (bitcast_arg.IsImmediate()) {
+            break;
+        }
+        IR::Inst* const bitcast_inst{bitcast_arg.InstRecursive()};
+        if (bitcast_inst->GetOpcode() == IR::Opcode::GetAttribute) {
+            const IR::Attribute attr{bitcast_inst->Arg(0).Attribute()};
             switch (attr) {
             case IR::Attribute::PositionX:
             case IR::Attribute::PositionY:
-                if (is_fragment_shader) {
-                    op->SetFlags<u32>(0xDEADBEEF);
-                }
+                bitcast_inst->SetFlags<u32>(0xDEADBEEF);
                 break;
             default:
                 break;
             }
-        };
-        const IR::Value param_1{inst.Arg(0)};
-        if (param_1.IsImmediate()) {
-            break;
-        }
-        IR::Inst* op_a{param_1.InstRecursive()};
-        if (op_a->GetOpcode() == IR::Opcode::GetAttribute) {
-            try_mark(op_a);
-            break;
-        }
-        if (op_a->GetOpcode() != IR::Opcode::BitCastF32U32) {
-            break;
-        }
-        const IR::Value param_2{op_a->Arg(0)};
-        if (param_2.IsImmediate()) {
-            break;
-        }
-        IR::Inst* op_b{param_2.InstRecursive()};
-        if (op_b->GetOpcode() == IR::Opcode::GetAttribute) {
-            try_mark(op_b);
         }
         break;
     }