From 395bed3a0af90a53be44e81eadd06f4931c8e933 Mon Sep 17 00:00:00 2001
From: ReinUsesLisp <reinuseslisp@airmail.cc>
Date: Wed, 23 Jun 2021 02:41:00 -0300
Subject: [PATCH] shader: Unify shader stage types

---
 src/shader_recompiler/stage.h                 | 11 ++++++++--
 src/video_core/engines/kepler_compute.cpp     |  1 -
 src/video_core/engines/maxwell_3d.cpp         |  1 -
 src/video_core/engines/maxwell_3d.h           |  1 -
 src/video_core/engines/shader_type.h          | 21 -------------------
 src/video_core/renderer_opengl/gl_device.cpp  | 18 +++++++++-------
 src/video_core/renderer_opengl/gl_device.h    | 11 +++++-----
 .../renderer_opengl/gl_rasterizer.cpp         |  2 --
 .../renderer_opengl/gl_shader_cache.cpp       |  1 -
 .../renderer_opengl/gl_shader_cache.h         |  1 -
 .../renderer_vulkan/maxwell_to_vk.cpp         | 15 ++++++-------
 .../renderer_vulkan/maxwell_to_vk.h           |  3 ++-
 .../renderer_vulkan/vk_graphics_pipeline.cpp  |  2 +-
 .../renderer_vulkan/vk_rasterizer.cpp         |  2 --
 src/video_core/shader_environment.cpp         |  2 +-
 15 files changed, 37 insertions(+), 55 deletions(-)
 delete mode 100644 src/video_core/engines/shader_type.h

diff --git a/src/shader_recompiler/stage.h b/src/shader_recompiler/stage.h
index 7d4f2c0bba..5c1c8d8fc4 100644
--- a/src/shader_recompiler/stage.h
+++ b/src/shader_recompiler/stage.h
@@ -9,13 +9,20 @@
 namespace Shader {
 
 enum class Stage : u32 {
-    Compute,
-    VertexA,
     VertexB,
     TessellationControl,
     TessellationEval,
     Geometry,
     Fragment,
+
+    Compute,
+
+    VertexA,
 };
+constexpr u32 MaxStageTypes = 6;
+
+[[nodiscard]] constexpr Stage StageFromIndex(size_t index) noexcept {
+    return static_cast<Stage>(static_cast<size_t>(Stage::VertexB) + index);
+}
 
 } // namespace Shader
diff --git a/src/video_core/engines/kepler_compute.cpp b/src/video_core/engines/kepler_compute.cpp
index cae93c4709..492b4c5a38 100644
--- a/src/video_core/engines/kepler_compute.cpp
+++ b/src/video_core/engines/kepler_compute.cpp
@@ -8,7 +8,6 @@
 #include "core/core.h"
 #include "video_core/engines/kepler_compute.h"
 #include "video_core/engines/maxwell_3d.h"
-#include "video_core/engines/shader_type.h"
 #include "video_core/memory_manager.h"
 #include "video_core/rasterizer_interface.h"
 #include "video_core/renderer_base.h"
diff --git a/src/video_core/engines/maxwell_3d.cpp b/src/video_core/engines/maxwell_3d.cpp
index 103a51fd05..b18b8a02a3 100644
--- a/src/video_core/engines/maxwell_3d.cpp
+++ b/src/video_core/engines/maxwell_3d.cpp
@@ -8,7 +8,6 @@
 #include "core/core.h"
 #include "core/core_timing.h"
 #include "video_core/engines/maxwell_3d.h"
-#include "video_core/engines/shader_type.h"
 #include "video_core/gpu.h"
 #include "video_core/memory_manager.h"
 #include "video_core/rasterizer_interface.h"
diff --git a/src/video_core/engines/maxwell_3d.h b/src/video_core/engines/maxwell_3d.h
index 04d5790f68..fc2c36c6b4 100644
--- a/src/video_core/engines/maxwell_3d.h
+++ b/src/video_core/engines/maxwell_3d.h
@@ -20,7 +20,6 @@
 #include "video_core/engines/const_buffer_info.h"
 #include "video_core/engines/engine_interface.h"
 #include "video_core/engines/engine_upload.h"
-#include "video_core/engines/shader_type.h"
 #include "video_core/gpu.h"
 #include "video_core/macro/macro.h"
 #include "video_core/textures/texture.h"
diff --git a/src/video_core/engines/shader_type.h b/src/video_core/engines/shader_type.h
deleted file mode 100644
index 49ce5cde5d..0000000000
--- a/src/video_core/engines/shader_type.h
+++ /dev/null
@@ -1,21 +0,0 @@
-// Copyright 2019 yuzu Emulator Project
-// Licensed under GPLv2 or any later version
-// Refer to the license.txt file included.
-
-#pragma once
-
-#include "common/common_types.h"
-
-namespace Tegra::Engines {
-
-enum class ShaderType : u32 {
-    Vertex = 0,
-    TesselationControl = 1,
-    TesselationEval = 2,
-    Geometry = 3,
-    Fragment = 4,
-    Compute = 5,
-};
-static constexpr std::size_t MaxShaderTypes = 6;
-
-} // namespace Tegra::Engines
diff --git a/src/video_core/renderer_opengl/gl_device.cpp b/src/video_core/renderer_opengl/gl_device.cpp
index b1b5ba1ab9..27be347e69 100644
--- a/src/video_core/renderer_opengl/gl_device.cpp
+++ b/src/video_core/renderer_opengl/gl_device.cpp
@@ -17,6 +17,7 @@
 #include "common/logging/log.h"
 #include "common/scope_exit.h"
 #include "common/settings.h"
+#include "shader_recompiler/stage.h"
 #include "video_core/renderer_opengl/gl_device.h"
 #include "video_core/renderer_opengl/gl_resource_manager.h"
 
@@ -59,16 +60,18 @@ bool HasExtension(std::span<const std::string_view> extensions, std::string_view
     return std::ranges::find(extensions, extension) != extensions.end();
 }
 
-std::array<u32, Tegra::Engines::MaxShaderTypes> BuildMaxUniformBuffers() noexcept {
-    std::array<u32, Tegra::Engines::MaxShaderTypes> max;
-    std::ranges::transform(LIMIT_UBOS, max.begin(),
-                           [](GLenum pname) { return GetInteger<u32>(pname); });
+std::array<u32, Shader::MaxStageTypes> BuildMaxUniformBuffers() noexcept {
+    std::array<u32, Shader::MaxStageTypes> max;
+    std::ranges::transform(LIMIT_UBOS, max.begin(), &GetInteger<u32>);
     return max;
 }
 
 bool IsASTCSupported() {
-    static constexpr std::array targets = {GL_TEXTURE_2D, GL_TEXTURE_2D_ARRAY};
-    static constexpr std::array formats = {
+    static constexpr std::array targets{
+        GL_TEXTURE_2D,
+        GL_TEXTURE_2D_ARRAY,
+    };
+    static constexpr std::array formats{
         GL_COMPRESSED_RGBA_ASTC_4x4_KHR,           GL_COMPRESSED_RGBA_ASTC_5x4_KHR,
         GL_COMPRESSED_RGBA_ASTC_5x5_KHR,           GL_COMPRESSED_RGBA_ASTC_6x5_KHR,
         GL_COMPRESSED_RGBA_ASTC_6x6_KHR,           GL_COMPRESSED_RGBA_ASTC_8x5_KHR,
@@ -84,11 +87,10 @@ bool IsASTCSupported() {
         GL_COMPRESSED_SRGB8_ALPHA8_ASTC_10x8_KHR,  GL_COMPRESSED_SRGB8_ALPHA8_ASTC_10x10_KHR,
         GL_COMPRESSED_SRGB8_ALPHA8_ASTC_12x10_KHR, GL_COMPRESSED_SRGB8_ALPHA8_ASTC_12x12_KHR,
     };
-    static constexpr std::array required_support = {
+    static constexpr std::array required_support{
         GL_VERTEX_TEXTURE,   GL_TESS_CONTROL_TEXTURE, GL_TESS_EVALUATION_TEXTURE,
         GL_GEOMETRY_TEXTURE, GL_FRAGMENT_TEXTURE,     GL_COMPUTE_TEXTURE,
     };
-
     for (const GLenum target : targets) {
         for (const GLenum format : formats) {
             for (const GLenum support : required_support) {
diff --git a/src/video_core/renderer_opengl/gl_device.h b/src/video_core/renderer_opengl/gl_device.h
index 0bd277d38d..ad7b01b063 100644
--- a/src/video_core/renderer_opengl/gl_device.h
+++ b/src/video_core/renderer_opengl/gl_device.h
@@ -6,7 +6,7 @@
 
 #include <cstddef>
 #include "common/common_types.h"
-#include "video_core/engines/shader_type.h"
+#include "shader_recompiler/stage.h"
 
 namespace OpenGL {
 
@@ -16,8 +16,8 @@ public:
 
     [[nodiscard]] std::string GetVendorName() const;
 
-    u32 GetMaxUniformBuffers(Tegra::Engines::ShaderType shader_type) const noexcept {
-        return max_uniform_buffers[static_cast<std::size_t>(shader_type)];
+    u32 GetMaxUniformBuffers(Shader::Stage stage) const noexcept {
+        return max_uniform_buffers[static_cast<size_t>(stage)];
     }
 
     size_t GetUniformBufferAlignment() const {
@@ -148,8 +148,7 @@ private:
     static bool TestVariableAoffi();
     static bool TestPreciseBug();
 
-    std::string vendor_name;
-    std::array<u32, Tegra::Engines::MaxShaderTypes> max_uniform_buffers{};
+    std::array<u32, Shader::MaxStageTypes> max_uniform_buffers{};
     size_t uniform_buffer_alignment{};
     size_t shader_storage_alignment{};
     u32 max_vertex_attributes{};
@@ -181,6 +180,8 @@ private:
     bool has_sparse_texture_2{};
     bool warp_size_potentially_larger_than_guest{};
     bool need_fastmath_off{};
+
+    std::string vendor_name;
 };
 
 } // namespace OpenGL
diff --git a/src/video_core/renderer_opengl/gl_rasterizer.cpp b/src/video_core/renderer_opengl/gl_rasterizer.cpp
index e3d336f86a..0f0d780b5e 100644
--- a/src/video_core/renderer_opengl/gl_rasterizer.cpp
+++ b/src/video_core/renderer_opengl/gl_rasterizer.cpp
@@ -23,7 +23,6 @@
 #include "core/memory.h"
 #include "video_core/engines/kepler_compute.h"
 #include "video_core/engines/maxwell_3d.h"
-#include "video_core/engines/shader_type.h"
 #include "video_core/memory_manager.h"
 #include "video_core/renderer_opengl/gl_device.h"
 #include "video_core/renderer_opengl/gl_query_cache.h"
@@ -40,7 +39,6 @@ namespace OpenGL {
 using Maxwell = Tegra::Engines::Maxwell3D::Regs;
 using GLvec4 = std::array<GLfloat, 4>;
 
-using Tegra::Engines::ShaderType;
 using VideoCore::Surface::PixelFormat;
 using VideoCore::Surface::SurfaceTarget;
 using VideoCore::Surface::SurfaceType;
diff --git a/src/video_core/renderer_opengl/gl_shader_cache.cpp b/src/video_core/renderer_opengl/gl_shader_cache.cpp
index f2f18b18ae..5af9b77451 100644
--- a/src/video_core/renderer_opengl/gl_shader_cache.cpp
+++ b/src/video_core/renderer_opengl/gl_shader_cache.cpp
@@ -26,7 +26,6 @@
 #include "shader_recompiler/profile.h"
 #include "video_core/engines/kepler_compute.h"
 #include "video_core/engines/maxwell_3d.h"
-#include "video_core/engines/shader_type.h"
 #include "video_core/memory_manager.h"
 #include "video_core/renderer_opengl/gl_rasterizer.h"
 #include "video_core/renderer_opengl/gl_resource_manager.h"
diff --git a/src/video_core/renderer_opengl/gl_shader_cache.h b/src/video_core/renderer_opengl/gl_shader_cache.h
index 16873fcece..9d53062937 100644
--- a/src/video_core/renderer_opengl/gl_shader_cache.h
+++ b/src/video_core/renderer_opengl/gl_shader_cache.h
@@ -17,7 +17,6 @@
 #include "shader_recompiler/host_translate_info.h"
 #include "shader_recompiler/object_pool.h"
 #include "shader_recompiler/profile.h"
-#include "video_core/engines/shader_type.h"
 #include "video_core/renderer_opengl/gl_compute_pipeline.h"
 #include "video_core/renderer_opengl/gl_graphics_pipeline.h"
 #include "video_core/renderer_opengl/gl_shader_context.h"
diff --git a/src/video_core/renderer_vulkan/maxwell_to_vk.cpp b/src/video_core/renderer_vulkan/maxwell_to_vk.cpp
index 8f0b0b8ecc..8f9b9a11a2 100644
--- a/src/video_core/renderer_vulkan/maxwell_to_vk.cpp
+++ b/src/video_core/renderer_vulkan/maxwell_to_vk.cpp
@@ -266,19 +266,20 @@ FormatInfo SurfaceFormat(const Device& device, FormatType format_type, bool with
     return {device.GetSupportedFormat(tuple.format, usage, format_type), attachable, storage};
 }
 
-VkShaderStageFlagBits ShaderStage(Tegra::Engines::ShaderType stage) {
+VkShaderStageFlagBits ShaderStage(Shader::Stage stage) {
     switch (stage) {
-    case Tegra::Engines::ShaderType::Vertex:
+    case Shader::Stage::VertexA:
+    case Shader::Stage::VertexB:
         return VK_SHADER_STAGE_VERTEX_BIT;
-    case Tegra::Engines::ShaderType::TesselationControl:
+    case Shader::Stage::TessellationControl:
         return VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT;
-    case Tegra::Engines::ShaderType::TesselationEval:
+    case Shader::Stage::TessellationEval:
         return VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT;
-    case Tegra::Engines::ShaderType::Geometry:
+    case Shader::Stage::Geometry:
         return VK_SHADER_STAGE_GEOMETRY_BIT;
-    case Tegra::Engines::ShaderType::Fragment:
+    case Shader::Stage::Fragment:
         return VK_SHADER_STAGE_FRAGMENT_BIT;
-    case Tegra::Engines::ShaderType::Compute:
+    case Shader::Stage::Compute:
         return VK_SHADER_STAGE_COMPUTE_BIT;
     }
     UNIMPLEMENTED_MSG("Unimplemented shader stage={}", stage);
diff --git a/src/video_core/renderer_vulkan/maxwell_to_vk.h b/src/video_core/renderer_vulkan/maxwell_to_vk.h
index 50a599c116..8a96160397 100644
--- a/src/video_core/renderer_vulkan/maxwell_to_vk.h
+++ b/src/video_core/renderer_vulkan/maxwell_to_vk.h
@@ -5,6 +5,7 @@
 #pragma once
 
 #include "common/common_types.h"
+#include "shader_recompiler/stage.h"
 #include "video_core/engines/maxwell_3d.h"
 #include "video_core/surface.h"
 #include "video_core/textures/texture.h"
@@ -45,7 +46,7 @@ struct FormatInfo {
 [[nodiscard]] FormatInfo SurfaceFormat(const Device& device, FormatType format_type, bool with_srgb,
                                        PixelFormat pixel_format);
 
-VkShaderStageFlagBits ShaderStage(Tegra::Engines::ShaderType stage);
+VkShaderStageFlagBits ShaderStage(Shader::Stage stage);
 
 VkPrimitiveTopology PrimitiveTopology(const Device& device, Maxwell::PrimitiveTopology topology);
 
diff --git a/src/video_core/renderer_vulkan/vk_graphics_pipeline.cpp b/src/video_core/renderer_vulkan/vk_graphics_pipeline.cpp
index 2b59a9d880..9eb353a885 100644
--- a/src/video_core/renderer_vulkan/vk_graphics_pipeline.cpp
+++ b/src/video_core/renderer_vulkan/vk_graphics_pipeline.cpp
@@ -737,7 +737,7 @@ void GraphicsPipeline::MakePipeline(VkRenderPass render_pass) {
                 .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
                 .pNext = nullptr,
                 .flags = 0,
-                .stage = MaxwellToVK::ShaderStage(static_cast<Tegra::Engines::ShaderType>(stage)),
+                .stage = MaxwellToVK::ShaderStage(Shader::StageFromIndex(stage)),
                 .module = *spv_modules[stage],
                 .pName = "main",
                 .pSpecializationInfo = nullptr,
diff --git a/src/video_core/renderer_vulkan/vk_rasterizer.cpp b/src/video_core/renderer_vulkan/vk_rasterizer.cpp
index c57e16c508..f04c3394ca 100644
--- a/src/video_core/renderer_vulkan/vk_rasterizer.cpp
+++ b/src/video_core/renderer_vulkan/vk_rasterizer.cpp
@@ -58,8 +58,6 @@ struct DrawParams {
     bool is_indexed;
 };
 
-constexpr auto COMPUTE_SHADER_INDEX = static_cast<size_t>(Tegra::Engines::ShaderType::Compute);
-
 VkViewport GetViewportState(const Device& device, const Maxwell& regs, size_t index) {
     const auto& src = regs.viewport_transform[index];
     const float width = src.scale_x * 2.0f;
diff --git a/src/video_core/shader_environment.cpp b/src/video_core/shader_environment.cpp
index 6243cd1764..d463e2b560 100644
--- a/src/video_core/shader_environment.cpp
+++ b/src/video_core/shader_environment.cpp
@@ -22,7 +22,7 @@
 namespace VideoCommon {
 
 constexpr std::array<char, 8> MAGIC_NUMBER{'y', 'u', 'z', 'u', 'c', 'a', 'c', 'h'};
-constexpr u32 CACHE_VERSION = 3;
+constexpr u32 CACHE_VERSION = 4;
 
 constexpr size_t INST_SIZE = sizeof(u64);