mirror of
https://github.com/librespot-org/librespot.git
synced 2024-11-08 16:45:43 +00:00
Add basic websocket support
This commit is contained in:
parent
08ba3ad7d7
commit
1ade02b7ad
11 changed files with 1040 additions and 68 deletions
136
Cargo.lock
generated
136
Cargo.lock
generated
|
@ -918,6 +918,15 @@ dependencies = [
|
|||
"hashbrown",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "input_buffer"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f97967975f448f1a7ddb12b0bc41069d09ed6a1c161a92687e057325db35d413"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "instant"
|
||||
version = "0.1.9"
|
||||
|
@ -1229,6 +1238,7 @@ dependencies = [
|
|||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"tokio-tungstenite",
|
||||
"tokio-util",
|
||||
"url",
|
||||
"uuid",
|
||||
|
@ -1911,6 +1921,21 @@ dependencies = [
|
|||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ring"
|
||||
version = "0.16.20"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"libc",
|
||||
"once_cell",
|
||||
"spin",
|
||||
"untrusted",
|
||||
"web-sys",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rodio"
|
||||
version = "0.14.0"
|
||||
|
@ -1945,6 +1970,19 @@ dependencies = [
|
|||
"semver",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls"
|
||||
version = "0.19.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "35edb675feee39aec9c99fa5ff985081995a06d594114ae14cbe797ad7b7a6d7"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"log",
|
||||
"ring",
|
||||
"sct",
|
||||
"webpki",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.5"
|
||||
|
@ -1966,6 +2004,16 @@ version = "1.1.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
|
||||
|
||||
[[package]]
|
||||
name = "sct"
|
||||
version = "0.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b362b83898e0e69f38515b82ee15aa80636befe47c3b6d3d89a911e78fc228ce"
|
||||
dependencies = [
|
||||
"ring",
|
||||
"untrusted",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sdl2"
|
||||
version = "0.34.5"
|
||||
|
@ -2103,6 +2151,12 @@ dependencies = [
|
|||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spin"
|
||||
version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
|
||||
|
||||
[[package]]
|
||||
name = "stdweb"
|
||||
version = "0.1.3"
|
||||
|
@ -2275,6 +2329,17 @@ dependencies = [
|
|||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-rustls"
|
||||
version = "0.22.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bc6844de72e57df1980054b38be3a9f4702aba4858be64dd700181a8a6d0e1b6"
|
||||
dependencies = [
|
||||
"rustls",
|
||||
"tokio",
|
||||
"webpki",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-stream"
|
||||
version = "0.1.5"
|
||||
|
@ -2286,6 +2351,23 @@ dependencies = [
|
|||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-tungstenite"
|
||||
version = "0.14.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e96bb520beab540ab664bd5a9cfeaa1fcd846fa68c830b42e2c8963071251d2"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"pin-project",
|
||||
"rustls",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tungstenite",
|
||||
"webpki",
|
||||
"webpki-roots",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-util"
|
||||
version = "0.6.6"
|
||||
|
@ -2341,6 +2423,29 @@ version = "0.2.3"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642"
|
||||
|
||||
[[package]]
|
||||
name = "tungstenite"
|
||||
version = "0.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5fe8dada8c1a3aeca77d6b51a4f1314e0f4b8e438b7b1b71e3ddaca8080e4093"
|
||||
dependencies = [
|
||||
"base64",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"http",
|
||||
"httparse",
|
||||
"input_buffer",
|
||||
"log",
|
||||
"rand",
|
||||
"rustls",
|
||||
"sha-1",
|
||||
"thiserror",
|
||||
"url",
|
||||
"utf-8",
|
||||
"webpki",
|
||||
"webpki-roots",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.13.0"
|
||||
|
@ -2389,6 +2494,12 @@ version = "0.2.2"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3"
|
||||
|
||||
[[package]]
|
||||
name = "untrusted"
|
||||
version = "0.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a"
|
||||
|
||||
[[package]]
|
||||
name = "url"
|
||||
version = "2.2.2"
|
||||
|
@ -2401,6 +2512,12 @@ dependencies = [
|
|||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "utf-8"
|
||||
version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
|
||||
|
||||
[[package]]
|
||||
name = "uuid"
|
||||
version = "0.8.2"
|
||||
|
@ -2561,6 +2678,25 @@ dependencies = [
|
|||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki"
|
||||
version = "0.21.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b8e38c0608262c46d4a56202ebabdeb094cef7e560ca7a226c6bf055188aa4ea"
|
||||
dependencies = [
|
||||
"ring",
|
||||
"untrusted",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki-roots"
|
||||
version = "0.21.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "aabe153544e473b775453675851ecc86863d2a81d786d741f6b76778f2a48940"
|
||||
dependencies = [
|
||||
"webpki",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi"
|
||||
version = "0.3.9"
|
||||
|
|
|
@ -39,8 +39,9 @@ serde_json = "1.0"
|
|||
sha-1 = "0.9"
|
||||
shannon = "0.2.0"
|
||||
thiserror = "1.0.7"
|
||||
tokio = { version = "1.0", features = ["io-util", "net", "rt", "sync"] }
|
||||
tokio = { version = "1.5", features = ["io-util", "macros", "net", "rt", "time", "sync"] }
|
||||
tokio-stream = "0.1.1"
|
||||
tokio-tungstenite = { version = "0.14", default-features = false, features = ["rustls-tls"] }
|
||||
tokio-util = { version = "0.6", features = ["codec"] }
|
||||
url = "2.1"
|
||||
uuid = { version = "0.8", default-features = false, features = ["v4"] }
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
use std::error::Error;
|
||||
|
||||
use hyper::client::HttpConnector;
|
||||
use hyper::{Body, Client, Method, Request, Uri};
|
||||
use hyper::{Body, Client, Method, Request};
|
||||
use hyper_proxy::{Intercept, Proxy, ProxyConnector};
|
||||
use serde::Deserialize;
|
||||
use url::Url;
|
||||
|
||||
use super::AP_FALLBACK;
|
||||
use super::ap_fallback;
|
||||
|
||||
const APRESOLVE_ENDPOINT: &str = "http://apresolve.spotify.com:80";
|
||||
|
||||
|
@ -18,7 +18,7 @@ struct ApResolveData {
|
|||
async fn try_apresolve(
|
||||
proxy: Option<&Url>,
|
||||
ap_port: Option<u16>,
|
||||
) -> Result<String, Box<dyn Error>> {
|
||||
) -> Result<(String, u16), Box<dyn Error>> {
|
||||
let port = ap_port.unwrap_or(443);
|
||||
|
||||
let mut req = Request::new(Body::empty());
|
||||
|
@ -43,27 +43,29 @@ async fn try_apresolve(
|
|||
let body = hyper::body::to_bytes(response.into_body()).await?;
|
||||
let data: ApResolveData = serde_json::from_slice(body.as_ref())?;
|
||||
|
||||
let mut aps = data.ap_list.into_iter().filter_map(|ap| {
|
||||
let mut split = ap.rsplitn(2, ':');
|
||||
let port = split
|
||||
.next()
|
||||
.expect("rsplitn should not return empty iterator");
|
||||
let host = split.next()?.to_owned();
|
||||
let port: u16 = port.parse().ok()?;
|
||||
Some((host, port))
|
||||
});
|
||||
let ap = if ap_port.is_some() || proxy.is_some() {
|
||||
data.ap_list.into_iter().find_map(|ap| {
|
||||
if ap.parse::<Uri>().ok()?.port()? == port {
|
||||
Some(ap)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
aps.find(|(_, p)| *p == port)
|
||||
} else {
|
||||
data.ap_list.into_iter().next()
|
||||
aps.next()
|
||||
}
|
||||
.ok_or("empty AP List")?;
|
||||
.ok_or("no valid AP in list")?;
|
||||
|
||||
Ok(ap)
|
||||
}
|
||||
|
||||
pub async fn apresolve(proxy: Option<&Url>, ap_port: Option<u16>) -> String {
|
||||
pub async fn apresolve(proxy: Option<&Url>, ap_port: Option<u16>) -> (String, u16) {
|
||||
try_apresolve(proxy, ap_port).await.unwrap_or_else(|e| {
|
||||
warn!("Failed to resolve Access Point: {}", e);
|
||||
warn!("Using fallback \"{}\"", AP_FALLBACK);
|
||||
AP_FALLBACK.into()
|
||||
warn!("Failed to resolve Access Point: {}, using fallback.", e);
|
||||
ap_fallback()
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@ pub use self::codec::ApCodec;
|
|||
pub use self::handshake::handshake;
|
||||
|
||||
use std::io::{self, ErrorKind};
|
||||
use std::net::ToSocketAddrs;
|
||||
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use protobuf::{self, Message, ProtobufError};
|
||||
|
@ -16,7 +15,6 @@ use url::Url;
|
|||
|
||||
use crate::authentication::Credentials;
|
||||
use crate::protocol::keyexchange::{APLoginFailed, ErrorCode};
|
||||
use crate::proxytunnel;
|
||||
use crate::version;
|
||||
|
||||
pub type Transport = Framed<TcpStream, ApCodec>;
|
||||
|
@ -58,50 +56,8 @@ impl From<APLoginFailed> for AuthenticationError {
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn connect(addr: String, proxy: Option<&Url>) -> io::Result<Transport> {
|
||||
let socket = if let Some(proxy_url) = proxy {
|
||||
info!("Using proxy \"{}\"", proxy_url);
|
||||
|
||||
let socket_addr = proxy_url.socket_addrs(|| None).and_then(|addrs| {
|
||||
addrs.into_iter().next().ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::NotFound,
|
||||
"Can't resolve proxy server address",
|
||||
)
|
||||
})
|
||||
})?;
|
||||
let socket = TcpStream::connect(&socket_addr).await?;
|
||||
|
||||
let uri = addr.parse::<http::Uri>().map_err(|_| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"Can't parse access point address",
|
||||
)
|
||||
})?;
|
||||
let host = uri.host().ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"The access point address contains no hostname",
|
||||
)
|
||||
})?;
|
||||
let port = uri.port().ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"The access point address contains no port",
|
||||
)
|
||||
})?;
|
||||
|
||||
proxytunnel::proxy_connect(socket, host, port.as_str()).await?
|
||||
} else {
|
||||
let socket_addr = addr.to_socket_addrs()?.next().ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::NotFound,
|
||||
"Can't resolve access point address",
|
||||
)
|
||||
})?;
|
||||
|
||||
TcpStream::connect(&socket_addr).await?
|
||||
};
|
||||
pub async fn connect(host: &str, port: u16, proxy: Option<&Url>) -> io::Result<Transport> {
|
||||
let socket = crate::socket::connect(host, port, proxy).await?;
|
||||
|
||||
handshake(socket).await
|
||||
}
|
||||
|
|
117
core/src/dealer/maps.rs
Normal file
117
core/src/dealer/maps.rs
Normal file
|
@ -0,0 +1,117 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AlreadyHandledError(());
|
||||
|
||||
pub enum HandlerMap<T> {
|
||||
Leaf(T),
|
||||
Branch(HashMap<String, HandlerMap<T>>),
|
||||
}
|
||||
|
||||
impl<T> Default for HandlerMap<T> {
|
||||
fn default() -> Self {
|
||||
Self::Branch(HashMap::new())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> HandlerMap<T> {
|
||||
pub fn insert<'a>(
|
||||
&mut self,
|
||||
mut path: impl Iterator<Item = &'a str>,
|
||||
handler: T,
|
||||
) -> Result<(), AlreadyHandledError> {
|
||||
match self {
|
||||
Self::Leaf(_) => Err(AlreadyHandledError(())),
|
||||
Self::Branch(children) => {
|
||||
if let Some(component) = path.next() {
|
||||
let node = children.entry(component.to_owned()).or_default();
|
||||
node.insert(path, handler)
|
||||
} else if children.is_empty() {
|
||||
*self = Self::Leaf(handler);
|
||||
Ok(())
|
||||
} else {
|
||||
Err(AlreadyHandledError(()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get<'a>(&self, mut path: impl Iterator<Item = &'a str>) -> Option<&T> {
|
||||
match self {
|
||||
Self::Leaf(t) => Some(t),
|
||||
Self::Branch(m) => {
|
||||
let component = path.next()?;
|
||||
m.get(component)?.get(path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn remove<'a>(&mut self, mut path: impl Iterator<Item = &'a str>) -> Option<T> {
|
||||
match self {
|
||||
Self::Leaf(_) => match std::mem::take(self) {
|
||||
Self::Leaf(t) => Some(t),
|
||||
_ => unreachable!(),
|
||||
},
|
||||
Self::Branch(map) => {
|
||||
let component = path.next()?;
|
||||
let next = map.get_mut(component)?;
|
||||
let result = next.remove(path);
|
||||
match &*next {
|
||||
Self::Branch(b) if b.is_empty() => {
|
||||
map.remove(component);
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SubscriberMap<T> {
|
||||
subscribed: Vec<T>,
|
||||
children: HashMap<String, SubscriberMap<T>>,
|
||||
}
|
||||
|
||||
impl<T> Default for SubscriberMap<T> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
subscribed: Vec::new(),
|
||||
children: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> SubscriberMap<T> {
|
||||
pub fn insert<'a>(&mut self, mut path: impl Iterator<Item = &'a str>, handler: T) {
|
||||
if let Some(component) = path.next() {
|
||||
self.children
|
||||
.entry(component.to_owned())
|
||||
.or_default()
|
||||
.insert(path, handler);
|
||||
} else {
|
||||
self.subscribed.push(handler);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.children.is_empty() && self.subscribed.is_empty()
|
||||
}
|
||||
|
||||
pub fn retain<'a>(
|
||||
&mut self,
|
||||
mut path: impl Iterator<Item = &'a str>,
|
||||
fun: &mut impl FnMut(&T) -> bool,
|
||||
) {
|
||||
self.subscribed.retain(|x| fun(x));
|
||||
|
||||
if let Some(next) = path.next() {
|
||||
if let Some(y) = self.children.get_mut(next) {
|
||||
y.retain(path, fun);
|
||||
if y.is_empty() {
|
||||
self.children.remove(next);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
586
core/src/dealer/mod.rs
Normal file
586
core/src/dealer/mod.rs
Normal file
|
@ -0,0 +1,586 @@
|
|||
mod maps;
|
||||
mod protocol;
|
||||
|
||||
use std::iter;
|
||||
use std::pin::Pin;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::{atomic, Arc, Mutex};
|
||||
use std::task::Poll;
|
||||
use std::time::Duration;
|
||||
|
||||
use futures_core::{Future, Stream};
|
||||
use futures_util::future::join_all;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use tokio::select;
|
||||
use tokio::sync::mpsc::{self, UnboundedReceiver};
|
||||
use tokio::sync::Semaphore;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio_tungstenite::tungstenite;
|
||||
use tungstenite::error::UrlError;
|
||||
use url::Url;
|
||||
|
||||
use self::maps::*;
|
||||
use self::protocol::*;
|
||||
pub use self::protocol::{Message, Request};
|
||||
use crate::socket;
|
||||
use crate::util::{keep_flushing, CancelOnDrop, TimeoutOnDrop};
|
||||
|
||||
type WsMessage = tungstenite::Message;
|
||||
type WsError = tungstenite::Error;
|
||||
type WsResult<T> = Result<T, tungstenite::Error>;
|
||||
|
||||
pub struct Response {
|
||||
pub success: bool,
|
||||
}
|
||||
|
||||
pub struct Responder {
|
||||
key: String,
|
||||
tx: mpsc::UnboundedSender<WsMessage>,
|
||||
sent: bool,
|
||||
}
|
||||
|
||||
impl Responder {
|
||||
fn new(key: String, tx: mpsc::UnboundedSender<WsMessage>) -> Self {
|
||||
Self {
|
||||
key,
|
||||
tx,
|
||||
sent: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Should only be called once
|
||||
fn send_internal(&mut self, response: Response) {
|
||||
let response = serde_json::json!({
|
||||
"type": "reply",
|
||||
"key": &self.key,
|
||||
"payload": {
|
||||
"success": response.success,
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
|
||||
if let Err(e) = self.tx.send(WsMessage::Text(response)) {
|
||||
warn!("Wasn't able to reply to dealer request: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn send(mut self, success: Response) {
|
||||
self.send_internal(success);
|
||||
self.sent = true;
|
||||
}
|
||||
|
||||
pub fn force_unanswered(mut self) {
|
||||
self.sent = true;
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Responder {
|
||||
fn drop(&mut self) {
|
||||
if !self.sent {
|
||||
self.send_internal(Response { success: false });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait IntoResponse {
|
||||
fn respond(self, responder: Responder);
|
||||
}
|
||||
|
||||
impl IntoResponse for Response {
|
||||
fn respond(self, responder: Responder) {
|
||||
responder.send(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> IntoResponse for F
|
||||
where
|
||||
F: Future<Output = Response> + Send + 'static,
|
||||
{
|
||||
fn respond(self, responder: Responder) {
|
||||
tokio::spawn(async move {
|
||||
responder.send(self.await);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl<F, R> RequestHandler for F
|
||||
where
|
||||
F: (Fn(Request<Payload>) -> R) + Send + Sync + 'static,
|
||||
R: IntoResponse,
|
||||
{
|
||||
fn handle_request(&self, request: Request<Payload>, responder: Responder) {
|
||||
self(request).respond(responder);
|
||||
}
|
||||
}
|
||||
|
||||
pub trait RequestHandler: Send + Sync + 'static {
|
||||
fn handle_request(&self, request: Request<Payload>, responder: Responder);
|
||||
}
|
||||
|
||||
type MessageHandler = mpsc::UnboundedSender<Message<JsonValue>>;
|
||||
|
||||
// TODO: Maybe it's possible to unregister subscription directly when they
|
||||
// are dropped instead of on next failed attempt.
|
||||
pub struct Subscription(UnboundedReceiver<Message<JsonValue>>);
|
||||
|
||||
impl Stream for Subscription {
|
||||
type Item = Message<JsonValue>;
|
||||
|
||||
fn poll_next(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> Poll<Option<Self::Item>> {
|
||||
self.0.poll_recv(cx)
|
||||
}
|
||||
}
|
||||
|
||||
fn split_uri(s: &str) -> Option<impl Iterator<Item = &'_ str>> {
|
||||
let (scheme, sep, rest) = if let Some(rest) = s.strip_prefix("hm://") {
|
||||
("hm", '/', rest)
|
||||
} else if let Some(rest) = s.strip_suffix("spotify:") {
|
||||
("spotify", ':', rest)
|
||||
} else {
|
||||
return None;
|
||||
};
|
||||
|
||||
let rest = rest.trim_end_matches(sep);
|
||||
let mut split = rest.split(sep);
|
||||
|
||||
if rest.is_empty() {
|
||||
assert_eq!(split.next(), Some(""));
|
||||
}
|
||||
|
||||
Some(iter::once(scheme).chain(split))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AddHandlerError {
|
||||
AlreadyHandled,
|
||||
InvalidUri,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum SubscriptionError {
|
||||
InvalidUri,
|
||||
}
|
||||
|
||||
fn add_handler<H>(
|
||||
map: &mut HandlerMap<Box<dyn RequestHandler>>,
|
||||
uri: &str,
|
||||
handler: H,
|
||||
) -> Result<(), AddHandlerError>
|
||||
where
|
||||
H: RequestHandler,
|
||||
{
|
||||
let split = split_uri(uri).ok_or(AddHandlerError::InvalidUri)?;
|
||||
map.insert(split, Box::new(handler))
|
||||
.map_err(|_| AddHandlerError::AlreadyHandled)
|
||||
}
|
||||
|
||||
fn remove_handler<T>(map: &mut HandlerMap<T>, uri: &str) -> Option<T> {
|
||||
map.remove(split_uri(uri)?)
|
||||
}
|
||||
|
||||
fn subscribe(
|
||||
map: &mut SubscriberMap<MessageHandler>,
|
||||
uris: &[&str],
|
||||
) -> Result<Subscription, SubscriptionError> {
|
||||
let (tx, rx) = mpsc::unbounded_channel();
|
||||
|
||||
for &uri in uris {
|
||||
let split = split_uri(uri).ok_or(SubscriptionError::InvalidUri)?;
|
||||
map.insert(split, tx.clone());
|
||||
}
|
||||
|
||||
Ok(Subscription(rx))
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Builder {
|
||||
message_handlers: SubscriberMap<MessageHandler>,
|
||||
request_handlers: HandlerMap<Box<dyn RequestHandler>>,
|
||||
}
|
||||
|
||||
macro_rules! create_dealer {
|
||||
($builder:expr, $shared:ident -> $body:expr) => {
|
||||
match $builder {
|
||||
builder => {
|
||||
let shared = Arc::new(DealerShared {
|
||||
message_handlers: Mutex::new(builder.message_handlers),
|
||||
request_handlers: Mutex::new(builder.request_handlers),
|
||||
notify_drop: Semaphore::new(0),
|
||||
});
|
||||
|
||||
let handle = {
|
||||
let $shared = Arc::clone(&shared);
|
||||
tokio::spawn($body)
|
||||
};
|
||||
|
||||
Dealer {
|
||||
shared,
|
||||
handle: TimeoutOnDrop::new(handle, Duration::from_secs(3)),
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl Builder {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn add_handler(
|
||||
&mut self,
|
||||
uri: &str,
|
||||
handler: impl RequestHandler,
|
||||
) -> Result<(), AddHandlerError> {
|
||||
add_handler(&mut self.request_handlers, uri, handler)
|
||||
}
|
||||
|
||||
pub fn subscribe(&mut self, uris: &[&str]) -> Result<Subscription, SubscriptionError> {
|
||||
subscribe(&mut self.message_handlers, uris)
|
||||
}
|
||||
|
||||
pub fn launch_in_background<Fut, F>(self, get_url: F, proxy: Option<Url>) -> Dealer
|
||||
where
|
||||
Fut: Future<Output = Url> + Send + 'static,
|
||||
F: (FnMut() -> Fut) + Send + 'static,
|
||||
{
|
||||
create_dealer!(self, shared -> run(shared, None, get_url, proxy))
|
||||
}
|
||||
|
||||
pub async fn launch<Fut, F>(self, mut get_url: F, proxy: Option<Url>) -> WsResult<Dealer>
|
||||
where
|
||||
Fut: Future<Output = Url> + Send + 'static,
|
||||
F: (FnMut() -> Fut) + Send + 'static,
|
||||
{
|
||||
let dealer = create_dealer!(self, shared -> {
|
||||
// Try to connect.
|
||||
let url = get_url().await;
|
||||
let tasks = connect(&url, proxy.as_ref(), &shared).await?;
|
||||
|
||||
// If a connection is established, continue in a background task.
|
||||
run(shared, Some(tasks), get_url, proxy)
|
||||
});
|
||||
|
||||
Ok(dealer)
|
||||
}
|
||||
}
|
||||
|
||||
struct DealerShared {
|
||||
message_handlers: Mutex<SubscriberMap<MessageHandler>>,
|
||||
request_handlers: Mutex<HandlerMap<Box<dyn RequestHandler>>>,
|
||||
|
||||
// Semaphore with 0 permits. By closing this semaphore, we indicate
|
||||
// that the actual Dealer struct has been dropped.
|
||||
notify_drop: Semaphore,
|
||||
}
|
||||
|
||||
impl DealerShared {
|
||||
fn dispatch_message(&self, msg: Message<JsonValue>) {
|
||||
if let Some(split) = split_uri(&msg.uri) {
|
||||
self.message_handlers
|
||||
.lock()
|
||||
.unwrap()
|
||||
.retain(split, &mut |tx| tx.send(msg.clone()).is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
fn dispatch_request(
|
||||
&self,
|
||||
request: Request<Payload>,
|
||||
send_tx: &mpsc::UnboundedSender<WsMessage>,
|
||||
) {
|
||||
// ResponseSender will automatically send "success: false" if it is dropped without an answer.
|
||||
let responder = Responder::new(request.key.clone(), send_tx.clone());
|
||||
|
||||
let split = if let Some(split) = split_uri(&request.message_ident) {
|
||||
split
|
||||
} else {
|
||||
warn!(
|
||||
"Dealer request with invalid message_ident: {}",
|
||||
&request.message_ident
|
||||
);
|
||||
return;
|
||||
};
|
||||
|
||||
{
|
||||
let handler_map = self.request_handlers.lock().unwrap();
|
||||
|
||||
if let Some(handler) = handler_map.get(split) {
|
||||
handler.handle_request(request, responder);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
warn!("No handler for message_ident: {}", &request.message_ident);
|
||||
}
|
||||
|
||||
fn dispatch(&self, m: MessageOrRequest, send_tx: &mpsc::UnboundedSender<WsMessage>) {
|
||||
match m {
|
||||
MessageOrRequest::Message(m) => self.dispatch_message(m),
|
||||
MessageOrRequest::Request(r) => self.dispatch_request(r, send_tx),
|
||||
}
|
||||
}
|
||||
|
||||
async fn closed(&self) {
|
||||
self.notify_drop.acquire().await.unwrap_err();
|
||||
}
|
||||
|
||||
fn is_closed(&self) -> bool {
|
||||
self.notify_drop.is_closed()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Dealer {
|
||||
shared: Arc<DealerShared>,
|
||||
handle: TimeoutOnDrop<()>,
|
||||
}
|
||||
|
||||
impl Dealer {
|
||||
pub fn add_handler<H>(&self, uri: &str, handler: H) -> Result<(), AddHandlerError>
|
||||
where
|
||||
H: RequestHandler,
|
||||
{
|
||||
add_handler(
|
||||
&mut self.shared.request_handlers.lock().unwrap(),
|
||||
uri,
|
||||
handler,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn remove_handler(&self, uri: &str) -> Option<Box<dyn RequestHandler>> {
|
||||
remove_handler(&mut self.shared.request_handlers.lock().unwrap(), uri)
|
||||
}
|
||||
|
||||
pub fn subscribe(&self, uris: &[&str]) -> Result<Subscription, SubscriptionError> {
|
||||
subscribe(&mut self.shared.message_handlers.lock().unwrap(), uris)
|
||||
}
|
||||
|
||||
pub async fn close(mut self) {
|
||||
debug!("closing dealer");
|
||||
|
||||
self.shared.notify_drop.close();
|
||||
|
||||
if let Some(handle) = self.handle.take() {
|
||||
CancelOnDrop(handle).await.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Initializes a connection and returns futures that will finish when the connection is closed/lost.
|
||||
async fn connect(
|
||||
address: &Url,
|
||||
proxy: Option<&Url>,
|
||||
shared: &Arc<DealerShared>,
|
||||
) -> WsResult<(JoinHandle<()>, JoinHandle<()>)> {
|
||||
let host = address
|
||||
.host_str()
|
||||
.ok_or(WsError::Url(UrlError::NoHostName))?;
|
||||
|
||||
let default_port = match address.scheme() {
|
||||
"ws" => 80,
|
||||
"wss" => 443,
|
||||
_ => return Err(WsError::Url(UrlError::UnsupportedUrlScheme)),
|
||||
};
|
||||
|
||||
let port = address.port().unwrap_or(default_port);
|
||||
|
||||
let stream = socket::connect(host, port, proxy).await?;
|
||||
|
||||
let (mut ws_tx, ws_rx) = tokio_tungstenite::client_async_tls(address, stream)
|
||||
.await?
|
||||
.0
|
||||
.split();
|
||||
|
||||
let (send_tx, mut send_rx) = mpsc::unbounded_channel::<WsMessage>();
|
||||
|
||||
// Spawn a task that will forward messages from the channel to the websocket.
|
||||
let send_task = {
|
||||
let shared = Arc::clone(&shared);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let result = loop {
|
||||
select! {
|
||||
biased;
|
||||
() = shared.closed() => {
|
||||
break Ok(None);
|
||||
}
|
||||
msg = send_rx.recv() => {
|
||||
if let Some(msg) = msg {
|
||||
// New message arrived through channel
|
||||
if let WsMessage::Close(close_frame) = msg {
|
||||
break Ok(close_frame);
|
||||
}
|
||||
|
||||
if let Err(e) = ws_tx.feed(msg).await {
|
||||
break Err(e);
|
||||
}
|
||||
} else {
|
||||
break Ok(None);
|
||||
}
|
||||
},
|
||||
e = keep_flushing(&mut ws_tx) => {
|
||||
break Err(e)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
send_rx.close();
|
||||
|
||||
// I don't trust in tokio_tungstenite's implementation of Sink::close.
|
||||
let result = match result {
|
||||
Ok(close_frame) => ws_tx.send(WsMessage::Close(close_frame)).await,
|
||||
Err(WsError::AlreadyClosed) | Err(WsError::ConnectionClosed) => ws_tx.flush().await,
|
||||
Err(e) => {
|
||||
warn!("Dealer finished with an error: {}", e);
|
||||
ws_tx.send(WsMessage::Close(None)).await
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(e) = result {
|
||||
warn!("Error while closing websocket: {}", e);
|
||||
}
|
||||
|
||||
debug!("Dropping send task");
|
||||
})
|
||||
};
|
||||
|
||||
let shared = Arc::clone(&shared);
|
||||
|
||||
// A task that receives messages from the web socket.
|
||||
let receive_task = tokio::spawn(async {
|
||||
let pong_received = AtomicBool::new(true);
|
||||
let send_tx = send_tx;
|
||||
let shared = shared;
|
||||
|
||||
let receive_task = async {
|
||||
let mut ws_rx = ws_rx;
|
||||
|
||||
loop {
|
||||
match ws_rx.next().await {
|
||||
Some(Ok(msg)) => match msg {
|
||||
WsMessage::Text(t) => match serde_json::from_str(&t) {
|
||||
Ok(m) => shared.dispatch(m, &send_tx),
|
||||
Err(e) => info!("Received invalid message: {}", e),
|
||||
},
|
||||
WsMessage::Binary(_) => {
|
||||
info!("Received invalid binary message");
|
||||
}
|
||||
WsMessage::Pong(_) => {
|
||||
debug!("Received pong");
|
||||
pong_received.store(true, atomic::Ordering::Relaxed);
|
||||
}
|
||||
_ => (), // tungstenite handles Close and Ping automatically
|
||||
},
|
||||
Some(Err(e)) => {
|
||||
warn!("Websocket connection failed: {}", e);
|
||||
break;
|
||||
}
|
||||
None => {
|
||||
debug!("Websocket connection closed.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Sends pings and checks whether a pong comes back.
|
||||
let ping_task = async {
|
||||
use tokio::time::{interval, sleep};
|
||||
|
||||
let mut timer = interval(Duration::from_secs(30));
|
||||
|
||||
loop {
|
||||
timer.tick().await;
|
||||
|
||||
pong_received.store(false, atomic::Ordering::Relaxed);
|
||||
if send_tx.send(WsMessage::Ping(vec![])).is_err() {
|
||||
// The sender is closed.
|
||||
break;
|
||||
}
|
||||
|
||||
debug!("Sent ping");
|
||||
|
||||
sleep(Duration::from_secs(3)).await;
|
||||
|
||||
if !pong_received.load(atomic::Ordering::SeqCst) {
|
||||
// No response
|
||||
warn!("Websocket peer does not respond.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Exit this task as soon as one our subtasks fails.
|
||||
// In both cases the connection is probably lost.
|
||||
select! {
|
||||
() = ping_task => (),
|
||||
() = receive_task => ()
|
||||
}
|
||||
|
||||
// Try to take send_task down with us, in case it's still alive.
|
||||
let _ = send_tx.send(WsMessage::Close(None));
|
||||
|
||||
debug!("Dropping receive task");
|
||||
});
|
||||
|
||||
Ok((send_task, receive_task))
|
||||
}
|
||||
|
||||
/// The main background task for `Dealer`, which coordinates reconnecting.
|
||||
async fn run<F, Fut>(
|
||||
shared: Arc<DealerShared>,
|
||||
initial_tasks: Option<(JoinHandle<()>, JoinHandle<()>)>,
|
||||
mut get_url: F,
|
||||
proxy: Option<Url>,
|
||||
) where
|
||||
Fut: Future<Output = Url> + Send + 'static,
|
||||
F: (FnMut() -> Fut) + Send + 'static,
|
||||
{
|
||||
let init_task = |t| Some(TimeoutOnDrop::new(t, Duration::from_secs(3)));
|
||||
|
||||
let mut tasks = if let Some((s, r)) = initial_tasks {
|
||||
(init_task(s), init_task(r))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
while !shared.is_closed() {
|
||||
match &mut tasks {
|
||||
(Some(t0), Some(t1)) => {
|
||||
select! {
|
||||
() = shared.closed() => break,
|
||||
r = t0 => {
|
||||
r.unwrap(); // Whatever has gone wrong (probably panicked), we can't handle it, so let's panic too.
|
||||
tasks.0.take();
|
||||
},
|
||||
r = t1 => {
|
||||
r.unwrap();
|
||||
tasks.1.take();
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
let url = select! {
|
||||
() = shared.closed() => {
|
||||
break
|
||||
},
|
||||
e = get_url() => e
|
||||
};
|
||||
|
||||
match connect(&url, proxy.as_ref(), &shared).await {
|
||||
Ok((s, r)) => tasks = (init_task(s), init_task(r)),
|
||||
Err(e) => {
|
||||
warn!("Error while connecting: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let tasks = tasks.0.into_iter().chain(tasks.1);
|
||||
|
||||
let _ = join_all(tasks).await;
|
||||
}
|
39
core/src/dealer/protocol.rs
Normal file
39
core/src/dealer/protocol.rs
Normal file
|
@ -0,0 +1,39 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use serde::Deserialize;
|
||||
|
||||
pub type JsonValue = serde_json::Value;
|
||||
pub type JsonObject = serde_json::Map<String, JsonValue>;
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct Request<P> {
|
||||
#[serde(default)]
|
||||
pub headers: HashMap<String, String>,
|
||||
pub message_ident: String,
|
||||
pub key: String,
|
||||
pub payload: P,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct Payload {
|
||||
pub message_id: i32,
|
||||
pub sent_by_device_id: String,
|
||||
pub command: JsonObject,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
pub struct Message<P> {
|
||||
#[serde(default)]
|
||||
pub headers: HashMap<String, String>,
|
||||
pub method: Option<String>,
|
||||
#[serde(default)]
|
||||
pub payloads: Vec<P>,
|
||||
pub uri: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum MessageOrRequest {
|
||||
Message(Message<JsonValue>),
|
||||
Request(Request<Payload>),
|
||||
}
|
|
@ -14,25 +14,30 @@ pub mod cache;
|
|||
pub mod channel;
|
||||
pub mod config;
|
||||
mod connection;
|
||||
#[allow(dead_code)]
|
||||
mod dealer;
|
||||
#[doc(hidden)]
|
||||
pub mod diffie_hellman;
|
||||
pub mod keymaster;
|
||||
pub mod mercury;
|
||||
mod proxytunnel;
|
||||
pub mod session;
|
||||
mod socket;
|
||||
pub mod spotify_id;
|
||||
#[doc(hidden)]
|
||||
pub mod util;
|
||||
pub mod version;
|
||||
|
||||
const AP_FALLBACK: &str = "ap.spotify.com:443";
|
||||
fn ap_fallback() -> (String, u16) {
|
||||
(String::from("ap.spotify.com"), 443)
|
||||
}
|
||||
|
||||
#[cfg(feature = "apresolve")]
|
||||
mod apresolve;
|
||||
|
||||
#[cfg(not(feature = "apresolve"))]
|
||||
mod apresolve {
|
||||
pub async fn apresolve(_: Option<&url::Url>, _: Option<u16>) -> String {
|
||||
return super::AP_FALLBACK.into();
|
||||
pub async fn apresolve(_: Option<&url::Url>, _: Option<u16>) -> (String, u16) {
|
||||
super::ap_fallback()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -69,8 +69,8 @@ impl Session {
|
|||
) -> Result<Session, SessionError> {
|
||||
let ap = apresolve(config.proxy.as_ref(), config.ap_port).await;
|
||||
|
||||
info!("Connecting to AP \"{}\"", ap);
|
||||
let mut conn = connection::connect(ap, config.proxy.as_ref()).await?;
|
||||
info!("Connecting to AP \"{}:{}\"", ap.0, ap.1);
|
||||
let mut conn = connection::connect(&ap.0, ap.1, config.proxy.as_ref()).await?;
|
||||
|
||||
let reusable_credentials =
|
||||
connection::authenticate(&mut conn, credentials, &config.device_id).await?;
|
||||
|
|
35
core/src/socket.rs
Normal file
35
core/src/socket.rs
Normal file
|
@ -0,0 +1,35 @@
|
|||
use std::io;
|
||||
use std::net::ToSocketAddrs;
|
||||
|
||||
use tokio::net::TcpStream;
|
||||
use url::Url;
|
||||
|
||||
use crate::proxytunnel;
|
||||
|
||||
pub async fn connect(host: &str, port: u16, proxy: Option<&Url>) -> io::Result<TcpStream> {
|
||||
let socket = if let Some(proxy_url) = proxy {
|
||||
info!("Using proxy \"{}\"", proxy_url);
|
||||
|
||||
let socket_addr = proxy_url.socket_addrs(|| None).and_then(|addrs| {
|
||||
addrs.into_iter().next().ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::NotFound,
|
||||
"Can't resolve proxy server address",
|
||||
)
|
||||
})
|
||||
})?;
|
||||
let socket = TcpStream::connect(&socket_addr).await?;
|
||||
|
||||
proxytunnel::proxy_connect(socket, host, &port.to_string()).await?
|
||||
} else {
|
||||
let socket_addr = (host, port).to_socket_addrs()?.next().ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::NotFound,
|
||||
"Can't resolve access point address",
|
||||
)
|
||||
})?;
|
||||
|
||||
TcpStream::connect(&socket_addr).await?
|
||||
};
|
||||
Ok(socket)
|
||||
}
|
|
@ -1,4 +1,99 @@
|
|||
use std::future::Future;
|
||||
use std::mem;
|
||||
use std::pin::Pin;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
|
||||
use futures_core::ready;
|
||||
use futures_util::FutureExt;
|
||||
use futures_util::Sink;
|
||||
use futures_util::{future, SinkExt};
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio::time::timeout;
|
||||
|
||||
/// Returns a future that will flush the sink, even if flushing is temporarily completed.
|
||||
/// Finishes only if the sink throws an error.
|
||||
pub(crate) fn keep_flushing<'a, T, S: Sink<T> + Unpin + 'a>(
|
||||
mut s: S,
|
||||
) -> impl Future<Output = S::Error> + 'a {
|
||||
future::poll_fn(move |cx| match s.poll_flush_unpin(cx) {
|
||||
Poll::Ready(Err(e)) => Poll::Ready(e),
|
||||
_ => Poll::Pending,
|
||||
})
|
||||
}
|
||||
|
||||
pub struct CancelOnDrop<T>(pub JoinHandle<T>);
|
||||
|
||||
impl<T> Future for CancelOnDrop<T> {
|
||||
type Output = <JoinHandle<T> as Future>::Output;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
self.0.poll_unpin(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Drop for CancelOnDrop<T> {
|
||||
fn drop(&mut self) {
|
||||
self.0.abort();
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TimeoutOnDrop<T: Send + 'static> {
|
||||
handle: Option<JoinHandle<T>>,
|
||||
timeout: tokio::time::Duration,
|
||||
}
|
||||
|
||||
impl<T: Send + 'static> TimeoutOnDrop<T> {
|
||||
pub fn new(handle: JoinHandle<T>, timeout: tokio::time::Duration) -> Self {
|
||||
Self {
|
||||
handle: Some(handle),
|
||||
timeout,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn take(&mut self) -> Option<JoinHandle<T>> {
|
||||
self.handle.take()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Send + 'static> Future for TimeoutOnDrop<T> {
|
||||
type Output = <JoinHandle<T> as Future>::Output;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let r = ready!(self
|
||||
.handle
|
||||
.as_mut()
|
||||
.expect("Polled after ready")
|
||||
.poll_unpin(cx));
|
||||
self.handle = None;
|
||||
Poll::Ready(r)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Send + 'static> Drop for TimeoutOnDrop<T> {
|
||||
fn drop(&mut self) {
|
||||
let mut handle = if let Some(handle) = self.handle.take() {
|
||||
handle
|
||||
} else {
|
||||
return;
|
||||
};
|
||||
|
||||
if (&mut handle).now_or_never().is_some() {
|
||||
// Already finished
|
||||
return;
|
||||
}
|
||||
|
||||
match tokio::runtime::Handle::try_current() {
|
||||
Ok(h) => {
|
||||
h.spawn(timeout(self.timeout, CancelOnDrop(handle)));
|
||||
}
|
||||
Err(_) => {
|
||||
// Not in tokio context, can't spawn
|
||||
handle.abort();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Seq {
|
||||
fn next(&self) -> Self;
|
||||
|
|
Loading…
Reference in a new issue