Fix refilling with proxies and a race condition

This commit is contained in:
Roderick van Domburg 2021-06-22 21:39:38 +02:00
parent eee79f2a1e
commit 3a7843d049
No known key found for this signature in database
GPG key ID: 7076AA781B43EFE6

View file

@ -1,6 +1,8 @@
use hyper::{Body, Request}; use hyper::{Body, Request};
use serde::Deserialize; use serde::Deserialize;
use std::error::Error; use std::error::Error;
use std::hint;
use std::sync::atomic::{AtomicUsize, Ordering};
pub type SocketAddress = (String, u16); pub type SocketAddress = (String, u16);
@ -34,11 +36,22 @@ impl Default for ApResolveData {
component! { component! {
ApResolver : ApResolverInner { ApResolver : ApResolverInner {
data: AccessPoints = AccessPoints::default(), data: AccessPoints = AccessPoints::default(),
spinlock: AtomicUsize = AtomicUsize::new(0),
} }
} }
impl ApResolver { impl ApResolver {
fn split_aps(data: Vec<String>) -> Vec<SocketAddress> { // return a port if a proxy URL and/or a proxy port was specified. This is useful even when
// there is no proxy, but firewalls only allow certain ports (e.g. 443 and not 4070).
fn port_config(&self) -> Option<u16> {
if self.session().config().proxy.is_some() || self.session().config().ap_port.is_some() {
Some(self.session().config().ap_port.unwrap_or(443))
} else {
None
}
}
fn process_data(&self, data: Vec<String>) -> Vec<SocketAddress> {
data.into_iter() data.into_iter()
.filter_map(|ap| { .filter_map(|ap| {
let mut split = ap.rsplitn(2, ':'); let mut split = ap.rsplitn(2, ':');
@ -47,21 +60,16 @@ impl ApResolver {
.expect("rsplitn should not return empty iterator"); .expect("rsplitn should not return empty iterator");
let host = split.next()?.to_owned(); let host = split.next()?.to_owned();
let port: u16 = port.parse().ok()?; let port: u16 = port.parse().ok()?;
if let Some(p) = self.port_config() {
if p != port {
return None;
}
}
Some((host, port)) Some((host, port))
}) })
.collect() .collect()
} }
fn find_ap(&self, data: &[SocketAddress]) -> usize {
match self.session().config().proxy {
Some(_) => data
.iter()
.position(|(_, port)| *port == self.session().config().ap_port.unwrap_or(443))
.expect("No access points available with that proxy port."),
None => 0, // just pick the first one
}
}
async fn try_apresolve(&self) -> Result<ApResolveData, Box<dyn Error>> { async fn try_apresolve(&self) -> Result<ApResolveData, Box<dyn Error>> {
let req = Request::builder() let req = Request::builder()
.method("GET") .method("GET")
@ -77,6 +85,7 @@ impl ApResolver {
async fn apresolve(&self) { async fn apresolve(&self) {
let result = self.try_apresolve().await; let result = self.try_apresolve().await;
self.lock(|inner| { self.lock(|inner| {
let data = match result { let data = match result {
Ok(data) => data, Ok(data) => data,
@ -86,9 +95,9 @@ impl ApResolver {
} }
}; };
inner.data.accesspoint = Self::split_aps(data.accesspoint); inner.data.accesspoint = self.process_data(data.accesspoint);
inner.data.dealer = Self::split_aps(data.dealer); inner.data.dealer = self.process_data(data.dealer);
inner.data.spclient = Self::split_aps(data.spclient); inner.data.spclient = self.process_data(data.spclient);
}) })
} }
@ -101,24 +110,32 @@ impl ApResolver {
} }
pub async fn resolve(&self, endpoint: &str) -> SocketAddress { pub async fn resolve(&self, endpoint: &str) -> SocketAddress {
// Use a spinlock to make this function atomic. Otherwise, various race conditions may
// occur, e.g. when the session is created, multiple components are launched almost in
// parallel and they will all call this function, while resolving is still in progress.
self.lock(|inner| {
while inner.spinlock.load(Ordering::SeqCst) != 0 {
hint::spin_loop()
}
inner.spinlock.store(1, Ordering::SeqCst);
});
if self.is_empty() { if self.is_empty() {
self.apresolve().await; self.apresolve().await;
} }
self.lock(|inner| match endpoint { self.lock(|inner| {
"accesspoint" => { let access_point = match endpoint {
let pos = self.find_ap(&inner.data.accesspoint); // take the first position instead of the last with `pop`, because Spotify returns
inner.data.accesspoint.remove(pos) // access points with ports 4070, 443 and 80 in order of preference from highest
} // to lowest.
"dealer" => { "accesspoint" => inner.data.accesspoint.remove(0),
let pos = self.find_ap(&inner.data.dealer); "dealer" => inner.data.dealer.remove(0),
inner.data.dealer.remove(pos) "spclient" => inner.data.spclient.remove(0),
} _ => unimplemented!(),
"spclient" => { };
let pos = self.find_ap(&inner.data.spclient); inner.spinlock.store(0, Ordering::SeqCst);
inner.data.spclient.remove(pos) access_point
}
_ => unimplemented!(),
}) })
} }
} }