From b0f1255c8cb800e9f336be66b3f16c3d958673d2 Mon Sep 17 00:00:00 2001
From: ReinUsesLisp <reinuseslisp@airmail.cc>
Date: Fri, 16 Apr 2021 01:55:06 -0300
Subject: [PATCH] shader: Implement PrimitiveId

---
 src/shader_recompiler/backend/spirv/emit_context.cpp           | 3 +++
 src/shader_recompiler/backend/spirv/emit_context.h             | 1 +
 .../backend/spirv/emit_spirv_context_get_set.cpp               | 2 ++
 src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp      | 3 +++
 src/shader_recompiler/shader_info.h                            | 1 +
 5 files changed, 10 insertions(+)

diff --git a/src/shader_recompiler/backend/spirv/emit_context.cpp b/src/shader_recompiler/backend/spirv/emit_context.cpp
index 067f616137..3946dab143 100644
--- a/src/shader_recompiler/backend/spirv/emit_context.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_context.cpp
@@ -1021,6 +1021,9 @@ void EmitContext::DefineInputs(const Info& info) {
         fswzadd_lut_b =
             ConstantComposite(F32[4], f32_minus_one, f32_minus_one, f32_one, f32_minus_one);
     }
+    if (info.loads_primitive_id) {
+        primitive_id = DefineInput(*this, U32[1], false, spv::BuiltIn::PrimitiveId);
+    }
     if (info.loads_position) {
         const bool is_fragment{stage != Stage::Fragment};
         const spv::BuiltIn built_in{is_fragment ? spv::BuiltIn::Position : spv::BuiltIn::FragCoord};
diff --git a/src/shader_recompiler/backend/spirv/emit_context.h b/src/shader_recompiler/backend/spirv/emit_context.h
index ba0a253b35..c7d6f8a38f 100644
--- a/src/shader_recompiler/backend/spirv/emit_context.h
+++ b/src/shader_recompiler/backend/spirv/emit_context.h
@@ -167,6 +167,7 @@ public:
     Id clip_distances{};
     Id layer{};
     Id viewport_index{};
+    Id primitive_id{};
 
     Id fswzadd_lut_a{};
     Id fswzadd_lut_b{};
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
index 4a1aeece5a..23a74f9666 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
@@ -210,6 +210,8 @@ Id EmitGetAttribute(EmitContext& ctx, IR::Attribute attr, Id vertex) {
         return type->needs_cast ? ctx.OpBitcast(ctx.F32[1], value) : value;
     }
     switch (attr) {
+    case IR::Attribute::PrimitiveId:
+        return ctx.OpLoad(ctx.U32[1], ctx.primitive_id);
     case IR::Attribute::PositionX:
     case IR::Attribute::PositionY:
     case IR::Attribute::PositionZ:
diff --git a/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp b/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp
index aadcf7999c..c84bf211fb 100644
--- a/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp
+++ b/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp
@@ -34,6 +34,9 @@ void GetAttribute(Info& info, IR::Attribute attribute) {
         return;
     }
     switch (attribute) {
+    case IR::Attribute::PrimitiveId:
+        info.loads_primitive_id = true;
+        break;
     case IR::Attribute::PositionX:
     case IR::Attribute::PositionY:
     case IR::Attribute::PositionZ:
diff --git a/src/shader_recompiler/shader_info.h b/src/shader_recompiler/shader_info.h
index 4dbf9ed12a..d6cde15960 100644
--- a/src/shader_recompiler/shader_info.h
+++ b/src/shader_recompiler/shader_info.h
@@ -107,6 +107,7 @@ struct Info {
     std::array<bool, 30> uses_patches{};
 
     std::array<InputVarying, 32> input_generics{};
+    bool loads_primitive_id{};
     bool loads_position{};
     bool loads_instance_id{};
     bool loads_vertex_id{};