librespot/core/src/dealer/mod.rs

587 lines
16 KiB
Rust
Raw Normal View History

2021-05-22 17:05:13 +00:00
mod maps;
mod protocol;
use std::iter;
use std::pin::Pin;
use std::sync::atomic::AtomicBool;
use std::sync::{atomic, Arc, Mutex};
use std::task::Poll;
use std::time::Duration;
use futures_core::{Future, Stream};
use futures_util::future::join_all;
use futures_util::{SinkExt, StreamExt};
use tokio::select;
use tokio::sync::mpsc::{self, UnboundedReceiver};
use tokio::sync::Semaphore;
use tokio::task::JoinHandle;
use tokio_tungstenite::tungstenite;
use tungstenite::error::UrlError;
use url::Url;
use self::maps::*;
use self::protocol::*;
pub use self::protocol::{Message, Request};
use crate::socket;
use crate::util::{keep_flushing, CancelOnDrop, TimeoutOnDrop};
type WsMessage = tungstenite::Message;
type WsError = tungstenite::Error;
type WsResult<T> = Result<T, tungstenite::Error>;
pub struct Response {
pub success: bool,
}
pub struct Responder {
key: String,
tx: mpsc::UnboundedSender<WsMessage>,
sent: bool,
}
impl Responder {
fn new(key: String, tx: mpsc::UnboundedSender<WsMessage>) -> Self {
Self {
key,
tx,
sent: false,
}
}
// Should only be called once
fn send_internal(&mut self, response: Response) {
let response = serde_json::json!({
"type": "reply",
"key": &self.key,
"payload": {
"success": response.success,
}
})
.to_string();
if let Err(e) = self.tx.send(WsMessage::Text(response)) {
warn!("Wasn't able to reply to dealer request: {}", e);
}
}
pub fn send(mut self, success: Response) {
self.send_internal(success);
self.sent = true;
}
pub fn force_unanswered(mut self) {
self.sent = true;
}
}
impl Drop for Responder {
fn drop(&mut self) {
if !self.sent {
self.send_internal(Response { success: false });
}
}
}
pub trait IntoResponse {
fn respond(self, responder: Responder);
}
impl IntoResponse for Response {
fn respond(self, responder: Responder) {
responder.send(self)
}
}
impl<F> IntoResponse for F
where
F: Future<Output = Response> + Send + 'static,
{
fn respond(self, responder: Responder) {
tokio::spawn(async move {
responder.send(self.await);
});
}
}
impl<F, R> RequestHandler for F
where
F: (Fn(Request<Payload>) -> R) + Send + Sync + 'static,
R: IntoResponse,
{
fn handle_request(&self, request: Request<Payload>, responder: Responder) {
self(request).respond(responder);
}
}
pub trait RequestHandler: Send + Sync + 'static {
fn handle_request(&self, request: Request<Payload>, responder: Responder);
}
type MessageHandler = mpsc::UnboundedSender<Message<JsonValue>>;
// TODO: Maybe it's possible to unregister subscription directly when they
// are dropped instead of on next failed attempt.
pub struct Subscription(UnboundedReceiver<Message<JsonValue>>);
impl Stream for Subscription {
type Item = Message<JsonValue>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
self.0.poll_recv(cx)
}
}
fn split_uri(s: &str) -> Option<impl Iterator<Item = &'_ str>> {
let (scheme, sep, rest) = if let Some(rest) = s.strip_prefix("hm://") {
("hm", '/', rest)
} else if let Some(rest) = s.strip_suffix("spotify:") {
("spotify", ':', rest)
} else {
return None;
};
let rest = rest.trim_end_matches(sep);
let mut split = rest.split(sep);
if rest.is_empty() {
assert_eq!(split.next(), Some(""));
}
Some(iter::once(scheme).chain(split))
}
#[derive(Debug, Clone)]
pub enum AddHandlerError {
AlreadyHandled,
InvalidUri,
}
#[derive(Debug, Clone)]
pub enum SubscriptionError {
InvalidUri,
}
fn add_handler<H>(
map: &mut HandlerMap<Box<dyn RequestHandler>>,
uri: &str,
handler: H,
) -> Result<(), AddHandlerError>
where
H: RequestHandler,
{
let split = split_uri(uri).ok_or(AddHandlerError::InvalidUri)?;
map.insert(split, Box::new(handler))
.map_err(|_| AddHandlerError::AlreadyHandled)
}
fn remove_handler<T>(map: &mut HandlerMap<T>, uri: &str) -> Option<T> {
map.remove(split_uri(uri)?)
}
fn subscribe(
map: &mut SubscriberMap<MessageHandler>,
uris: &[&str],
) -> Result<Subscription, SubscriptionError> {
let (tx, rx) = mpsc::unbounded_channel();
for &uri in uris {
let split = split_uri(uri).ok_or(SubscriptionError::InvalidUri)?;
map.insert(split, tx.clone());
}
Ok(Subscription(rx))
}
#[derive(Default)]
pub struct Builder {
message_handlers: SubscriberMap<MessageHandler>,
request_handlers: HandlerMap<Box<dyn RequestHandler>>,
}
macro_rules! create_dealer {
($builder:expr, $shared:ident -> $body:expr) => {
match $builder {
builder => {
let shared = Arc::new(DealerShared {
message_handlers: Mutex::new(builder.message_handlers),
request_handlers: Mutex::new(builder.request_handlers),
notify_drop: Semaphore::new(0),
});
let handle = {
let $shared = Arc::clone(&shared);
tokio::spawn($body)
};
Dealer {
shared,
handle: TimeoutOnDrop::new(handle, Duration::from_secs(3)),
}
}
}
};
}
impl Builder {
pub fn new() -> Self {
Self::default()
}
pub fn add_handler(
&mut self,
uri: &str,
handler: impl RequestHandler,
) -> Result<(), AddHandlerError> {
add_handler(&mut self.request_handlers, uri, handler)
}
pub fn subscribe(&mut self, uris: &[&str]) -> Result<Subscription, SubscriptionError> {
subscribe(&mut self.message_handlers, uris)
}
pub fn launch_in_background<Fut, F>(self, get_url: F, proxy: Option<Url>) -> Dealer
where
Fut: Future<Output = Url> + Send + 'static,
F: (FnMut() -> Fut) + Send + 'static,
{
create_dealer!(self, shared -> run(shared, None, get_url, proxy))
}
pub async fn launch<Fut, F>(self, mut get_url: F, proxy: Option<Url>) -> WsResult<Dealer>
where
Fut: Future<Output = Url> + Send + 'static,
F: (FnMut() -> Fut) + Send + 'static,
{
let dealer = create_dealer!(self, shared -> {
// Try to connect.
let url = get_url().await;
let tasks = connect(&url, proxy.as_ref(), &shared).await?;
// If a connection is established, continue in a background task.
run(shared, Some(tasks), get_url, proxy)
});
Ok(dealer)
}
}
struct DealerShared {
message_handlers: Mutex<SubscriberMap<MessageHandler>>,
request_handlers: Mutex<HandlerMap<Box<dyn RequestHandler>>>,
// Semaphore with 0 permits. By closing this semaphore, we indicate
// that the actual Dealer struct has been dropped.
notify_drop: Semaphore,
}
impl DealerShared {
fn dispatch_message(&self, msg: Message<JsonValue>) {
if let Some(split) = split_uri(&msg.uri) {
self.message_handlers
.lock()
.unwrap()
.retain(split, &mut |tx| tx.send(msg.clone()).is_ok());
}
}
fn dispatch_request(
&self,
request: Request<Payload>,
send_tx: &mpsc::UnboundedSender<WsMessage>,
) {
// ResponseSender will automatically send "success: false" if it is dropped without an answer.
let responder = Responder::new(request.key.clone(), send_tx.clone());
let split = if let Some(split) = split_uri(&request.message_ident) {
split
} else {
warn!(
"Dealer request with invalid message_ident: {}",
&request.message_ident
);
return;
};
{
let handler_map = self.request_handlers.lock().unwrap();
if let Some(handler) = handler_map.get(split) {
handler.handle_request(request, responder);
return;
}
}
warn!("No handler for message_ident: {}", &request.message_ident);
}
fn dispatch(&self, m: MessageOrRequest, send_tx: &mpsc::UnboundedSender<WsMessage>) {
match m {
MessageOrRequest::Message(m) => self.dispatch_message(m),
MessageOrRequest::Request(r) => self.dispatch_request(r, send_tx),
}
}
async fn closed(&self) {
self.notify_drop.acquire().await.unwrap_err();
}
fn is_closed(&self) -> bool {
self.notify_drop.is_closed()
}
}
pub struct Dealer {
shared: Arc<DealerShared>,
handle: TimeoutOnDrop<()>,
}
impl Dealer {
pub fn add_handler<H>(&self, uri: &str, handler: H) -> Result<(), AddHandlerError>
where
H: RequestHandler,
{
add_handler(
&mut self.shared.request_handlers.lock().unwrap(),
uri,
handler,
)
}
pub fn remove_handler(&self, uri: &str) -> Option<Box<dyn RequestHandler>> {
remove_handler(&mut self.shared.request_handlers.lock().unwrap(), uri)
}
pub fn subscribe(&self, uris: &[&str]) -> Result<Subscription, SubscriptionError> {
subscribe(&mut self.shared.message_handlers.lock().unwrap(), uris)
}
pub async fn close(mut self) {
debug!("closing dealer");
self.shared.notify_drop.close();
if let Some(handle) = self.handle.take() {
CancelOnDrop(handle).await.unwrap();
}
}
}
/// Initializes a connection and returns futures that will finish when the connection is closed/lost.
async fn connect(
address: &Url,
proxy: Option<&Url>,
shared: &Arc<DealerShared>,
) -> WsResult<(JoinHandle<()>, JoinHandle<()>)> {
let host = address
.host_str()
.ok_or(WsError::Url(UrlError::NoHostName))?;
let default_port = match address.scheme() {
"ws" => 80,
"wss" => 443,
_ => return Err(WsError::Url(UrlError::UnsupportedUrlScheme)),
};
let port = address.port().unwrap_or(default_port);
let stream = socket::connect(host, port, proxy).await?;
let (mut ws_tx, ws_rx) = tokio_tungstenite::client_async_tls(address, stream)
.await?
.0
.split();
let (send_tx, mut send_rx) = mpsc::unbounded_channel::<WsMessage>();
// Spawn a task that will forward messages from the channel to the websocket.
let send_task = {
let shared = Arc::clone(&shared);
tokio::spawn(async move {
let result = loop {
select! {
biased;
() = shared.closed() => {
break Ok(None);
}
msg = send_rx.recv() => {
if let Some(msg) = msg {
// New message arrived through channel
if let WsMessage::Close(close_frame) = msg {
break Ok(close_frame);
}
if let Err(e) = ws_tx.feed(msg).await {
break Err(e);
}
} else {
break Ok(None);
}
},
e = keep_flushing(&mut ws_tx) => {
break Err(e)
}
}
};
send_rx.close();
// I don't trust in tokio_tungstenite's implementation of Sink::close.
let result = match result {
Ok(close_frame) => ws_tx.send(WsMessage::Close(close_frame)).await,
Err(WsError::AlreadyClosed) | Err(WsError::ConnectionClosed) => ws_tx.flush().await,
Err(e) => {
warn!("Dealer finished with an error: {}", e);
ws_tx.send(WsMessage::Close(None)).await
}
};
if let Err(e) = result {
warn!("Error while closing websocket: {}", e);
}
debug!("Dropping send task");
})
};
let shared = Arc::clone(&shared);
// A task that receives messages from the web socket.
let receive_task = tokio::spawn(async {
let pong_received = AtomicBool::new(true);
let send_tx = send_tx;
let shared = shared;
let receive_task = async {
let mut ws_rx = ws_rx;
loop {
match ws_rx.next().await {
Some(Ok(msg)) => match msg {
WsMessage::Text(t) => match serde_json::from_str(&t) {
Ok(m) => shared.dispatch(m, &send_tx),
Err(e) => info!("Received invalid message: {}", e),
},
WsMessage::Binary(_) => {
info!("Received invalid binary message");
}
WsMessage::Pong(_) => {
debug!("Received pong");
pong_received.store(true, atomic::Ordering::Relaxed);
}
_ => (), // tungstenite handles Close and Ping automatically
},
Some(Err(e)) => {
warn!("Websocket connection failed: {}", e);
break;
}
None => {
debug!("Websocket connection closed.");
break;
}
}
}
};
// Sends pings and checks whether a pong comes back.
let ping_task = async {
use tokio::time::{interval, sleep};
let mut timer = interval(Duration::from_secs(30));
loop {
timer.tick().await;
pong_received.store(false, atomic::Ordering::Relaxed);
if send_tx.send(WsMessage::Ping(vec![])).is_err() {
// The sender is closed.
break;
}
debug!("Sent ping");
sleep(Duration::from_secs(3)).await;
if !pong_received.load(atomic::Ordering::SeqCst) {
// No response
warn!("Websocket peer does not respond.");
break;
}
}
};
// Exit this task as soon as one our subtasks fails.
// In both cases the connection is probably lost.
select! {
() = ping_task => (),
() = receive_task => ()
}
// Try to take send_task down with us, in case it's still alive.
let _ = send_tx.send(WsMessage::Close(None));
debug!("Dropping receive task");
});
Ok((send_task, receive_task))
}
/// The main background task for `Dealer`, which coordinates reconnecting.
async fn run<F, Fut>(
shared: Arc<DealerShared>,
initial_tasks: Option<(JoinHandle<()>, JoinHandle<()>)>,
mut get_url: F,
proxy: Option<Url>,
) where
Fut: Future<Output = Url> + Send + 'static,
F: (FnMut() -> Fut) + Send + 'static,
{
let init_task = |t| Some(TimeoutOnDrop::new(t, Duration::from_secs(3)));
let mut tasks = if let Some((s, r)) = initial_tasks {
(init_task(s), init_task(r))
} else {
(None, None)
};
while !shared.is_closed() {
match &mut tasks {
(Some(t0), Some(t1)) => {
select! {
() = shared.closed() => break,
r = t0 => {
r.unwrap(); // Whatever has gone wrong (probably panicked), we can't handle it, so let's panic too.
tasks.0.take();
},
r = t1 => {
r.unwrap();
tasks.1.take();
}
}
}
_ => {
let url = select! {
() = shared.closed() => {
break
},
e = get_url() => e
};
match connect(&url, proxy.as_ref(), &shared).await {
Ok((s, r)) => tasks = (init_task(s), init_task(r)),
Err(e) => {
warn!("Error while connecting: {}", e);
}
}
}
}
}
let tasks = tasks.0.into_iter().chain(tasks.1);
let _ = join_all(tasks).await;
}