Add basic websocket support

This commit is contained in:
johannesd3 2021-05-22 19:05:13 +02:00
parent 08ba3ad7d7
commit 1ade02b7ad
No known key found for this signature in database
GPG key ID: 8C2739E91D410F75
11 changed files with 1040 additions and 68 deletions

136
Cargo.lock generated
View file

@ -918,6 +918,15 @@ dependencies = [
"hashbrown", "hashbrown",
] ]
[[package]]
name = "input_buffer"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f97967975f448f1a7ddb12b0bc41069d09ed6a1c161a92687e057325db35d413"
dependencies = [
"bytes",
]
[[package]] [[package]]
name = "instant" name = "instant"
version = "0.1.9" version = "0.1.9"
@ -1229,6 +1238,7 @@ dependencies = [
"thiserror", "thiserror",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tokio-tungstenite",
"tokio-util", "tokio-util",
"url", "url",
"uuid", "uuid",
@ -1911,6 +1921,21 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "ring"
version = "0.16.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc"
dependencies = [
"cc",
"libc",
"once_cell",
"spin",
"untrusted",
"web-sys",
"winapi",
]
[[package]] [[package]]
name = "rodio" name = "rodio"
version = "0.14.0" version = "0.14.0"
@ -1945,6 +1970,19 @@ dependencies = [
"semver", "semver",
] ]
[[package]]
name = "rustls"
version = "0.19.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "35edb675feee39aec9c99fa5ff985081995a06d594114ae14cbe797ad7b7a6d7"
dependencies = [
"base64",
"log",
"ring",
"sct",
"webpki",
]
[[package]] [[package]]
name = "ryu" name = "ryu"
version = "1.0.5" version = "1.0.5"
@ -1966,6 +2004,16 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
[[package]]
name = "sct"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b362b83898e0e69f38515b82ee15aa80636befe47c3b6d3d89a911e78fc228ce"
dependencies = [
"ring",
"untrusted",
]
[[package]] [[package]]
name = "sdl2" name = "sdl2"
version = "0.34.5" version = "0.34.5"
@ -2103,6 +2151,12 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "spin"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
[[package]] [[package]]
name = "stdweb" name = "stdweb"
version = "0.1.3" version = "0.1.3"
@ -2275,6 +2329,17 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "tokio-rustls"
version = "0.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc6844de72e57df1980054b38be3a9f4702aba4858be64dd700181a8a6d0e1b6"
dependencies = [
"rustls",
"tokio",
"webpki",
]
[[package]] [[package]]
name = "tokio-stream" name = "tokio-stream"
version = "0.1.5" version = "0.1.5"
@ -2286,6 +2351,23 @@ dependencies = [
"tokio", "tokio",
] ]
[[package]]
name = "tokio-tungstenite"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e96bb520beab540ab664bd5a9cfeaa1fcd846fa68c830b42e2c8963071251d2"
dependencies = [
"futures-util",
"log",
"pin-project",
"rustls",
"tokio",
"tokio-rustls",
"tungstenite",
"webpki",
"webpki-roots",
]
[[package]] [[package]]
name = "tokio-util" name = "tokio-util"
version = "0.6.6" version = "0.6.6"
@ -2341,6 +2423,29 @@ version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642" checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642"
[[package]]
name = "tungstenite"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5fe8dada8c1a3aeca77d6b51a4f1314e0f4b8e438b7b1b71e3ddaca8080e4093"
dependencies = [
"base64",
"byteorder",
"bytes",
"http",
"httparse",
"input_buffer",
"log",
"rand",
"rustls",
"sha-1",
"thiserror",
"url",
"utf-8",
"webpki",
"webpki-roots",
]
[[package]] [[package]]
name = "typenum" name = "typenum"
version = "1.13.0" version = "1.13.0"
@ -2389,6 +2494,12 @@ version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3" checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3"
[[package]]
name = "untrusted"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a"
[[package]] [[package]]
name = "url" name = "url"
version = "2.2.2" version = "2.2.2"
@ -2401,6 +2512,12 @@ dependencies = [
"percent-encoding", "percent-encoding",
] ]
[[package]]
name = "utf-8"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
[[package]] [[package]]
name = "uuid" name = "uuid"
version = "0.8.2" version = "0.8.2"
@ -2561,6 +2678,25 @@ dependencies = [
"wasm-bindgen", "wasm-bindgen",
] ]
[[package]]
name = "webpki"
version = "0.21.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8e38c0608262c46d4a56202ebabdeb094cef7e560ca7a226c6bf055188aa4ea"
dependencies = [
"ring",
"untrusted",
]
[[package]]
name = "webpki-roots"
version = "0.21.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aabe153544e473b775453675851ecc86863d2a81d786d741f6b76778f2a48940"
dependencies = [
"webpki",
]
[[package]] [[package]]
name = "winapi" name = "winapi"
version = "0.3.9" version = "0.3.9"

View file

@ -39,8 +39,9 @@ serde_json = "1.0"
sha-1 = "0.9" sha-1 = "0.9"
shannon = "0.2.0" shannon = "0.2.0"
thiserror = "1.0.7" thiserror = "1.0.7"
tokio = { version = "1.0", features = ["io-util", "net", "rt", "sync"] } tokio = { version = "1.5", features = ["io-util", "macros", "net", "rt", "time", "sync"] }
tokio-stream = "0.1.1" tokio-stream = "0.1.1"
tokio-tungstenite = { version = "0.14", default-features = false, features = ["rustls-tls"] }
tokio-util = { version = "0.6", features = ["codec"] } tokio-util = { version = "0.6", features = ["codec"] }
url = "2.1" url = "2.1"
uuid = { version = "0.8", default-features = false, features = ["v4"] } uuid = { version = "0.8", default-features = false, features = ["v4"] }

View file

@ -1,12 +1,12 @@
use std::error::Error; use std::error::Error;
use hyper::client::HttpConnector; use hyper::client::HttpConnector;
use hyper::{Body, Client, Method, Request, Uri}; use hyper::{Body, Client, Method, Request};
use hyper_proxy::{Intercept, Proxy, ProxyConnector}; use hyper_proxy::{Intercept, Proxy, ProxyConnector};
use serde::Deserialize; use serde::Deserialize;
use url::Url; use url::Url;
use super::AP_FALLBACK; use super::ap_fallback;
const APRESOLVE_ENDPOINT: &str = "http://apresolve.spotify.com:80"; const APRESOLVE_ENDPOINT: &str = "http://apresolve.spotify.com:80";
@ -18,7 +18,7 @@ struct ApResolveData {
async fn try_apresolve( async fn try_apresolve(
proxy: Option<&Url>, proxy: Option<&Url>,
ap_port: Option<u16>, ap_port: Option<u16>,
) -> Result<String, Box<dyn Error>> { ) -> Result<(String, u16), Box<dyn Error>> {
let port = ap_port.unwrap_or(443); let port = ap_port.unwrap_or(443);
let mut req = Request::new(Body::empty()); let mut req = Request::new(Body::empty());
@ -43,27 +43,29 @@ async fn try_apresolve(
let body = hyper::body::to_bytes(response.into_body()).await?; let body = hyper::body::to_bytes(response.into_body()).await?;
let data: ApResolveData = serde_json::from_slice(body.as_ref())?; let data: ApResolveData = serde_json::from_slice(body.as_ref())?;
let mut aps = data.ap_list.into_iter().filter_map(|ap| {
let mut split = ap.rsplitn(2, ':');
let port = split
.next()
.expect("rsplitn should not return empty iterator");
let host = split.next()?.to_owned();
let port: u16 = port.parse().ok()?;
Some((host, port))
});
let ap = if ap_port.is_some() || proxy.is_some() { let ap = if ap_port.is_some() || proxy.is_some() {
data.ap_list.into_iter().find_map(|ap| { aps.find(|(_, p)| *p == port)
if ap.parse::<Uri>().ok()?.port()? == port {
Some(ap)
} else {
None
}
})
} else { } else {
data.ap_list.into_iter().next() aps.next()
} }
.ok_or("empty AP List")?; .ok_or("no valid AP in list")?;
Ok(ap) Ok(ap)
} }
pub async fn apresolve(proxy: Option<&Url>, ap_port: Option<u16>) -> String { pub async fn apresolve(proxy: Option<&Url>, ap_port: Option<u16>) -> (String, u16) {
try_apresolve(proxy, ap_port).await.unwrap_or_else(|e| { try_apresolve(proxy, ap_port).await.unwrap_or_else(|e| {
warn!("Failed to resolve Access Point: {}", e); warn!("Failed to resolve Access Point: {}, using fallback.", e);
warn!("Using fallback \"{}\"", AP_FALLBACK); ap_fallback()
AP_FALLBACK.into()
}) })
} }

View file

@ -5,7 +5,6 @@ pub use self::codec::ApCodec;
pub use self::handshake::handshake; pub use self::handshake::handshake;
use std::io::{self, ErrorKind}; use std::io::{self, ErrorKind};
use std::net::ToSocketAddrs;
use futures_util::{SinkExt, StreamExt}; use futures_util::{SinkExt, StreamExt};
use protobuf::{self, Message, ProtobufError}; use protobuf::{self, Message, ProtobufError};
@ -16,7 +15,6 @@ use url::Url;
use crate::authentication::Credentials; use crate::authentication::Credentials;
use crate::protocol::keyexchange::{APLoginFailed, ErrorCode}; use crate::protocol::keyexchange::{APLoginFailed, ErrorCode};
use crate::proxytunnel;
use crate::version; use crate::version;
pub type Transport = Framed<TcpStream, ApCodec>; pub type Transport = Framed<TcpStream, ApCodec>;
@ -58,50 +56,8 @@ impl From<APLoginFailed> for AuthenticationError {
} }
} }
pub async fn connect(addr: String, proxy: Option<&Url>) -> io::Result<Transport> { pub async fn connect(host: &str, port: u16, proxy: Option<&Url>) -> io::Result<Transport> {
let socket = if let Some(proxy_url) = proxy { let socket = crate::socket::connect(host, port, proxy).await?;
info!("Using proxy \"{}\"", proxy_url);
let socket_addr = proxy_url.socket_addrs(|| None).and_then(|addrs| {
addrs.into_iter().next().ok_or_else(|| {
io::Error::new(
io::ErrorKind::NotFound,
"Can't resolve proxy server address",
)
})
})?;
let socket = TcpStream::connect(&socket_addr).await?;
let uri = addr.parse::<http::Uri>().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidData,
"Can't parse access point address",
)
})?;
let host = uri.host().ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"The access point address contains no hostname",
)
})?;
let port = uri.port().ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"The access point address contains no port",
)
})?;
proxytunnel::proxy_connect(socket, host, port.as_str()).await?
} else {
let socket_addr = addr.to_socket_addrs()?.next().ok_or_else(|| {
io::Error::new(
io::ErrorKind::NotFound,
"Can't resolve access point address",
)
})?;
TcpStream::connect(&socket_addr).await?
};
handshake(socket).await handshake(socket).await
} }

117
core/src/dealer/maps.rs Normal file
View file

@ -0,0 +1,117 @@
use std::collections::HashMap;
#[derive(Debug)]
pub struct AlreadyHandledError(());
pub enum HandlerMap<T> {
Leaf(T),
Branch(HashMap<String, HandlerMap<T>>),
}
impl<T> Default for HandlerMap<T> {
fn default() -> Self {
Self::Branch(HashMap::new())
}
}
impl<T> HandlerMap<T> {
pub fn insert<'a>(
&mut self,
mut path: impl Iterator<Item = &'a str>,
handler: T,
) -> Result<(), AlreadyHandledError> {
match self {
Self::Leaf(_) => Err(AlreadyHandledError(())),
Self::Branch(children) => {
if let Some(component) = path.next() {
let node = children.entry(component.to_owned()).or_default();
node.insert(path, handler)
} else if children.is_empty() {
*self = Self::Leaf(handler);
Ok(())
} else {
Err(AlreadyHandledError(()))
}
}
}
}
pub fn get<'a>(&self, mut path: impl Iterator<Item = &'a str>) -> Option<&T> {
match self {
Self::Leaf(t) => Some(t),
Self::Branch(m) => {
let component = path.next()?;
m.get(component)?.get(path)
}
}
}
pub fn remove<'a>(&mut self, mut path: impl Iterator<Item = &'a str>) -> Option<T> {
match self {
Self::Leaf(_) => match std::mem::take(self) {
Self::Leaf(t) => Some(t),
_ => unreachable!(),
},
Self::Branch(map) => {
let component = path.next()?;
let next = map.get_mut(component)?;
let result = next.remove(path);
match &*next {
Self::Branch(b) if b.is_empty() => {
map.remove(component);
}
_ => (),
}
result
}
}
}
}
pub struct SubscriberMap<T> {
subscribed: Vec<T>,
children: HashMap<String, SubscriberMap<T>>,
}
impl<T> Default for SubscriberMap<T> {
fn default() -> Self {
Self {
subscribed: Vec::new(),
children: HashMap::new(),
}
}
}
impl<T> SubscriberMap<T> {
pub fn insert<'a>(&mut self, mut path: impl Iterator<Item = &'a str>, handler: T) {
if let Some(component) = path.next() {
self.children
.entry(component.to_owned())
.or_default()
.insert(path, handler);
} else {
self.subscribed.push(handler);
}
}
pub fn is_empty(&self) -> bool {
self.children.is_empty() && self.subscribed.is_empty()
}
pub fn retain<'a>(
&mut self,
mut path: impl Iterator<Item = &'a str>,
fun: &mut impl FnMut(&T) -> bool,
) {
self.subscribed.retain(|x| fun(x));
if let Some(next) = path.next() {
if let Some(y) = self.children.get_mut(next) {
y.retain(path, fun);
if y.is_empty() {
self.children.remove(next);
}
}
}
}
}

586
core/src/dealer/mod.rs Normal file
View file

@ -0,0 +1,586 @@
mod maps;
mod protocol;
use std::iter;
use std::pin::Pin;
use std::sync::atomic::AtomicBool;
use std::sync::{atomic, Arc, Mutex};
use std::task::Poll;
use std::time::Duration;
use futures_core::{Future, Stream};
use futures_util::future::join_all;
use futures_util::{SinkExt, StreamExt};
use tokio::select;
use tokio::sync::mpsc::{self, UnboundedReceiver};
use tokio::sync::Semaphore;
use tokio::task::JoinHandle;
use tokio_tungstenite::tungstenite;
use tungstenite::error::UrlError;
use url::Url;
use self::maps::*;
use self::protocol::*;
pub use self::protocol::{Message, Request};
use crate::socket;
use crate::util::{keep_flushing, CancelOnDrop, TimeoutOnDrop};
type WsMessage = tungstenite::Message;
type WsError = tungstenite::Error;
type WsResult<T> = Result<T, tungstenite::Error>;
pub struct Response {
pub success: bool,
}
pub struct Responder {
key: String,
tx: mpsc::UnboundedSender<WsMessage>,
sent: bool,
}
impl Responder {
fn new(key: String, tx: mpsc::UnboundedSender<WsMessage>) -> Self {
Self {
key,
tx,
sent: false,
}
}
// Should only be called once
fn send_internal(&mut self, response: Response) {
let response = serde_json::json!({
"type": "reply",
"key": &self.key,
"payload": {
"success": response.success,
}
})
.to_string();
if let Err(e) = self.tx.send(WsMessage::Text(response)) {
warn!("Wasn't able to reply to dealer request: {}", e);
}
}
pub fn send(mut self, success: Response) {
self.send_internal(success);
self.sent = true;
}
pub fn force_unanswered(mut self) {
self.sent = true;
}
}
impl Drop for Responder {
fn drop(&mut self) {
if !self.sent {
self.send_internal(Response { success: false });
}
}
}
pub trait IntoResponse {
fn respond(self, responder: Responder);
}
impl IntoResponse for Response {
fn respond(self, responder: Responder) {
responder.send(self)
}
}
impl<F> IntoResponse for F
where
F: Future<Output = Response> + Send + 'static,
{
fn respond(self, responder: Responder) {
tokio::spawn(async move {
responder.send(self.await);
});
}
}
impl<F, R> RequestHandler for F
where
F: (Fn(Request<Payload>) -> R) + Send + Sync + 'static,
R: IntoResponse,
{
fn handle_request(&self, request: Request<Payload>, responder: Responder) {
self(request).respond(responder);
}
}
pub trait RequestHandler: Send + Sync + 'static {
fn handle_request(&self, request: Request<Payload>, responder: Responder);
}
type MessageHandler = mpsc::UnboundedSender<Message<JsonValue>>;
// TODO: Maybe it's possible to unregister subscription directly when they
// are dropped instead of on next failed attempt.
pub struct Subscription(UnboundedReceiver<Message<JsonValue>>);
impl Stream for Subscription {
type Item = Message<JsonValue>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
self.0.poll_recv(cx)
}
}
fn split_uri(s: &str) -> Option<impl Iterator<Item = &'_ str>> {
let (scheme, sep, rest) = if let Some(rest) = s.strip_prefix("hm://") {
("hm", '/', rest)
} else if let Some(rest) = s.strip_suffix("spotify:") {
("spotify", ':', rest)
} else {
return None;
};
let rest = rest.trim_end_matches(sep);
let mut split = rest.split(sep);
if rest.is_empty() {
assert_eq!(split.next(), Some(""));
}
Some(iter::once(scheme).chain(split))
}
#[derive(Debug, Clone)]
pub enum AddHandlerError {
AlreadyHandled,
InvalidUri,
}
#[derive(Debug, Clone)]
pub enum SubscriptionError {
InvalidUri,
}
fn add_handler<H>(
map: &mut HandlerMap<Box<dyn RequestHandler>>,
uri: &str,
handler: H,
) -> Result<(), AddHandlerError>
where
H: RequestHandler,
{
let split = split_uri(uri).ok_or(AddHandlerError::InvalidUri)?;
map.insert(split, Box::new(handler))
.map_err(|_| AddHandlerError::AlreadyHandled)
}
fn remove_handler<T>(map: &mut HandlerMap<T>, uri: &str) -> Option<T> {
map.remove(split_uri(uri)?)
}
fn subscribe(
map: &mut SubscriberMap<MessageHandler>,
uris: &[&str],
) -> Result<Subscription, SubscriptionError> {
let (tx, rx) = mpsc::unbounded_channel();
for &uri in uris {
let split = split_uri(uri).ok_or(SubscriptionError::InvalidUri)?;
map.insert(split, tx.clone());
}
Ok(Subscription(rx))
}
#[derive(Default)]
pub struct Builder {
message_handlers: SubscriberMap<MessageHandler>,
request_handlers: HandlerMap<Box<dyn RequestHandler>>,
}
macro_rules! create_dealer {
($builder:expr, $shared:ident -> $body:expr) => {
match $builder {
builder => {
let shared = Arc::new(DealerShared {
message_handlers: Mutex::new(builder.message_handlers),
request_handlers: Mutex::new(builder.request_handlers),
notify_drop: Semaphore::new(0),
});
let handle = {
let $shared = Arc::clone(&shared);
tokio::spawn($body)
};
Dealer {
shared,
handle: TimeoutOnDrop::new(handle, Duration::from_secs(3)),
}
}
}
};
}
impl Builder {
pub fn new() -> Self {
Self::default()
}
pub fn add_handler(
&mut self,
uri: &str,
handler: impl RequestHandler,
) -> Result<(), AddHandlerError> {
add_handler(&mut self.request_handlers, uri, handler)
}
pub fn subscribe(&mut self, uris: &[&str]) -> Result<Subscription, SubscriptionError> {
subscribe(&mut self.message_handlers, uris)
}
pub fn launch_in_background<Fut, F>(self, get_url: F, proxy: Option<Url>) -> Dealer
where
Fut: Future<Output = Url> + Send + 'static,
F: (FnMut() -> Fut) + Send + 'static,
{
create_dealer!(self, shared -> run(shared, None, get_url, proxy))
}
pub async fn launch<Fut, F>(self, mut get_url: F, proxy: Option<Url>) -> WsResult<Dealer>
where
Fut: Future<Output = Url> + Send + 'static,
F: (FnMut() -> Fut) + Send + 'static,
{
let dealer = create_dealer!(self, shared -> {
// Try to connect.
let url = get_url().await;
let tasks = connect(&url, proxy.as_ref(), &shared).await?;
// If a connection is established, continue in a background task.
run(shared, Some(tasks), get_url, proxy)
});
Ok(dealer)
}
}
struct DealerShared {
message_handlers: Mutex<SubscriberMap<MessageHandler>>,
request_handlers: Mutex<HandlerMap<Box<dyn RequestHandler>>>,
// Semaphore with 0 permits. By closing this semaphore, we indicate
// that the actual Dealer struct has been dropped.
notify_drop: Semaphore,
}
impl DealerShared {
fn dispatch_message(&self, msg: Message<JsonValue>) {
if let Some(split) = split_uri(&msg.uri) {
self.message_handlers
.lock()
.unwrap()
.retain(split, &mut |tx| tx.send(msg.clone()).is_ok());
}
}
fn dispatch_request(
&self,
request: Request<Payload>,
send_tx: &mpsc::UnboundedSender<WsMessage>,
) {
// ResponseSender will automatically send "success: false" if it is dropped without an answer.
let responder = Responder::new(request.key.clone(), send_tx.clone());
let split = if let Some(split) = split_uri(&request.message_ident) {
split
} else {
warn!(
"Dealer request with invalid message_ident: {}",
&request.message_ident
);
return;
};
{
let handler_map = self.request_handlers.lock().unwrap();
if let Some(handler) = handler_map.get(split) {
handler.handle_request(request, responder);
return;
}
}
warn!("No handler for message_ident: {}", &request.message_ident);
}
fn dispatch(&self, m: MessageOrRequest, send_tx: &mpsc::UnboundedSender<WsMessage>) {
match m {
MessageOrRequest::Message(m) => self.dispatch_message(m),
MessageOrRequest::Request(r) => self.dispatch_request(r, send_tx),
}
}
async fn closed(&self) {
self.notify_drop.acquire().await.unwrap_err();
}
fn is_closed(&self) -> bool {
self.notify_drop.is_closed()
}
}
pub struct Dealer {
shared: Arc<DealerShared>,
handle: TimeoutOnDrop<()>,
}
impl Dealer {
pub fn add_handler<H>(&self, uri: &str, handler: H) -> Result<(), AddHandlerError>
where
H: RequestHandler,
{
add_handler(
&mut self.shared.request_handlers.lock().unwrap(),
uri,
handler,
)
}
pub fn remove_handler(&self, uri: &str) -> Option<Box<dyn RequestHandler>> {
remove_handler(&mut self.shared.request_handlers.lock().unwrap(), uri)
}
pub fn subscribe(&self, uris: &[&str]) -> Result<Subscription, SubscriptionError> {
subscribe(&mut self.shared.message_handlers.lock().unwrap(), uris)
}
pub async fn close(mut self) {
debug!("closing dealer");
self.shared.notify_drop.close();
if let Some(handle) = self.handle.take() {
CancelOnDrop(handle).await.unwrap();
}
}
}
/// Initializes a connection and returns futures that will finish when the connection is closed/lost.
async fn connect(
address: &Url,
proxy: Option<&Url>,
shared: &Arc<DealerShared>,
) -> WsResult<(JoinHandle<()>, JoinHandle<()>)> {
let host = address
.host_str()
.ok_or(WsError::Url(UrlError::NoHostName))?;
let default_port = match address.scheme() {
"ws" => 80,
"wss" => 443,
_ => return Err(WsError::Url(UrlError::UnsupportedUrlScheme)),
};
let port = address.port().unwrap_or(default_port);
let stream = socket::connect(host, port, proxy).await?;
let (mut ws_tx, ws_rx) = tokio_tungstenite::client_async_tls(address, stream)
.await?
.0
.split();
let (send_tx, mut send_rx) = mpsc::unbounded_channel::<WsMessage>();
// Spawn a task that will forward messages from the channel to the websocket.
let send_task = {
let shared = Arc::clone(&shared);
tokio::spawn(async move {
let result = loop {
select! {
biased;
() = shared.closed() => {
break Ok(None);
}
msg = send_rx.recv() => {
if let Some(msg) = msg {
// New message arrived through channel
if let WsMessage::Close(close_frame) = msg {
break Ok(close_frame);
}
if let Err(e) = ws_tx.feed(msg).await {
break Err(e);
}
} else {
break Ok(None);
}
},
e = keep_flushing(&mut ws_tx) => {
break Err(e)
}
}
};
send_rx.close();
// I don't trust in tokio_tungstenite's implementation of Sink::close.
let result = match result {
Ok(close_frame) => ws_tx.send(WsMessage::Close(close_frame)).await,
Err(WsError::AlreadyClosed) | Err(WsError::ConnectionClosed) => ws_tx.flush().await,
Err(e) => {
warn!("Dealer finished with an error: {}", e);
ws_tx.send(WsMessage::Close(None)).await
}
};
if let Err(e) = result {
warn!("Error while closing websocket: {}", e);
}
debug!("Dropping send task");
})
};
let shared = Arc::clone(&shared);
// A task that receives messages from the web socket.
let receive_task = tokio::spawn(async {
let pong_received = AtomicBool::new(true);
let send_tx = send_tx;
let shared = shared;
let receive_task = async {
let mut ws_rx = ws_rx;
loop {
match ws_rx.next().await {
Some(Ok(msg)) => match msg {
WsMessage::Text(t) => match serde_json::from_str(&t) {
Ok(m) => shared.dispatch(m, &send_tx),
Err(e) => info!("Received invalid message: {}", e),
},
WsMessage::Binary(_) => {
info!("Received invalid binary message");
}
WsMessage::Pong(_) => {
debug!("Received pong");
pong_received.store(true, atomic::Ordering::Relaxed);
}
_ => (), // tungstenite handles Close and Ping automatically
},
Some(Err(e)) => {
warn!("Websocket connection failed: {}", e);
break;
}
None => {
debug!("Websocket connection closed.");
break;
}
}
}
};
// Sends pings and checks whether a pong comes back.
let ping_task = async {
use tokio::time::{interval, sleep};
let mut timer = interval(Duration::from_secs(30));
loop {
timer.tick().await;
pong_received.store(false, atomic::Ordering::Relaxed);
if send_tx.send(WsMessage::Ping(vec![])).is_err() {
// The sender is closed.
break;
}
debug!("Sent ping");
sleep(Duration::from_secs(3)).await;
if !pong_received.load(atomic::Ordering::SeqCst) {
// No response
warn!("Websocket peer does not respond.");
break;
}
}
};
// Exit this task as soon as one our subtasks fails.
// In both cases the connection is probably lost.
select! {
() = ping_task => (),
() = receive_task => ()
}
// Try to take send_task down with us, in case it's still alive.
let _ = send_tx.send(WsMessage::Close(None));
debug!("Dropping receive task");
});
Ok((send_task, receive_task))
}
/// The main background task for `Dealer`, which coordinates reconnecting.
async fn run<F, Fut>(
shared: Arc<DealerShared>,
initial_tasks: Option<(JoinHandle<()>, JoinHandle<()>)>,
mut get_url: F,
proxy: Option<Url>,
) where
Fut: Future<Output = Url> + Send + 'static,
F: (FnMut() -> Fut) + Send + 'static,
{
let init_task = |t| Some(TimeoutOnDrop::new(t, Duration::from_secs(3)));
let mut tasks = if let Some((s, r)) = initial_tasks {
(init_task(s), init_task(r))
} else {
(None, None)
};
while !shared.is_closed() {
match &mut tasks {
(Some(t0), Some(t1)) => {
select! {
() = shared.closed() => break,
r = t0 => {
r.unwrap(); // Whatever has gone wrong (probably panicked), we can't handle it, so let's panic too.
tasks.0.take();
},
r = t1 => {
r.unwrap();
tasks.1.take();
}
}
}
_ => {
let url = select! {
() = shared.closed() => {
break
},
e = get_url() => e
};
match connect(&url, proxy.as_ref(), &shared).await {
Ok((s, r)) => tasks = (init_task(s), init_task(r)),
Err(e) => {
warn!("Error while connecting: {}", e);
}
}
}
}
}
let tasks = tasks.0.into_iter().chain(tasks.1);
let _ = join_all(tasks).await;
}

View file

@ -0,0 +1,39 @@
use std::collections::HashMap;
use serde::Deserialize;
pub type JsonValue = serde_json::Value;
pub type JsonObject = serde_json::Map<String, JsonValue>;
#[derive(Clone, Debug, Deserialize)]
pub struct Request<P> {
#[serde(default)]
pub headers: HashMap<String, String>,
pub message_ident: String,
pub key: String,
pub payload: P,
}
#[derive(Clone, Debug, Deserialize)]
pub struct Payload {
pub message_id: i32,
pub sent_by_device_id: String,
pub command: JsonObject,
}
#[derive(Clone, Debug, Deserialize)]
pub struct Message<P> {
#[serde(default)]
pub headers: HashMap<String, String>,
pub method: Option<String>,
#[serde(default)]
pub payloads: Vec<P>,
pub uri: String,
}
#[derive(Clone, Debug, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum MessageOrRequest {
Message(Message<JsonValue>),
Request(Request<Payload>),
}

View file

@ -14,25 +14,30 @@ pub mod cache;
pub mod channel; pub mod channel;
pub mod config; pub mod config;
mod connection; mod connection;
#[allow(dead_code)]
mod dealer;
#[doc(hidden)] #[doc(hidden)]
pub mod diffie_hellman; pub mod diffie_hellman;
pub mod keymaster; pub mod keymaster;
pub mod mercury; pub mod mercury;
mod proxytunnel; mod proxytunnel;
pub mod session; pub mod session;
mod socket;
pub mod spotify_id; pub mod spotify_id;
#[doc(hidden)] #[doc(hidden)]
pub mod util; pub mod util;
pub mod version; pub mod version;
const AP_FALLBACK: &str = "ap.spotify.com:443"; fn ap_fallback() -> (String, u16) {
(String::from("ap.spotify.com"), 443)
}
#[cfg(feature = "apresolve")] #[cfg(feature = "apresolve")]
mod apresolve; mod apresolve;
#[cfg(not(feature = "apresolve"))] #[cfg(not(feature = "apresolve"))]
mod apresolve { mod apresolve {
pub async fn apresolve(_: Option<&url::Url>, _: Option<u16>) -> String { pub async fn apresolve(_: Option<&url::Url>, _: Option<u16>) -> (String, u16) {
return super::AP_FALLBACK.into(); super::ap_fallback()
} }
} }

View file

@ -69,8 +69,8 @@ impl Session {
) -> Result<Session, SessionError> { ) -> Result<Session, SessionError> {
let ap = apresolve(config.proxy.as_ref(), config.ap_port).await; let ap = apresolve(config.proxy.as_ref(), config.ap_port).await;
info!("Connecting to AP \"{}\"", ap); info!("Connecting to AP \"{}:{}\"", ap.0, ap.1);
let mut conn = connection::connect(ap, config.proxy.as_ref()).await?; let mut conn = connection::connect(&ap.0, ap.1, config.proxy.as_ref()).await?;
let reusable_credentials = let reusable_credentials =
connection::authenticate(&mut conn, credentials, &config.device_id).await?; connection::authenticate(&mut conn, credentials, &config.device_id).await?;

35
core/src/socket.rs Normal file
View file

@ -0,0 +1,35 @@
use std::io;
use std::net::ToSocketAddrs;
use tokio::net::TcpStream;
use url::Url;
use crate::proxytunnel;
pub async fn connect(host: &str, port: u16, proxy: Option<&Url>) -> io::Result<TcpStream> {
let socket = if let Some(proxy_url) = proxy {
info!("Using proxy \"{}\"", proxy_url);
let socket_addr = proxy_url.socket_addrs(|| None).and_then(|addrs| {
addrs.into_iter().next().ok_or_else(|| {
io::Error::new(
io::ErrorKind::NotFound,
"Can't resolve proxy server address",
)
})
})?;
let socket = TcpStream::connect(&socket_addr).await?;
proxytunnel::proxy_connect(socket, host, &port.to_string()).await?
} else {
let socket_addr = (host, port).to_socket_addrs()?.next().ok_or_else(|| {
io::Error::new(
io::ErrorKind::NotFound,
"Can't resolve access point address",
)
})?;
TcpStream::connect(&socket_addr).await?
};
Ok(socket)
}

View file

@ -1,4 +1,99 @@
use std::future::Future;
use std::mem; use std::mem;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use futures_core::ready;
use futures_util::FutureExt;
use futures_util::Sink;
use futures_util::{future, SinkExt};
use tokio::task::JoinHandle;
use tokio::time::timeout;
/// Returns a future that will flush the sink, even if flushing is temporarily completed.
/// Finishes only if the sink throws an error.
pub(crate) fn keep_flushing<'a, T, S: Sink<T> + Unpin + 'a>(
mut s: S,
) -> impl Future<Output = S::Error> + 'a {
future::poll_fn(move |cx| match s.poll_flush_unpin(cx) {
Poll::Ready(Err(e)) => Poll::Ready(e),
_ => Poll::Pending,
})
}
pub struct CancelOnDrop<T>(pub JoinHandle<T>);
impl<T> Future for CancelOnDrop<T> {
type Output = <JoinHandle<T> as Future>::Output;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.0.poll_unpin(cx)
}
}
impl<T> Drop for CancelOnDrop<T> {
fn drop(&mut self) {
self.0.abort();
}
}
pub struct TimeoutOnDrop<T: Send + 'static> {
handle: Option<JoinHandle<T>>,
timeout: tokio::time::Duration,
}
impl<T: Send + 'static> TimeoutOnDrop<T> {
pub fn new(handle: JoinHandle<T>, timeout: tokio::time::Duration) -> Self {
Self {
handle: Some(handle),
timeout,
}
}
pub fn take(&mut self) -> Option<JoinHandle<T>> {
self.handle.take()
}
}
impl<T: Send + 'static> Future for TimeoutOnDrop<T> {
type Output = <JoinHandle<T> as Future>::Output;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let r = ready!(self
.handle
.as_mut()
.expect("Polled after ready")
.poll_unpin(cx));
self.handle = None;
Poll::Ready(r)
}
}
impl<T: Send + 'static> Drop for TimeoutOnDrop<T> {
fn drop(&mut self) {
let mut handle = if let Some(handle) = self.handle.take() {
handle
} else {
return;
};
if (&mut handle).now_or_never().is_some() {
// Already finished
return;
}
match tokio::runtime::Handle::try_current() {
Ok(h) => {
h.spawn(timeout(self.timeout, CancelOnDrop(handle)));
}
Err(_) => {
// Not in tokio context, can't spawn
handle.abort();
}
}
}
}
pub trait Seq { pub trait Seq {
fn next(&self) -> Self; fn next(&self) -> Self;