Refactor AudioFileFetch using async/await

Previously, polling `AudioFileFetch` consisted of three parts: Handling
stream loader commands, handling received data, and triggering preloading
in stream mode when the number of open requests is sufficiently small. The
first steps use channels which are polled, and if something's available,
it's handled. The third step is executed on every call of `poll`.

The first two could easily be refactored using a `tokio::select!`-loop.
Therefore, counting the number of open requests was also refactored to fit
into this scheme. They were previously counted using a shared
`AtomicUsize`. Now, the number of open requests is stored exclusively in
`AudioFileFetch`, increased on starting a request, and decreased by an
oneshot channel that is fired when a request is finished.

This allows us to `select` that channel in the loop too, and since
loading ahead makes only sense if the number of open requests decreases,
the third step is only executed in this case.

`AudioFileFetch` does not implement `Future` anymore, but is rather used
as helper struct in an async fn `audio_file_fetch`.
This commit is contained in:
johannesd3 2021-02-28 11:36:14 +01:00
parent 173a36332f
commit e71a004e93

View file

@ -1,11 +1,8 @@
use std::cmp::{max, min}; use std::cmp::{max, min};
use std::fs; use std::fs;
use std::future::Future;
use std::io::{self, Read, Seek, SeekFrom, Write}; use std::io::{self, Read, Seek, SeekFrom, Write};
use std::pin::Pin;
use std::sync::atomic::{self, AtomicUsize}; use std::sync::atomic::{self, AtomicUsize};
use std::sync::{Arc, Condvar, Mutex}; use std::sync::{Arc, Condvar, Mutex};
use std::task::{Context, Poll};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use byteorder::{BigEndian, ByteOrder, WriteBytesExt}; use byteorder::{BigEndian, ByteOrder, WriteBytesExt};
@ -236,7 +233,7 @@ struct AudioFileDownloadStatus {
downloaded: RangeSet, downloaded: RangeSet,
} }
#[derive(Copy, Clone)] #[derive(Copy, Clone, PartialEq, Eq)]
enum DownloadStrategy { enum DownloadStrategy {
RandomAccess(), RandomAccess(),
Streaming(), Streaming(),
@ -249,7 +246,6 @@ struct AudioFileShared {
cond: Condvar, cond: Condvar,
download_status: Mutex<AudioFileDownloadStatus>, download_status: Mutex<AudioFileDownloadStatus>,
download_strategy: Mutex<DownloadStrategy>, download_strategy: Mutex<DownloadStrategy>,
number_of_open_requests: AtomicUsize,
ping_time_ms: AtomicUsize, ping_time_ms: AtomicUsize,
read_position: AtomicUsize, read_position: AtomicUsize,
} }
@ -358,7 +354,6 @@ impl AudioFileStreaming {
downloaded: RangeSet::new(), downloaded: RangeSet::new(),
}), }),
download_strategy: Mutex::new(DownloadStrategy::RandomAccess()), // start with random access mode until someone tells us otherwise download_strategy: Mutex::new(DownloadStrategy::RandomAccess()), // start with random access mode until someone tells us otherwise
number_of_open_requests: AtomicUsize::new(0),
ping_time_ms: AtomicUsize::new(0), ping_time_ms: AtomicUsize::new(0),
read_position: AtomicUsize::new(0), read_position: AtomicUsize::new(0),
}); });
@ -373,7 +368,7 @@ impl AudioFileStreaming {
let (stream_loader_command_tx, stream_loader_command_rx) = let (stream_loader_command_tx, stream_loader_command_rx) =
mpsc::unbounded_channel::<StreamLoaderCommand>(); mpsc::unbounded_channel::<StreamLoaderCommand>();
let fetcher = AudioFileFetch::new( session.spawn(audio_file_fetch(
session.clone(), session.clone(),
shared.clone(), shared.clone(),
initial_data_rx, initial_data_rx,
@ -382,9 +377,8 @@ impl AudioFileStreaming {
write_file, write_file,
stream_loader_command_rx, stream_loader_command_rx,
complete_tx, complete_tx,
); ));
session.spawn(fetcher);
Ok(AudioFileStreaming { Ok(AudioFileStreaming {
read_file, read_file,
position: 0, position: 0,
@ -442,17 +436,11 @@ async fn audio_file_fetch_receive_data(
initial_data_offset: usize, initial_data_offset: usize,
initial_request_length: usize, initial_request_length: usize,
request_sent_time: Instant, request_sent_time: Instant,
mut measure_ping_time: bool,
finish_tx: mpsc::UnboundedSender<()>,
) { ) {
let mut data_offset = initial_data_offset; let mut data_offset = initial_data_offset;
let mut request_length = initial_request_length; let mut request_length = initial_request_length;
let mut measure_ping_time = shared
.number_of_open_requests
.load(atomic::Ordering::SeqCst)
== 0;
shared
.number_of_open_requests
.fetch_add(1, atomic::Ordering::SeqCst);
let result = loop { let result = loop {
let data = match data_rx.next().await { let data = match data_rx.next().await {
@ -501,9 +489,7 @@ async fn audio_file_fetch_receive_data(
shared.cond.notify_all(); shared.cond.notify_all();
} }
shared let _ = finish_tx.send(());
.number_of_open_requests
.fetch_sub(1, atomic::Ordering::SeqCst);
if result.is_err() { if result.is_err() {
warn!( warn!(
@ -517,162 +503,6 @@ async fn audio_file_fetch_receive_data(
); );
} }
} }
/*
async fn audio_file_fetch(
session: Session,
shared: Arc<AudioFileShared>,
initial_data_rx: ChannelData,
initial_request_sent_time: Instant,
initial_data_length: usize,
output: NamedTempFile,
stream_loader_command_rx: mpsc::UnboundedReceiver<StreamLoaderCommand>,
complete_tx: oneshot::Sender<NamedTempFile>,
) {
let (file_data_tx, file_data_rx) = unbounded::<ReceivedData>();
let requested_range = Range::new(0, initial_data_length);
let mut download_status = shared.download_status.lock().unwrap();
download_status.requested.add_range(&requested_range);
session.spawn(audio_file_fetch_receive_data(
shared.clone(),
file_data_tx.clone(),
initial_data_rx,
0,
initial_data_length,
initial_request_sent_time,
));
let mut network_response_times_ms: Vec::new();
let f1 = file_data_rx.map(|x| Ok::<_, ()>(x)).try_for_each(|x| {
match x {
ReceivedData::ResponseTimeMs(response_time_ms) => {
trace!("Ping time estimated as: {} ms.", response_time_ms);
// record the response time
network_response_times_ms.push(response_time_ms);
// prune old response times. Keep at most three.
while network_response_times_ms.len() > 3 {
network_response_times_ms.remove(0);
}
// stats::median is experimental. So we calculate the median of up to three ourselves.
let ping_time_ms: usize = match network_response_times_ms.len() {
1 => network_response_times_ms[0] as usize,
2 => {
((network_response_times_ms[0] + network_response_times_ms[1]) / 2) as usize
}
3 => {
let mut times = network_response_times_ms.clone();
times.sort();
times[1]
}
_ => unreachable!(),
};
// store our new estimate for everyone to see
shared
.ping_time_ms
.store(ping_time_ms, atomic::Ordering::Relaxed);
}
ReceivedData::Data(data) => {
output
.as_mut()
.unwrap()
.seek(SeekFrom::Start(data.offset as u64))
.unwrap();
output
.as_mut()
.unwrap()
.write_all(data.data.as_ref())
.unwrap();
let mut full = false;
{
let mut download_status = shared.download_status.lock().unwrap();
let received_range = Range::new(data.offset, data.data.len());
download_status.downloaded.add_range(&received_range);
shared.cond.notify_all();
if download_status.downloaded.contained_length_from_value(0)
>= shared.file_size
{
full = true;
}
drop(download_status);
}
if full {
self.finish();
return future::ready(Err(()));
}
}
}
future::ready(Ok(()))
});
let f2 = stream_loader_command_rx.map(Ok::<_, ()>).try_for_each(|x| {
match cmd {
StreamLoaderCommand::Fetch(request) => {
self.download_range(request.start, request.length);
}
StreamLoaderCommand::RandomAccessMode() => {
*(shared.download_strategy.lock().unwrap()) = DownloadStrategy::RandomAccess();
}
StreamLoaderCommand::StreamMode() => {
*(shared.download_strategy.lock().unwrap()) = DownloadStrategy::Streaming();
}
StreamLoaderCommand::Close() => return future::ready(Err(())),
}
Ok(())
});
let f3 = future::poll_fn(|_| {
if let DownloadStrategy::Streaming() = self.get_download_strategy() {
let number_of_open_requests = shared
.number_of_open_requests
.load(atomic::Ordering::SeqCst);
let max_requests_to_send =
MAX_PREFETCH_REQUESTS - min(MAX_PREFETCH_REQUESTS, number_of_open_requests);
if max_requests_to_send > 0 {
let bytes_pending: usize = {
let download_status = shared.download_status.lock().unwrap();
download_status
.requested
.minus(&download_status.downloaded)
.len()
};
let ping_time_seconds =
0.001 * shared.ping_time_ms.load(atomic::Ordering::Relaxed) as f64;
let download_rate = session.channel().get_download_rate_estimate();
let desired_pending_bytes = max(
(PREFETCH_THRESHOLD_FACTOR * ping_time_seconds * shared.stream_data_rate as f64)
as usize,
(FAST_PREFETCH_THRESHOLD_FACTOR * ping_time_seconds * download_rate as f64)
as usize,
);
if bytes_pending < desired_pending_bytes {
self.pre_fetch_more_data(
desired_pending_bytes - bytes_pending,
max_requests_to_send,
);
}
}
}
Poll::Pending
});
future::select_all(vec![f1, f2, f3]).await
}*/
struct AudioFileFetch { struct AudioFileFetch {
session: Session, session: Session,
@ -680,54 +510,21 @@ struct AudioFileFetch {
output: Option<NamedTempFile>, output: Option<NamedTempFile>,
file_data_tx: mpsc::UnboundedSender<ReceivedData>, file_data_tx: mpsc::UnboundedSender<ReceivedData>,
file_data_rx: mpsc::UnboundedReceiver<ReceivedData>,
stream_loader_command_rx: mpsc::UnboundedReceiver<StreamLoaderCommand>,
complete_tx: Option<oneshot::Sender<NamedTempFile>>, complete_tx: Option<oneshot::Sender<NamedTempFile>>,
network_response_times_ms: Vec<usize>, network_response_times_ms: Vec<usize>,
number_of_open_requests: usize,
download_finish_tx: mpsc::UnboundedSender<()>,
}
// Might be replaced by enum from std once stable
#[derive(PartialEq, Eq)]
enum ControlFlow {
Break,
Continue,
} }
impl AudioFileFetch { impl AudioFileFetch {
fn new(
session: Session,
shared: Arc<AudioFileShared>,
initial_data_rx: ChannelData,
initial_request_sent_time: Instant,
initial_data_length: usize,
output: NamedTempFile,
stream_loader_command_rx: mpsc::UnboundedReceiver<StreamLoaderCommand>,
complete_tx: oneshot::Sender<NamedTempFile>,
) -> AudioFileFetch {
let (file_data_tx, file_data_rx) = mpsc::unbounded_channel::<ReceivedData>();
{
let requested_range = Range::new(0, initial_data_length);
let mut download_status = shared.download_status.lock().unwrap();
download_status.requested.add_range(&requested_range);
}
session.spawn(audio_file_fetch_receive_data(
shared.clone(),
file_data_tx.clone(),
initial_data_rx,
0,
initial_data_length,
initial_request_sent_time,
));
AudioFileFetch {
session,
shared,
output: Some(output),
file_data_tx,
file_data_rx,
stream_loader_command_rx,
complete_tx: Some(complete_tx),
network_response_times_ms: Vec::new(),
}
}
fn get_download_strategy(&mut self) -> DownloadStrategy { fn get_download_strategy(&mut self) -> DownloadStrategy {
*(self.shared.download_strategy.lock().unwrap()) *(self.shared.download_strategy.lock().unwrap())
} }
@ -785,7 +582,11 @@ impl AudioFileFetch {
range.start, range.start,
range.length, range.length,
Instant::now(), Instant::now(),
self.number_of_open_requests == 0,
self.download_finish_tx.clone(),
)); ));
self.number_of_open_requests += 1;
} }
} }
@ -833,103 +634,86 @@ impl AudioFileFetch {
} }
} }
fn poll_file_data_rx(&mut self, cx: &mut Context<'_>) -> Poll<()> { fn handle_file_data(&mut self, data: ReceivedData) -> ControlFlow {
loop { match data {
match self.file_data_rx.poll_recv(cx) { ReceivedData::ResponseTimeMs(response_time_ms) => {
Poll::Ready(None) => return Poll::Ready(()), trace!("Ping time estimated as: {} ms.", response_time_ms);
Poll::Ready(Some(ReceivedData::ResponseTimeMs(response_time_ms))) => {
trace!("Ping time estimated as: {} ms.", response_time_ms);
// record the response time // record the response time
self.network_response_times_ms.push(response_time_ms); self.network_response_times_ms.push(response_time_ms);
// prune old response times. Keep at most three. // prune old response times. Keep at most three.
while self.network_response_times_ms.len() > 3 { while self.network_response_times_ms.len() > 3 {
self.network_response_times_ms.remove(0); self.network_response_times_ms.remove(0);
}
// stats::median is experimental. So we calculate the median of up to three ourselves.
let ping_time_ms: usize = match self.network_response_times_ms.len() {
1 => self.network_response_times_ms[0] as usize,
2 => {
((self.network_response_times_ms[0]
+ self.network_response_times_ms[1])
/ 2) as usize
}
3 => {
let mut times = self.network_response_times_ms.clone();
times.sort_unstable();
times[1]
}
_ => unreachable!(),
};
// store our new estimate for everyone to see
self.shared
.ping_time_ms
.store(ping_time_ms, atomic::Ordering::Relaxed);
} }
Poll::Ready(Some(ReceivedData::Data(data))) => {
self.output
.as_mut()
.unwrap()
.seek(SeekFrom::Start(data.offset as u64))
.unwrap();
self.output
.as_mut()
.unwrap()
.write_all(data.data.as_ref())
.unwrap();
let mut full = false; // stats::median is experimental. So we calculate the median of up to three ourselves.
let ping_time_ms: usize = match self.network_response_times_ms.len() {
{ 1 => self.network_response_times_ms[0] as usize,
let mut download_status = self.shared.download_status.lock().unwrap(); 2 => {
((self.network_response_times_ms[0] + self.network_response_times_ms[1])
let received_range = Range::new(data.offset, data.data.len()); / 2) as usize
download_status.downloaded.add_range(&received_range);
self.shared.cond.notify_all();
if download_status.downloaded.contained_length_from_value(0)
>= self.shared.file_size
{
full = true;
}
drop(download_status);
} }
3 => {
if full { let mut times = self.network_response_times_ms.clone();
self.finish(); times.sort_unstable();
return Poll::Ready(()); times[1]
} }
_ => unreachable!(),
};
// store our new estimate for everyone to see
self.shared
.ping_time_ms
.store(ping_time_ms, atomic::Ordering::Relaxed);
}
ReceivedData::Data(data) => {
self.output
.as_mut()
.unwrap()
.seek(SeekFrom::Start(data.offset as u64))
.unwrap();
self.output
.as_mut()
.unwrap()
.write_all(data.data.as_ref())
.unwrap();
let mut download_status = self.shared.download_status.lock().unwrap();
let received_range = Range::new(data.offset, data.data.len());
download_status.downloaded.add_range(&received_range);
self.shared.cond.notify_all();
let full = download_status.downloaded.contained_length_from_value(0)
>= self.shared.file_size;
drop(download_status);
if full {
self.finish();
return ControlFlow::Break;
} }
Poll::Pending => return Poll::Pending,
} }
} }
ControlFlow::Continue
} }
fn poll_stream_loader_command_rx(&mut self, cx: &mut Context<'_>) -> Poll<()> { fn handle_stream_loader_command(&mut self, cmd: StreamLoaderCommand) -> ControlFlow {
loop { match cmd {
match self.stream_loader_command_rx.poll_recv(cx) { StreamLoaderCommand::Fetch(request) => {
Poll::Ready(None) => return Poll::Ready(()), self.download_range(request.start, request.length);
Poll::Ready(Some(cmd)) => match cmd {
StreamLoaderCommand::Fetch(request) => {
self.download_range(request.start, request.length);
}
StreamLoaderCommand::RandomAccessMode() => {
*(self.shared.download_strategy.lock().unwrap()) =
DownloadStrategy::RandomAccess();
}
StreamLoaderCommand::StreamMode() => {
*(self.shared.download_strategy.lock().unwrap()) =
DownloadStrategy::Streaming();
}
StreamLoaderCommand::Close() => return Poll::Ready(()),
},
Poll::Pending => return Poll::Pending,
} }
StreamLoaderCommand::RandomAccessMode() => {
*(self.shared.download_strategy.lock().unwrap()) = DownloadStrategy::RandomAccess();
}
StreamLoaderCommand::StreamMode() => {
*(self.shared.download_strategy.lock().unwrap()) = DownloadStrategy::Streaming();
self.trigger_preload();
}
StreamLoaderCommand::Close() => return ControlFlow::Break,
} }
ControlFlow::Continue
} }
fn finish(&mut self) { fn finish(&mut self) {
@ -939,57 +723,102 @@ impl AudioFileFetch {
output.seek(SeekFrom::Start(0)).unwrap(); output.seek(SeekFrom::Start(0)).unwrap();
let _ = complete_tx.send(output); let _ = complete_tx.send(output);
} }
fn trigger_preload(&mut self) {
if self.number_of_open_requests >= MAX_PREFETCH_REQUESTS {
return;
}
let max_requests_to_send = MAX_PREFETCH_REQUESTS - self.number_of_open_requests;
let bytes_pending: usize = {
let download_status = self.shared.download_status.lock().unwrap();
download_status
.requested
.minus(&download_status.downloaded)
.len()
};
let ping_time_seconds =
0.001 * self.shared.ping_time_ms.load(atomic::Ordering::Relaxed) as f64;
let download_rate = self.session.channel().get_download_rate_estimate();
let desired_pending_bytes = max(
(PREFETCH_THRESHOLD_FACTOR * ping_time_seconds * self.shared.stream_data_rate as f64)
as usize,
(FAST_PREFETCH_THRESHOLD_FACTOR * ping_time_seconds * download_rate as f64) as usize,
);
if bytes_pending < desired_pending_bytes {
self.pre_fetch_more_data(desired_pending_bytes - bytes_pending, max_requests_to_send);
}
}
} }
impl Future for AudioFileFetch {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { async fn audio_file_fetch(
if let Poll::Ready(()) = self.poll_stream_loader_command_rx(cx) { session: Session,
return Poll::Ready(()); shared: Arc<AudioFileShared>,
} initial_data_rx: ChannelData,
initial_request_sent_time: Instant,
initial_data_length: usize,
if let Poll::Ready(()) = self.poll_file_data_rx(cx) { output: NamedTempFile,
return Poll::Ready(()); mut stream_loader_command_rx: mpsc::UnboundedReceiver<StreamLoaderCommand>,
} complete_tx: oneshot::Sender<NamedTempFile>,
) {
let (file_data_tx, mut file_data_rx) = mpsc::unbounded_channel();
let (download_finish_tx, mut download_finish_rx) = mpsc::unbounded_channel();
if let DownloadStrategy::Streaming() = self.get_download_strategy() { {
let number_of_open_requests = self let requested_range = Range::new(0, initial_data_length);
.shared let mut download_status = shared.download_status.lock().unwrap();
.number_of_open_requests download_status.requested.add_range(&requested_range);
.load(atomic::Ordering::SeqCst); }
let max_requests_to_send =
MAX_PREFETCH_REQUESTS - min(MAX_PREFETCH_REQUESTS, number_of_open_requests);
if max_requests_to_send > 0 { session.spawn(audio_file_fetch_receive_data(
let bytes_pending: usize = { shared.clone(),
let download_status = self.shared.download_status.lock().unwrap(); file_data_tx.clone(),
download_status initial_data_rx,
.requested 0,
.minus(&download_status.downloaded) initial_data_length,
.len() initial_request_sent_time,
}; true,
download_finish_tx.clone(),
));
let ping_time_seconds = let mut fetch = AudioFileFetch {
0.001 * self.shared.ping_time_ms.load(atomic::Ordering::Relaxed) as f64; session,
let download_rate = self.session.channel().get_download_rate_estimate(); shared,
output: Some(output),
let desired_pending_bytes = max( file_data_tx,
(PREFETCH_THRESHOLD_FACTOR complete_tx: Some(complete_tx),
* ping_time_seconds network_response_times_ms: Vec::new(),
* self.shared.stream_data_rate as f64) as usize, number_of_open_requests: 1,
(FAST_PREFETCH_THRESHOLD_FACTOR * ping_time_seconds * download_rate as f64)
as usize,
);
if bytes_pending < desired_pending_bytes { download_finish_tx,
self.pre_fetch_more_data( };
desired_pending_bytes - bytes_pending,
max_requests_to_send, loop {
); tokio::select! {
cmd = stream_loader_command_rx.recv() => {
if cmd.map_or(true, |cmd| fetch.handle_stream_loader_command(cmd) == ControlFlow::Break) {
break;
}
},
data = file_data_rx.recv() => {
if data.map_or(true, |data| fetch.handle_file_data(data) == ControlFlow::Break) {
break;
}
},
_ = download_finish_rx.recv() => {
fetch.number_of_open_requests -= 1;
if fetch.get_download_strategy() == DownloadStrategy::Streaming() {
fetch.trigger_preload();
} }
} }
} }
Poll::Pending
} }
} }