From 01b21ee1e8e7455dd84ee7f22d33426caaaafdb3 Mon Sep 17 00:00:00 2001
From: Fernando Sahmkow <fsahmkow27@gmail.com>
Date: Tue, 25 Jun 2019 20:15:40 -0400
Subject: [PATCH] shader_ir: Corrections, documenting and asserting
 control_flow

---
 src/video_core/shader/control_flow.cpp | 80 ++++++++++++--------------
 src/video_core/shader/control_flow.h   | 16 ++++--
 src/video_core/shader/decode.cpp       | 10 ++--
 3 files changed, 54 insertions(+), 52 deletions(-)

diff --git a/src/video_core/shader/control_flow.cpp b/src/video_core/shader/control_flow.cpp
index c99d95b574..deef0cd3a8 100644
--- a/src/video_core/shader/control_flow.cpp
+++ b/src/video_core/shader/control_flow.cpp
@@ -1,3 +1,6 @@
+// Copyright 2019 yuzu Emulator Project
+// Licensed under GPLv2 or any later version
+// Refer to the license.txt file included.
 
 #include <list>
 #include <map>
@@ -17,16 +20,18 @@ using Tegra::Shader::OpCode;
 
 constexpr s32 unassigned_branch = -2;
 
+/***
+ * 'ControlStack' represents a static stack of control jumps such as SSY and PBK
+ * stacks in Maxwell.
+ ***/
 struct ControlStack {
-    std::array<u32, 20> stack;
+    static constexpr std::size_t stack_fixed_size = 20;
+    std::array<u32, stack_fixed_size> stack{};
     u32 index{};
 
     ControlStack() {}
 
-    ControlStack(const ControlStack& cp) {
-        index = cp.index;
-        std::memcpy(stack.data(), cp.stack.data(), index * sizeof(u32));
-    }
+    ControlStack(const ControlStack& cp) = default;
 
     bool Compare(const ControlStack& cs) const {
         if (index != cs.index) {
@@ -35,6 +40,7 @@ struct ControlStack {
         return std::memcmp(stack.data(), cs.stack.data(), index * sizeof(u32)) == 0;
     }
 
+    /// This compare just compares the top of the stack against one another
     bool SoftCompare(const ControlStack& cs) const {
         if (index == 0 || cs.index == 0) {
             return index == cs.index;
@@ -51,7 +57,7 @@ struct ControlStack {
     }
 
     bool Push(u32 address) {
-        if (index >= 20) {
+        if (index >= stack.size()) {
             return false;
         }
         stack[index] = address;
@@ -70,21 +76,23 @@ struct ControlStack {
 
 struct Query {
     Query() {}
-    Query(const Query& q) : address{q.address}, ssy_stack{q.ssy_stack}, pbk_stack{q.pbk_stack} {}
-    u32 address;
+    Query(const Query& q) = default;
+    u32 address{};
     ControlStack ssy_stack{};
     ControlStack pbk_stack{};
 };
 
 struct BlockStack {
     BlockStack() = default;
-    BlockStack(const BlockStack& b) : ssy_stack{b.ssy_stack}, pbk_stack{b.pbk_stack} {}
+    BlockStack(const BlockStack& b) = default;
     BlockStack(const Query& q) : ssy_stack{q.ssy_stack}, pbk_stack{q.pbk_stack} {}
     ControlStack ssy_stack{};
     ControlStack pbk_stack{};
 };
 
 struct BlockBranchInfo {
+    BlockBranchInfo() = default;
+    BlockBranchInfo(const BlockBranchInfo& b) = default;
     Condition condition{};
     s32 address{exit_branch};
     bool kill{};
@@ -94,7 +102,7 @@ struct BlockBranchInfo {
 };
 
 struct BlockInfo {
-    BlockInfo() {}
+    BlockInfo() = default;
     u32 start{};
     u32 end{};
     bool visited{};
@@ -107,24 +115,15 @@ struct BlockInfo {
 
 struct CFGRebuildState {
     explicit CFGRebuildState(const ProgramCode& program_code, const std::size_t program_size)
-        : program_code{program_code}, program_size{program_size} {
-        // queries.clear();
-        block_info.clear();
-        labels.clear();
-        registered.clear();
-        ssy_labels.clear();
-        pbk_labels.clear();
-        inspect_queries.clear();
-        queries.clear();
-    }
+        : program_code{program_code}, program_size{program_size} {}
 
     std::vector<BlockInfo> block_info{};
     std::list<u32> inspect_queries{};
     std::list<Query> queries{};
     std::unordered_map<u32, u32> registered{};
     std::unordered_set<u32> labels{};
-    std::map<u32, u32> ssy_labels;
-    std::map<u32, u32> pbk_labels;
+    std::map<u32, u32> ssy_labels{};
+    std::map<u32, u32> pbk_labels{};
     std::unordered_map<u32, BlockStack> stacks{};
     const ProgramCode& program_code;
     const std::size_t program_size;
@@ -156,7 +155,7 @@ BlockInfo* CreateBlockInfo(CFGRebuildState& state, u32 start, u32 end) {
     auto& it = state.block_info.emplace_back();
     it.start = start;
     it.end = end;
-    u32 index = state.block_info.size() - 1;
+    const u32 index = static_cast<u32>(state.block_info.size() - 1);
     state.registered.insert({start, index});
     return &it;
 }
@@ -172,11 +171,10 @@ enum class ParseResult : u32 {
 };
 
 ParseResult ParseCode(CFGRebuildState& state, u32 address, ParseInfo& parse_info) {
-
     u32 offset = static_cast<u32>(address);
-    u32 end_address = static_cast<u32>(state.program_size - 10U) * 8U;
+    const u32 end_address = static_cast<u32>(state.program_size - 10U) * 8U;
 
-    auto insert_label = ([](CFGRebuildState& state, u32 address) {
+    const auto insert_label = ([](CFGRebuildState& state, u32 address) {
         auto pair = state.labels.emplace(address);
         if (pair.second) {
             state.inspect_queries.push_back(address);
@@ -361,20 +359,18 @@ bool TryInspectAddress(CFGRebuildState& state) {
     if (state.inspect_queries.empty()) {
         return false;
     }
-    u32 address = state.inspect_queries.front();
+    const u32 address = state.inspect_queries.front();
     state.inspect_queries.pop_front();
-    auto search_result = TryGetBlock(state, address);
-    BlockInfo* block_info;
+    const auto search_result = TryGetBlock(state, address);
     switch (search_result.first) {
     case BlockCollision::Found: {
         return true;
-        break;
     }
     case BlockCollision::Inside: {
         // This case is the tricky one:
         // We need to Split the block in 2 sepparate blocks
         auto it = search_result.second;
-        block_info = CreateBlockInfo(state, address, it->end);
+        BlockInfo* block_info = CreateBlockInfo(state, address, it->end);
         it->end = address - 1;
         block_info->branch = it->branch;
         BlockBranchInfo forward_branch{};
@@ -382,34 +378,32 @@ bool TryInspectAddress(CFGRebuildState& state) {
         forward_branch.ignore = true;
         it->branch = forward_branch;
         return true;
-        break;
     }
     default:
         break;
     }
     ParseInfo parse_info;
-    ParseResult parse_result = ParseCode(state, address, parse_info);
+    const ParseResult parse_result = ParseCode(state, address, parse_info);
     if (parse_result == ParseResult::AbnormalFlow) {
-        // if it's the end of the program, end it safely
         // if it's AbnormalFlow, we end it as false, ending the CFG reconstruction
         return false;
     }
 
-    block_info = CreateBlockInfo(state, address, parse_info.end_address);
+    BlockInfo* block_info = CreateBlockInfo(state, address, parse_info.end_address);
     block_info->branch = parse_info.branch_info;
     if (parse_info.branch_info.condition.IsUnconditional()) {
         return true;
     }
 
-    u32 fallthrough_address = parse_info.end_address + 1;
+    const u32 fallthrough_address = parse_info.end_address + 1;
     state.inspect_queries.push_front(fallthrough_address);
     return true;
 }
 
 bool TryQuery(CFGRebuildState& state) {
-    auto gather_labels = ([](ControlStack& cc, std::map<u32, u32>& labels, BlockInfo& block) {
+    const auto gather_labels = ([](ControlStack& cc, std::map<u32, u32>& labels, BlockInfo& block) {
         auto gather_start = labels.lower_bound(block.start);
-        auto gather_end = labels.upper_bound(block.end);
+        const auto gather_end = labels.upper_bound(block.end);
         while (gather_start != gather_end) {
             cc.Push(gather_start->second);
             gather_start++;
@@ -419,21 +413,21 @@ bool TryQuery(CFGRebuildState& state) {
         return false;
     }
     Query& q = state.queries.front();
-    u32 block_index = state.registered[q.address];
+    const u32 block_index = state.registered[q.address];
     BlockInfo& block = state.block_info[block_index];
     // If the block is visted, check if the stacks match, else gather the ssy/pbk
     // labels into the current stack and look if the branch at the end of the block
     // consumes a label. Schedule new queries accordingly
     if (block.visited) {
         BlockStack& stack = state.stacks[q.address];
-        bool all_okay = (stack.ssy_stack.Size() == 0 || q.ssy_stack.Compare(stack.ssy_stack)) &&
-                        (stack.pbk_stack.Size() == 0 || q.pbk_stack.Compare(stack.pbk_stack));
+        const bool all_okay =
+            (stack.ssy_stack.Size() == 0 || q.ssy_stack.Compare(stack.ssy_stack)) &&
+            (stack.pbk_stack.Size() == 0 || q.pbk_stack.Compare(stack.pbk_stack));
         state.queries.pop_front();
         return all_okay;
     }
     block.visited = true;
-    BlockStack bs{q};
-    state.stacks[q.address] = bs;
+    state.stacks[q.address] = BlockStack{q};
     Query q2(q);
     state.queries.pop_front();
     gather_labels(q2.ssy_stack, state.ssy_labels, block);
diff --git a/src/video_core/shader/control_flow.h b/src/video_core/shader/control_flow.h
index 4a2cd622c9..4689b0c102 100644
--- a/src/video_core/shader/control_flow.h
+++ b/src/video_core/shader/control_flow.h
@@ -1,3 +1,7 @@
+// Copyright 2019 yuzu Emulator Project
+// Licensed under GPLv2 or any later version
+// Refer to the license.txt file included.
+
 #pragma once
 
 #include <cstring>
@@ -20,12 +24,15 @@ struct Condition {
     ConditionCode cc{ConditionCode::T};
 
     bool IsUnconditional() const {
-        return (predicate == Pred::UnusedIndex) && (cc == ConditionCode::T);
+        return predicate == Pred::UnusedIndex && cc == ConditionCode::T;
+    }
+    bool operator==(const Condition& other) const {
+        return std::tie(predicate, cc) == std::tie(other.predicate, other.cc);
     }
 };
 
 struct ShaderBlock {
-    ShaderBlock() {}
+    ShaderBlock() = default;
     ShaderBlock(const ShaderBlock& sb) = default;
     u32 start{};
     u32 end{};
@@ -35,11 +42,12 @@ struct ShaderBlock {
         bool kills{};
         s32 address{};
         bool operator==(const Branch& b) const {
-            return std::memcmp(this, &b, sizeof(Branch)) == 0;
+            return std::tie(cond, kills, address) == std::tie(b.cond, b.kills, b.address);
         }
     } branch;
     bool operator==(const ShaderBlock& sb) const {
-        return std::memcmp(this, &sb, sizeof(ShaderBlock)) == 0;
+        return std::tie(start, end, ignore_branch, branch) ==
+               std::tie(sb.start, sb.end, sb.ignore_branch, sb.branch);
     }
 };
 
diff --git a/src/video_core/shader/decode.cpp b/src/video_core/shader/decode.cpp
index f9b1960daa..b4a282cbd1 100644
--- a/src/video_core/shader/decode.cpp
+++ b/src/video_core/shader/decode.cpp
@@ -46,7 +46,7 @@ void ShaderIR::Decode() {
         coverage_end = shader_info.end;
         if (shader_info.decompilable) {
             disable_flow_stack = true;
-            auto insert_block = ([this](NodeBlock& nodes, u32 label) {
+            const auto insert_block = ([this](NodeBlock& nodes, u32 label) {
                 if (label == exit_branch) {
                     return;
                 }
@@ -88,7 +88,6 @@ void ShaderIR::Decode() {
     for (u32 label = main_offset; label < shader_end; label++) {
         basic_blocks.insert({label, DecodeRange(label, label + 1)});
     }
-    return;
 }
 
 NodeBlock ShaderIR::DecodeRange(u32 begin, u32 end) {
@@ -104,16 +103,17 @@ void ShaderIR::DecodeRangeInner(NodeBlock& bb, u32 begin, u32 end) {
 }
 
 void ShaderIR::InsertControlFlow(NodeBlock& bb, const ShaderBlock& block) {
-    auto apply_conditions = ([&](const Condition& cond, Node n) -> Node {
+    const auto apply_conditions = ([&](const Condition& cond, Node n) -> Node {
         Node result = n;
         if (cond.cc != ConditionCode::T) {
             result = Conditional(GetConditionCode(cond.cc), {result});
         }
         if (cond.predicate != Pred::UnusedIndex) {
             u32 pred = static_cast<u32>(cond.predicate);
-            bool is_neg = pred > 7;
-            if (is_neg)
+            const bool is_neg = pred > 7;
+            if (is_neg) {
                 pred -= 8;
+            }
             result = Conditional(GetPredicate(pred, is_neg), {result});
         }
         return result;