From e87670ee48c896ba029a11ad590234e00260f875 Mon Sep 17 00:00:00 2001
From: ameerj <aj662@drexel.edu>
Date: Wed, 25 Nov 2020 00:33:20 -0500
Subject: [PATCH] Refactor MaxwellToSpirvComparison. Use Common::BitCast

Co-Authored-By: Rodrigo Locatti <reinuseslisp@airmail.cc>
---
 .../renderer_vulkan/fixed_pipeline_state.cpp  |  5 +-
 .../renderer_vulkan/vk_pipeline_cache.cpp     |  4 +-
 .../renderer_vulkan/vk_shader_decompiler.cpp  | 56 ++++++++++---------
 3 files changed, 34 insertions(+), 31 deletions(-)

diff --git a/src/video_core/renderer_vulkan/fixed_pipeline_state.cpp b/src/video_core/renderer_vulkan/fixed_pipeline_state.cpp
index 1928283001..fffae528ea 100644
--- a/src/video_core/renderer_vulkan/fixed_pipeline_state.cpp
+++ b/src/video_core/renderer_vulkan/fixed_pipeline_state.cpp
@@ -8,6 +8,7 @@
 
 #include <boost/functional/hash.hpp>
 
+#include "common/bit_cast.h"
 #include "common/cityhash.h"
 #include "common/common_types.h"
 #include "video_core/renderer_vulkan/fixed_pipeline_state.h"
@@ -64,9 +65,9 @@ void FixedPipelineState::Fill(const Maxwell& regs, bool has_extended_dynamic_sta
     const auto test_func =
         regs.alpha_test_enabled == 1 ? regs.alpha_test_func : Maxwell::ComparisonOp::Always;
     alpha_test_func.Assign(PackComparisonOp(test_func));
-    std::memcpy(&alpha_test_ref, &regs.alpha_test_ref, sizeof(u32)); // TODO: C++20 std::bit_cast
+    alpha_test_ref = Common::BitCast<u32>(regs.alpha_test_ref);
 
-    std::memcpy(&point_size, &regs.point_size, sizeof(point_size)); // TODO: C++20 std::bit_cast
+    point_size = Common::BitCast<u32>(regs.point_size);
 
     for (std::size_t index = 0; index < Maxwell::NumVertexArrays; ++index) {
         binding_divisors[index] =
diff --git a/src/video_core/renderer_vulkan/vk_pipeline_cache.cpp b/src/video_core/renderer_vulkan/vk_pipeline_cache.cpp
index a66a841fbc..f9efe526df 100644
--- a/src/video_core/renderer_vulkan/vk_pipeline_cache.cpp
+++ b/src/video_core/renderer_vulkan/vk_pipeline_cache.cpp
@@ -7,6 +7,7 @@
 #include <memory>
 #include <vector>
 
+#include "common/bit_cast.h"
 #include "common/microprofile.h"
 #include "core/core.h"
 #include "core/memory.h"
@@ -347,8 +348,7 @@ VKPipelineCache::DecompileShaders(const FixedPipelineState& fixed_state) {
     // Alpha test
     specialization.alpha_test_func =
         FixedPipelineState::UnpackComparisonOp(fixed_state.alpha_test_func.Value());
-    // memcpy from u32 to float TODO: C++20 std::bit_cast
-    std::memcpy(&specialization.alpha_test_ref, &fixed_state.alpha_test_ref, sizeof(float));
+    specialization.alpha_test_ref = Common::BitCast<float>(fixed_state.alpha_test_ref);
 
     SPIRVProgram program;
     std::vector<VkDescriptorSetLayoutBinding> bindings;
diff --git a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
index 81550bc96a..d6685cd122 100644
--- a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
+++ b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
@@ -2075,47 +2075,49 @@ private:
         return {};
     }
 
-    void AlphaTest(Id pointer) {
-        const Id true_label = OpLabel();
-        const Id skip_label = OpLabel();
-        const Id alpha_reference = Constant(t_float, specialization.alpha_test_ref);
-        const Id alpha_value = OpLoad(t_float, pointer);
-        Id condition;
+    Id MaxwellToSpirvComparison(Maxwell::ComparisonOp compare_op, Id operand_1, Id operand_2) {
         using Compare = Maxwell::ComparisonOp;
-        switch (specialization.alpha_test_func) {
+        switch (compare_op) {
         case Compare::NeverOld:
-            condition = v_false; // Never true
-            break;
+            return v_false; // Never let the test pass
         case Compare::LessOld:
-            condition = OpFOrdLessThan(t_bool, alpha_reference, alpha_value);
-            break;
+            return OpFOrdLessThan(t_bool, operand_1, operand_2);
         case Compare::EqualOld:
-            condition = OpFOrdEqual(t_bool, alpha_reference, alpha_value);
-            break;
+            // Note: not accurate when tested against a unit test
+            // TODO: confirm if used by games
+            return OpFOrdEqual(t_bool, operand_1, operand_2);
         case Compare::LessEqualOld:
-            condition = OpFOrdLessThanEqual(t_bool, alpha_reference, alpha_value);
-            break;
+            return OpFOrdLessThanEqual(t_bool, operand_1, operand_2);
         case Compare::GreaterOld:
-            // Note: requires "Equal" to properly work for ssbu. perhaps a precision issue
-            condition = OpFOrdGreaterThanEqual(t_bool, alpha_reference, alpha_value);
-            break;
+            return OpFOrdGreaterThan(t_bool, operand_1, operand_2);
         case Compare::NotEqualOld:
             // Note: not accurate when tested against a unit test
             // TODO: confirm if used by games
-            condition = OpFOrdNotEqual(t_bool, alpha_reference, alpha_value);
-            break;
+            return OpFOrdNotEqual(t_bool, operand_1, operand_2);
         case Compare::GreaterEqualOld:
-            condition = OpFOrdGreaterThanEqual(t_bool, alpha_reference, alpha_value);
-            break;
-        case Compare::AlwaysOld:
-            return;
+            return OpFOrdGreaterThanEqual(t_bool, operand_1, operand_2);
         default:
             UNREACHABLE();
         }
-        OpBranchConditional(condition, true_label, skip_label);
-        AddLabel(true_label);
+    }
+
+    void AlphaTest(Id pointer) {
+        if (specialization.alpha_test_func == Maxwell::ComparisonOp::AlwaysOld) {
+            return;
+        }
+
+        const Id true_label = OpLabel();
+        const Id discard_label = OpLabel();
+        const Id alpha_reference = Constant(t_float, specialization.alpha_test_ref);
+        const Id alpha_value = OpLoad(t_float, pointer);
+
+        const Id condition =
+            MaxwellToSpirvComparison(specialization.alpha_test_func, alpha_value, alpha_reference);
+
+        OpBranchConditional(condition, true_label, discard_label);
+        AddLabel(discard_label);
         OpKill();
-        AddLabel(skip_label);
+        AddLabel(true_label);
     }
 
     void PreExit() {