From be320a9e10fda32a984b12cdfe3aaf09cc67b39a Mon Sep 17 00:00:00 2001
From: Fernando Sahmkow <fsahmkow27@gmail.com>
Date: Wed, 5 Feb 2020 15:48:20 -0400
Subject: [PATCH] Common: Polish Fiber class, add comments, asserts and more
 tests.

---
 src/common/fiber.cpp        | 55 ++++++++++++---------
 src/common/fiber.h          | 14 +++++-
 src/common/spin_lock.cpp    |  7 +++
 src/common/spin_lock.h      |  1 +
 src/tests/common/fibers.cpp | 95 ++++++++++++++++++++++++++++++++++++-
 5 files changed, 147 insertions(+), 25 deletions(-)

diff --git a/src/common/fiber.cpp b/src/common/fiber.cpp
index a2c0401c4d..a88a30cede 100644
--- a/src/common/fiber.cpp
+++ b/src/common/fiber.cpp
@@ -2,6 +2,7 @@
 // Licensed under GPLv2 or any later version
 // Refer to the license.txt file included.
 
+#include "common/assert.h"
 #include "common/fiber.h"
 #ifdef _MSC_VER
 #include <windows.h>
@@ -18,11 +19,11 @@ struct Fiber::FiberImpl {
 };
 
 void Fiber::start() {
-    if (previous_fiber) {
-        previous_fiber->guard.unlock();
-        previous_fiber = nullptr;
-    }
+    ASSERT(previous_fiber != nullptr);
+    previous_fiber->guard.unlock();
+    previous_fiber.reset();
     entry_point(start_parameter);
+    UNREACHABLE();
 }
 
 void __stdcall Fiber::FiberStartFunc(void* fiber_parameter)
@@ -43,12 +44,16 @@ Fiber::Fiber() : guard{}, entry_point{}, start_parameter{}, previous_fiber{} {
 
 Fiber::~Fiber() {
     // Make sure the Fiber is not being used
-    guard.lock();
-    guard.unlock();
+    bool locked = guard.try_lock();
+    ASSERT_MSG(locked, "Destroying a fiber that's still running");
+    if (locked) {
+        guard.unlock();
+    }
     DeleteFiber(impl->handle);
 }
 
 void Fiber::Exit() {
+    ASSERT_MSG(is_thread_fiber, "Exitting non main thread fiber");
     if (!is_thread_fiber) {
         return;
     }
@@ -57,14 +62,15 @@ void Fiber::Exit() {
 }
 
 void Fiber::YieldTo(std::shared_ptr<Fiber> from, std::shared_ptr<Fiber> to) {
+    ASSERT_MSG(from != nullptr, "Yielding fiber is null!");
+    ASSERT_MSG(to != nullptr, "Next fiber is null!");
     to->guard.lock();
     to->previous_fiber = from;
     SwitchToFiber(to->impl->handle);
     auto previous_fiber = from->previous_fiber;
-    if (previous_fiber) {
-        previous_fiber->guard.unlock();
-        previous_fiber.reset();
-    }
+    ASSERT(previous_fiber != nullptr);
+    previous_fiber->guard.unlock();
+    previous_fiber.reset();
 }
 
 std::shared_ptr<Fiber> Fiber::ThreadToFiber() {
@@ -85,12 +91,12 @@ struct alignas(64) Fiber::FiberImpl {
 };
 
 void Fiber::start(boost::context::detail::transfer_t& transfer) {
-    if (previous_fiber) {
-        previous_fiber->impl->context = transfer.fctx;
-        previous_fiber->guard.unlock();
-        previous_fiber = nullptr;
-    }
+    ASSERT(previous_fiber != nullptr);
+    previous_fiber->impl->context = transfer.fctx;
+    previous_fiber->guard.unlock();
+    previous_fiber.reset();
     entry_point(start_parameter);
+    UNREACHABLE();
 }
 
 void Fiber::FiberStartFunc(boost::context::detail::transfer_t transfer)
@@ -113,11 +119,15 @@ Fiber::Fiber() : guard{}, entry_point{}, start_parameter{}, previous_fiber{} {
 
 Fiber::~Fiber() {
     // Make sure the Fiber is not being used
-    guard.lock();
-    guard.unlock();
+    bool locked = guard.try_lock();
+    ASSERT_MSG(locked, "Destroying a fiber that's still running");
+    if (locked) {
+        guard.unlock();
+    }
 }
 
 void Fiber::Exit() {
+    ASSERT_MSG(is_thread_fiber, "Exitting non main thread fiber");
     if (!is_thread_fiber) {
         return;
     }
@@ -125,15 +135,16 @@ void Fiber::Exit() {
 }
 
 void Fiber::YieldTo(std::shared_ptr<Fiber> from, std::shared_ptr<Fiber> to) {
+    ASSERT_MSG(from != nullptr, "Yielding fiber is null!");
+    ASSERT_MSG(to != nullptr, "Next fiber is null!");
     to->guard.lock();
     to->previous_fiber = from;
     auto transfer = boost::context::detail::jump_fcontext(to->impl.context, nullptr);
     auto previous_fiber = from->previous_fiber;
-    if (previous_fiber) {
-        previous_fiber->impl->context = transfer.fctx;
-        previous_fiber->guard.unlock();
-        previous_fiber.reset();
-    }
+    ASSERT(previous_fiber != nullptr);
+    previous_fiber->impl->context = transfer.fctx;
+    previous_fiber->guard.unlock();
+    previous_fiber.reset();
 }
 
 std::shared_ptr<Fiber> Fiber::ThreadToFiber() {
diff --git a/src/common/fiber.h b/src/common/fiber.h
index 812d6644ac..89a01fdd8e 100644
--- a/src/common/fiber.h
+++ b/src/common/fiber.h
@@ -18,6 +18,18 @@ namespace boost::context::detail {
 
 namespace Common {
 
+/**
+ * Fiber class
+ * a fiber is a userspace thread with it's own context. They can be used to
+ * implement coroutines, emulated threading systems and certain asynchronous
+ * patterns.
+ *
+ * This class implements fibers at a low level, thus allowing greater freedom
+ * to implement such patterns. This fiber class is 'threadsafe' only one fiber
+ * can be running at a time and threads will be locked while trying to yield to
+ * a running fiber until it yields. WARNING exchanging two running fibers between
+ * threads will cause a deadlock.
+ */
 class Fiber {
 public:
     Fiber(std::function<void(void*)>&& entry_point_func, void* start_parameter);
@@ -53,8 +65,6 @@ private:
     static void FiberStartFunc(boost::context::detail::transfer_t transfer);
 #endif
 
-
-
     struct FiberImpl;
 
     SpinLock guard;
diff --git a/src/common/spin_lock.cpp b/src/common/spin_lock.cpp
index 8077b78d28..82a1d39fff 100644
--- a/src/common/spin_lock.cpp
+++ b/src/common/spin_lock.cpp
@@ -43,4 +43,11 @@ void SpinLock::unlock() {
     lck.clear(std::memory_order_release);
 }
 
+bool SpinLock::try_lock() {
+    if (lck.test_and_set(std::memory_order_acquire)) {
+        return false;
+    }
+    return true;
+}
+
 } // namespace Common
diff --git a/src/common/spin_lock.h b/src/common/spin_lock.h
index cbc67b6c85..70282a961f 100644
--- a/src/common/spin_lock.h
+++ b/src/common/spin_lock.h
@@ -12,6 +12,7 @@ class SpinLock {
 public:
     void lock();
     void unlock();
+    bool try_lock();
 
 private:
     std::atomic_flag lck = ATOMIC_FLAG_INIT;
diff --git a/src/tests/common/fibers.cpp b/src/tests/common/fibers.cpp
index ff840afa64..358393a192 100644
--- a/src/tests/common/fibers.cpp
+++ b/src/tests/common/fibers.cpp
@@ -64,7 +64,9 @@ static void ThreadStart1(u32 id, TestControl1& test_control) {
     test_control.ExecuteThread(id);
 }
 
-
+/** This test checks for fiber setup configuration and validates that fibers are
+ *  doing all the work required.
+ */
 TEST_CASE("Fibers::Setup", "[common]") {
     constexpr u32 num_threads = 7;
     TestControl1 test_control{};
@@ -188,6 +190,10 @@ static void ThreadStart2_2(u32 id, TestControl2& test_control) {
     test_control.Exit();
 }
 
+/** This test checks for fiber thread exchange configuration and validates that fibers are
+ *  that a fiber has been succesfully transfered from one thread to another and that the TLS
+ *  region of the thread is kept while changing fibers.
+ */
 TEST_CASE("Fibers::InterExchange", "[common]") {
     TestControl2 test_control{};
     test_control.thread_fibers.resize(2, nullptr);
@@ -210,5 +216,92 @@ TEST_CASE("Fibers::InterExchange", "[common]") {
     REQUIRE(test_control.value1 == cal_value);
 }
 
+class TestControl3 {
+public:
+    TestControl3() = default;
+
+    void DoWork1() {
+        value1 += 1;
+        Fiber::YieldTo(fiber1, fiber2);
+        std::thread::id this_id = std::this_thread::get_id();
+        u32 id = ids[this_id];
+        value3 += 1;
+        Fiber::YieldTo(fiber1, thread_fibers[id]);
+    }
+
+    void DoWork2() {
+        value2 += 1;
+        std::thread::id this_id = std::this_thread::get_id();
+        u32 id = ids[this_id];
+        Fiber::YieldTo(fiber2, thread_fibers[id]);
+    }
+
+    void ExecuteThread(u32 id);
+
+    void CallFiber1() {
+        std::thread::id this_id = std::this_thread::get_id();
+        u32 id = ids[this_id];
+        Fiber::YieldTo(thread_fibers[id], fiber1);
+    }
+
+    void Exit();
+
+    u32 value1{};
+    u32 value2{};
+    u32 value3{};
+    std::unordered_map<std::thread::id, u32> ids;
+    std::vector<std::shared_ptr<Common::Fiber>> thread_fibers;
+    std::shared_ptr<Common::Fiber> fiber1;
+    std::shared_ptr<Common::Fiber> fiber2;
+};
+
+static void WorkControl3_1(void* control) {
+    TestControl3* test_control = static_cast<TestControl3*>(control);
+    test_control->DoWork1();
+}
+
+static void WorkControl3_2(void* control) {
+    TestControl3* test_control = static_cast<TestControl3*>(control);
+    test_control->DoWork2();
+}
+
+void TestControl3::ExecuteThread(u32 id) {
+    std::thread::id this_id = std::this_thread::get_id();
+    ids[this_id] = id;
+    auto thread_fiber = Fiber::ThreadToFiber();
+    thread_fibers[id] = thread_fiber;
+}
+
+void TestControl3::Exit() {
+    std::thread::id this_id = std::this_thread::get_id();
+    u32 id = ids[this_id];
+    thread_fibers[id]->Exit();
+}
+
+static void ThreadStart3(u32 id, TestControl3& test_control) {
+    test_control.ExecuteThread(id);
+    test_control.CallFiber1();
+    test_control.Exit();
+}
+
+/** This test checks for one two threads racing for starting the same fiber.
+ *  It checks execution occured in an ordered manner and by no time there were
+ *  two contexts at the same time.
+ */
+TEST_CASE("Fibers::StartRace", "[common]") {
+    TestControl3 test_control{};
+    test_control.thread_fibers.resize(2, nullptr);
+    test_control.fiber1 = std::make_shared<Fiber>(std::function<void(void*)>{WorkControl3_1}, &test_control);
+    test_control.fiber2 = std::make_shared<Fiber>(std::function<void(void*)>{WorkControl3_2}, &test_control);
+    std::thread thread1(ThreadStart3, 0, std::ref(test_control));
+    std::thread thread2(ThreadStart3, 1, std::ref(test_control));
+    thread1.join();
+    thread2.join();
+    REQUIRE(test_control.value1 == 1);
+    REQUIRE(test_control.value2 == 1);
+    REQUIRE(test_control.value3 == 1);
+}
+
+
 
 } // namespace Common