From 2a0d707ce1d4880dfcbd34c4d6572917a501f675 Mon Sep 17 00:00:00 2001
From: Liam <byteslice@airmail.cc>
Date: Thu, 11 Jan 2024 16:50:59 -0500
Subject: [PATCH] shader_recompiler: emulate 8-bit and 16-bit storage writes
 with cas loop

---
 .../backend/spirv/emit_spirv_memory.cpp       | 40 ++++++++++++---
 .../backend/spirv/spirv_emit_context.cpp      | 51 +++++++++++++++++++
 .../backend/spirv/spirv_emit_context.h        |  3 ++
 3 files changed, 86 insertions(+), 8 deletions(-)

diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_memory.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_memory.cpp
index 8693801c70..bdcbccfde9 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv_memory.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv_memory.cpp
@@ -65,6 +65,14 @@ void WriteStorage32(EmitContext& ctx, const IR::Value& binding, const IR::Value&
     WriteStorage(ctx, binding, offset, value, ctx.storage_types.U32, sizeof(u32),
                  &StorageDefinitions::U32, index_offset);
 }
+
+void WriteStorageByCasLoop(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
+                           Id value, Id bit_offset, Id bit_count) {
+    const Id pointer{StoragePointer(ctx, binding, offset, ctx.storage_types.U32, sizeof(u32),
+                                    &StorageDefinitions::U32)};
+    ctx.OpFunctionCall(ctx.TypeVoid(), ctx.write_storage_cas_loop_func, pointer, value, bit_offset,
+                       bit_count);
+}
 } // Anonymous namespace
 
 void EmitLoadGlobalU8(EmitContext&) {
@@ -219,26 +227,42 @@ Id EmitLoadStorage128(EmitContext& ctx, const IR::Value& binding, const IR::Valu
 
 void EmitWriteStorageU8(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
                         Id value) {
-    WriteStorage(ctx, binding, offset, ctx.OpSConvert(ctx.U8, value), ctx.storage_types.U8,
-                 sizeof(u8), &StorageDefinitions::U8);
+    if (ctx.profile.support_int8) {
+        WriteStorage(ctx, binding, offset, ctx.OpSConvert(ctx.U8, value), ctx.storage_types.U8,
+                     sizeof(u8), &StorageDefinitions::U8);
+    } else {
+        WriteStorageByCasLoop(ctx, binding, offset, value, ctx.BitOffset8(offset), ctx.Const(8u));
+    }
 }
 
 void EmitWriteStorageS8(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
                         Id value) {
-    WriteStorage(ctx, binding, offset, ctx.OpSConvert(ctx.S8, value), ctx.storage_types.S8,
-                 sizeof(s8), &StorageDefinitions::S8);
+    if (ctx.profile.support_int8) {
+        WriteStorage(ctx, binding, offset, ctx.OpSConvert(ctx.S8, value), ctx.storage_types.S8,
+                     sizeof(s8), &StorageDefinitions::S8);
+    } else {
+        WriteStorageByCasLoop(ctx, binding, offset, value, ctx.BitOffset8(offset), ctx.Const(8u));
+    }
 }
 
 void EmitWriteStorageU16(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
                          Id value) {
-    WriteStorage(ctx, binding, offset, ctx.OpSConvert(ctx.U16, value), ctx.storage_types.U16,
-                 sizeof(u16), &StorageDefinitions::U16);
+    if (ctx.profile.support_int16) {
+        WriteStorage(ctx, binding, offset, ctx.OpSConvert(ctx.U16, value), ctx.storage_types.U16,
+                     sizeof(u16), &StorageDefinitions::U16);
+    } else {
+        WriteStorageByCasLoop(ctx, binding, offset, value, ctx.BitOffset16(offset), ctx.Const(16u));
+    }
 }
 
 void EmitWriteStorageS16(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
                          Id value) {
-    WriteStorage(ctx, binding, offset, ctx.OpSConvert(ctx.S16, value), ctx.storage_types.S16,
-                 sizeof(s16), &StorageDefinitions::S16);
+    if (ctx.profile.support_int16) {
+        WriteStorage(ctx, binding, offset, ctx.OpSConvert(ctx.S16, value), ctx.storage_types.S16,
+                     sizeof(s16), &StorageDefinitions::S16);
+    } else {
+        WriteStorageByCasLoop(ctx, binding, offset, value, ctx.BitOffset16(offset), ctx.Const(16u));
+    }
 }
 
 void EmitWriteStorage32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
diff --git a/src/shader_recompiler/backend/spirv/spirv_emit_context.cpp b/src/shader_recompiler/backend/spirv/spirv_emit_context.cpp
index 0442adc831..a27f2f73ac 100644
--- a/src/shader_recompiler/backend/spirv/spirv_emit_context.cpp
+++ b/src/shader_recompiler/backend/spirv/spirv_emit_context.cpp
@@ -480,6 +480,7 @@ EmitContext::EmitContext(const Profile& profile_, const RuntimeInfo& runtime_inf
     DefineTextures(program.info, texture_binding, bindings.texture_scaling_index);
     DefineImages(program.info, image_binding, bindings.image_scaling_index);
     DefineAttributeMemAccess(program.info);
+    DefineWriteStorageCasLoopFunction(program.info);
     DefineGlobalMemoryFunctions(program.info);
     DefineRescalingInput(program.info);
     DefineRenderArea(program.info);
@@ -877,6 +878,56 @@ void EmitContext::DefineAttributeMemAccess(const Info& info) {
     }
 }
 
+void EmitContext::DefineWriteStorageCasLoopFunction(const Info& info) {
+    if (profile.support_int8 && profile.support_int16) {
+        return;
+    }
+    if (!info.uses_int8 && !info.uses_int16) {
+        return;
+    }
+
+    AddCapability(spv::Capability::VariablePointersStorageBuffer);
+
+    const Id ptr_type{TypePointer(spv::StorageClass::StorageBuffer, U32[1])};
+    const Id func_type{TypeFunction(void_id, ptr_type, U32[1], U32[1], U32[1])};
+    const Id func{OpFunction(void_id, spv::FunctionControlMask::MaskNone, func_type)};
+    const Id pointer{OpFunctionParameter(ptr_type)};
+    const Id value{OpFunctionParameter(U32[1])};
+    const Id bit_offset{OpFunctionParameter(U32[1])};
+    const Id bit_count{OpFunctionParameter(U32[1])};
+
+    AddLabel();
+    const Id scope_device{Const(1u)};
+    const Id ordering_relaxed{u32_zero_value};
+    const Id body_label{OpLabel()};
+    const Id continue_label{OpLabel()};
+    const Id endloop_label{OpLabel()};
+    const Id beginloop_label{OpLabel()};
+    OpBranch(beginloop_label);
+
+    AddLabel(beginloop_label);
+    OpLoopMerge(endloop_label, continue_label, spv::LoopControlMask::MaskNone);
+    OpBranch(body_label);
+
+    AddLabel(body_label);
+    const Id expected_value{OpLoad(U32[1], pointer)};
+    const Id desired_value{OpBitFieldInsert(U32[1], expected_value, value, bit_offset, bit_count)};
+    const Id actual_value{OpAtomicCompareExchange(U32[1], pointer, scope_device, ordering_relaxed,
+                                                  ordering_relaxed, desired_value, expected_value)};
+    const Id store_successful{OpIEqual(U1, expected_value, actual_value)};
+    OpBranchConditional(store_successful, endloop_label, continue_label);
+
+    AddLabel(endloop_label);
+    OpReturn();
+
+    AddLabel(continue_label);
+    OpBranch(beginloop_label);
+
+    OpFunctionEnd();
+
+    write_storage_cas_loop_func = func;
+}
+
 void EmitContext::DefineGlobalMemoryFunctions(const Info& info) {
     if (!info.uses_global_memory || !profile.support_int64) {
         return;
diff --git a/src/shader_recompiler/backend/spirv/spirv_emit_context.h b/src/shader_recompiler/backend/spirv/spirv_emit_context.h
index 56019ad891..40adcb6b67 100644
--- a/src/shader_recompiler/backend/spirv/spirv_emit_context.h
+++ b/src/shader_recompiler/backend/spirv/spirv_emit_context.h
@@ -325,6 +325,8 @@ public:
     Id f32x2_min_cas{};
     Id f32x2_max_cas{};
 
+    Id write_storage_cas_loop_func{};
+
     Id load_global_func_u32{};
     Id load_global_func_u32x2{};
     Id load_global_func_u32x4{};
@@ -372,6 +374,7 @@ private:
     void DefineTextures(const Info& info, u32& binding, u32& scaling_index);
     void DefineImages(const Info& info, u32& binding, u32& scaling_index);
     void DefineAttributeMemAccess(const Info& info);
+    void DefineWriteStorageCasLoopFunction(const Info& info);
     void DefineGlobalMemoryFunctions(const Info& info);
     void DefineRescalingInput(const Info& info);
     void DefineRescalingInputPushConstant();