From e01c5351781f7c0a8f5e6469a33563f26f8f0779 Mon Sep 17 00:00:00 2001
From: Liam <byteslice@airmail.cc>
Date: Sun, 17 Dec 2023 02:01:25 -0500
Subject: [PATCH] oboe_sink: implement channel count querying

---
 src/audio_core/sink/oboe_sink.cpp | 47 +++++++++++++++++++++++--------
 1 file changed, 35 insertions(+), 12 deletions(-)

diff --git a/src/audio_core/sink/oboe_sink.cpp b/src/audio_core/sink/oboe_sink.cpp
index fc62faaeee..c502205e1d 100644
--- a/src/audio_core/sink/oboe_sink.cpp
+++ b/src/audio_core/sink/oboe_sink.cpp
@@ -20,11 +20,10 @@ class OboeSinkStream final : public SinkStream,
                              public oboe::AudioStreamErrorCallback {
 public:
     explicit OboeSinkStream(Core::System& system_, StreamType type_, const std::string& name_,
-                            u32 device_channels_, u32 system_channels_)
+                            u32 system_channels_)
         : SinkStream(system_, type_) {
         name = name_;
         system_channels = system_channels_;
-        device_channels = device_channels_;
 
         this->OpenStream();
     }
@@ -62,6 +61,21 @@ public:
         }
     }
 
+public:
+    static s32 QueryChannelCount(oboe::Direction direction) {
+        std::shared_ptr<oboe::AudioStream> temp_stream;
+        oboe::AudioStreamBuilder builder;
+
+        const auto result = builder.setDirection(direction)
+                                ->setSampleRate(TargetSampleRate)
+                                ->setFormat(oboe::AudioFormat::I16)
+                                ->setFormatConversionAllowed(true)
+                                ->openStream(temp_stream);
+        ASSERT(result == oboe::Result::OK);
+
+        return temp_stream->getChannelCount() >= 6 ? 6 : 2;
+    }
+
 protected:
     oboe::DataCallbackResult onAudioReady(oboe::AudioStream*, void* audio_data,
                                           s32 num_buffer_frames) override {
@@ -105,8 +119,9 @@ private:
             }
         }();
 
-        const auto channel_mask = [&]() {
-            switch (device_channels) {
+        const auto expected_channels = QueryChannelCount(direction);
+        const auto expected_mask = [&]() {
+            switch (expected_channels) {
             case 1:
                 return oboe::ChannelMask::Mono;
             case 2:
@@ -122,25 +137,33 @@ private:
         oboe::AudioStreamBuilder builder;
         const auto result = builder.setDirection(direction)
                                 ->setSampleRate(TargetSampleRate)
-                                ->setChannelCount(device_channels)
-                                ->setChannelMask(channel_mask)
+                                ->setChannelCount(expected_channels)
+                                ->setChannelMask(expected_mask)
                                 ->setFormat(oboe::AudioFormat::I16)
                                 ->setFormatConversionAllowed(true)
                                 ->setDataCallback(this)
                                 ->setErrorCallback(this)
                                 ->openStream(m_stream);
-
         ASSERT(result == oboe::Result::OK);
-        return result == oboe::Result::OK;
+        return result == oboe::Result::OK && this->SetStreamProperties();
+    }
+
+    bool SetStreamProperties() {
+        ASSERT(m_stream);
+
+        device_channels = m_stream->getChannelCount();
+        LOG_INFO(Audio_Sink, "Opened Oboe stream with {} channels", device_channels);
+
+        return true;
     }
 
     std::shared_ptr<oboe::AudioStream> m_stream{};
 };
 
 OboeSink::OboeSink() {
-    // TODO: how do we get the number of channels, or device list?
-    // This seems to be missing from NDK.
-    device_channels = 2;
+    // TODO: This is not generally knowable
+    // The channel count is distinct based on direction and can change
+    device_channels = OboeSinkStream::QueryChannelCount(oboe::Direction::Output);
 }
 
 OboeSink::~OboeSink() = default;
@@ -148,7 +171,7 @@ OboeSink::~OboeSink() = default;
 SinkStream* OboeSink::AcquireSinkStream(Core::System& system, u32 system_channels,
                                         const std::string& name, StreamType type) {
     SinkStreamPtr& stream = sink_streams.emplace_back(
-        std::make_unique<OboeSinkStream>(system, type, name, device_channels, system_channels));
+        std::make_unique<OboeSinkStream>(system, type, name, system_channels));
 
     return stream.get();
 }