Register message listeners before connecting

This commit is contained in:
Roderick van Domburg 2022-01-16 01:14:00 +01:00
parent 8811b89b2d
commit abbc3bade8
No known key found for this signature in database
GPG key ID: FE2585E713F9F30A
4 changed files with 141 additions and 138 deletions

View file

@ -8,7 +8,7 @@ use std::{
use futures_util::{ use futures_util::{
future::{self, FusedFuture}, future::{self, FusedFuture},
stream::FusedStream, stream::FusedStream,
FutureExt, StreamExt, TryFutureExt, FutureExt, StreamExt,
}; };
use protobuf::{self, Message}; use protobuf::{self, Message};
@ -21,6 +21,7 @@ use crate::{
config::ConnectConfig, config::ConnectConfig,
context::StationContext, context::StationContext,
core::{ core::{
authentication::Credentials,
mercury::{MercuryError, MercurySender}, mercury::{MercuryError, MercurySender},
session::UserAttributes, session::UserAttributes,
util::SeqGenerator, util::SeqGenerator,
@ -92,7 +93,7 @@ struct SpircTask {
play_request_id: Option<u64>, play_request_id: Option<u64>,
play_status: SpircPlayStatus, play_status: SpircPlayStatus,
remote_update: BoxedStream<Result<Frame, Error>>, remote_update: BoxedStream<Result<(String, Frame), Error>>,
connection_id_update: BoxedStream<Result<String, Error>>, connection_id_update: BoxedStream<Result<String, Error>>,
user_attributes_update: BoxedStream<Result<UserAttributesUpdate, Error>>, user_attributes_update: BoxedStream<Result<UserAttributesUpdate, Error>>,
user_attributes_mutation: BoxedStream<Result<UserAttributesMutation, Error>>, user_attributes_mutation: BoxedStream<Result<UserAttributesMutation, Error>>,
@ -255,9 +256,10 @@ fn url_encode(bytes: impl AsRef<[u8]>) -> String {
} }
impl Spirc { impl Spirc {
pub fn new( pub async fn new(
config: ConnectConfig, config: ConnectConfig,
session: Session, session: Session,
credentials: Credentials,
player: Player, player: Player,
mixer: Box<dyn Mixer>, mixer: Box<dyn Mixer>,
) -> Result<(Spirc, impl Future<Output = ()>), Error> { ) -> Result<(Spirc, impl Future<Output = ()>), Error> {
@ -265,23 +267,21 @@ impl Spirc {
let ident = session.device_id().to_owned(); let ident = session.device_id().to_owned();
// Uri updated in response to issue #288
let canonical_username = &session.username();
debug!("canonical_username: {}", canonical_username);
let uri = format!("hm://remote/user/{}/", url_encode(canonical_username));
let remote_update = Box::pin( let remote_update = Box::pin(
session session
.mercury() .mercury()
.subscribe(uri.clone()) .listen_for("hm://remote/user/")
.inspect_err(|x| error!("remote update error: {}", x))
.and_then(|x| async move { Ok(x) })
.map(Result::unwrap) // guaranteed to be safe by `and_then` above
.map(UnboundedReceiverStream::new) .map(UnboundedReceiverStream::new)
.flatten_stream() .flatten_stream()
.map(|response| -> Result<Frame, Error> { .map(|response| -> Result<(String, Frame), Error> {
let uri_split: Vec<&str> = response.uri.split('/').collect();
let username = match uri_split.get(uri_split.len() - 2) {
Some(s) => s.to_string(),
None => String::new(),
};
let data = response.payload.first().ok_or(SpircError::NoData)?; let data = response.payload.first().ok_or(SpircError::NoData)?;
Ok(Frame::parse_from_bytes(data)?) Ok((username, Frame::parse_from_bytes(data)?))
}), }),
); );
@ -324,7 +324,14 @@ impl Spirc {
}), }),
); );
let sender = session.mercury().sender(uri); // Connect *after* all message listeners are registered
session.connect(credentials).await?;
let canonical_username = &session.username();
debug!("canonical_username: {}", canonical_username);
let sender_uri = format!("hm://remote/user/{}/", url_encode(canonical_username));
let sender = session.mercury().sender(sender_uri);
let (cmd_tx, cmd_rx) = mpsc::unbounded_channel(); let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
@ -414,13 +421,17 @@ impl SpircTask {
tokio::select! { tokio::select! {
remote_update = self.remote_update.next() => match remote_update { remote_update = self.remote_update.next() => match remote_update {
Some(result) => match result { Some(result) => match result {
Ok(update) => if let Err(e) = self.handle_remote_update(update) { Ok((username, frame)) => {
if username != self.session.username() {
error!("could not dispatch remote update: frame was intended for {}", username);
} else if let Err(e) = self.handle_remote_update(frame) {
error!("could not dispatch remote update: {}", e); error!("could not dispatch remote update: {}", e);
} }
},
Err(e) => error!("could not parse remote update: {}", e), Err(e) => error!("could not parse remote update: {}", e),
} }
None => { None => {
error!("subscription terminated"); error!("remote update selected, but none received");
break; break;
} }
}, },
@ -513,7 +524,7 @@ impl SpircTask {
} }
if self.sender.flush().await.is_err() { if self.sender.flush().await.is_err() {
warn!("Cannot flush spirc event sender."); warn!("Cannot flush spirc event sender when done.");
} }
} }
@ -754,7 +765,7 @@ impl SpircTask {
} }
fn handle_remote_update(&mut self, update: Frame) -> Result<(), Error> { fn handle_remote_update(&mut self, update: Frame) -> Result<(), Error> {
trace!("Received update frame: {:#?}", update,); trace!("Received update frame: {:#?}", update);
// First see if this update was intended for us. // First see if this update was intended for us.
let device_id = &self.ident; let device_id = &self.ident;

View file

@ -248,9 +248,6 @@ impl MercuryManager {
} }
Err(MercuryError::Response(response).into()) Err(MercuryError::Response(response).into())
} else if let PacketType::MercuryEvent = cmd { } else if let PacketType::MercuryEvent = cmd {
self.lock(|inner| {
let mut found = false;
// TODO: This is just a workaround to make utf-8 encoded usernames work. // TODO: This is just a workaround to make utf-8 encoded usernames work.
// A better solution would be to use an uri struct and urlencode it directly // A better solution would be to use an uri struct and urlencode it directly
// before sending while saving the subscription under its unencoded form. // before sending while saving the subscription under its unencoded form.
@ -263,6 +260,9 @@ impl MercuryManager {
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join("/"); .join("/");
let mut found = false;
self.lock(|inner| {
inner.subscriptions.retain(|&(ref prefix, ref sub)| { inner.subscriptions.retain(|&(ref prefix, ref sub)| {
if encoded_uri.starts_with(prefix) { if encoded_uri.starts_with(prefix) {
found = true; found = true;
@ -275,6 +275,7 @@ impl MercuryManager {
true true
} }
}); });
});
if !found { if !found {
debug!("unknown subscription uri={}", &response.uri); debug!("unknown subscription uri={}", &response.uri);
@ -283,7 +284,6 @@ impl MercuryManager {
} else { } else {
Ok(()) Ok(())
} }
})
} else if let Some(cb) = pending.callback { } else if let Some(cb) = pending.callback {
cb.send(Ok(response)).map_err(|_| MercuryError::Channel)?; cb.send(Ok(response)).map_err(|_| MercuryError::Channel)?;
Ok(()) Ok(())

View file

@ -46,6 +46,8 @@ pub enum SessionError {
AuthenticationError(#[from] AuthenticationError), AuthenticationError(#[from] AuthenticationError),
#[error("Cannot create session: {0}")] #[error("Cannot create session: {0}")]
IoError(#[from] io::Error), IoError(#[from] io::Error),
#[error("Session is not connected")]
NotConnected,
#[error("packet {0} unknown")] #[error("packet {0} unknown")]
Packet(u8), Packet(u8),
} }
@ -55,6 +57,7 @@ impl From<SessionError> for Error {
match err { match err {
SessionError::AuthenticationError(_) => Error::unauthenticated(err), SessionError::AuthenticationError(_) => Error::unauthenticated(err),
SessionError::IoError(_) => Error::unavailable(err), SessionError::IoError(_) => Error::unavailable(err),
SessionError::NotConnected => Error::unavailable(err),
SessionError::Packet(_) => Error::unimplemented(err), SessionError::Packet(_) => Error::unimplemented(err),
} }
} }
@ -83,7 +86,7 @@ struct SessionInternal {
data: RwLock<SessionData>, data: RwLock<SessionData>,
http_client: HttpClient, http_client: HttpClient,
tx_connection: mpsc::UnboundedSender<(u8, Vec<u8>)>, tx_connection: OnceCell<mpsc::UnboundedSender<(u8, Vec<u8>)>>,
apresolver: OnceCell<ApResolver>, apresolver: OnceCell<ApResolver>,
audio_key: OnceCell<AudioKeyManager>, audio_key: OnceCell<AudioKeyManager>,
@ -104,22 +107,17 @@ static SESSION_COUNTER: AtomicUsize = AtomicUsize::new(0);
pub struct Session(Arc<SessionInternal>); pub struct Session(Arc<SessionInternal>);
impl Session { impl Session {
pub async fn connect( pub fn new(config: SessionConfig, cache: Option<Cache>) -> Self {
config: SessionConfig,
credentials: Credentials,
cache: Option<Cache>,
) -> Result<Session, Error> {
let http_client = HttpClient::new(config.proxy.as_ref()); let http_client = HttpClient::new(config.proxy.as_ref());
let (sender_tx, sender_rx) = mpsc::unbounded_channel();
let session_id = SESSION_COUNTER.fetch_add(1, Ordering::AcqRel);
let session_id = SESSION_COUNTER.fetch_add(1, Ordering::AcqRel);
debug!("new Session[{}]", session_id); debug!("new Session[{}]", session_id);
let session = Session(Arc::new(SessionInternal { Self(Arc::new(SessionInternal {
config, config,
data: RwLock::new(SessionData::default()), data: RwLock::new(SessionData::default()),
http_client, http_client,
tx_connection: sender_tx, tx_connection: OnceCell::new(),
cache: cache.map(Arc::new), cache: cache.map(Arc::new),
apresolver: OnceCell::new(), apresolver: OnceCell::new(),
audio_key: OnceCell::new(), audio_key: OnceCell::new(),
@ -129,27 +127,33 @@ impl Session {
token_provider: OnceCell::new(), token_provider: OnceCell::new(),
handle: tokio::runtime::Handle::current(), handle: tokio::runtime::Handle::current(),
session_id, session_id,
})); }))
}
let ap = session.apresolver().resolve("accesspoint").await?; pub async fn connect(&self, credentials: Credentials) -> Result<(), Error> {
let ap = self.apresolver().resolve("accesspoint").await?;
info!("Connecting to AP \"{}:{}\"", ap.0, ap.1); info!("Connecting to AP \"{}:{}\"", ap.0, ap.1);
let mut transport = let mut transport = connection::connect(&ap.0, ap.1, self.config().proxy.as_ref()).await?;
connection::connect(&ap.0, ap.1, session.config().proxy.as_ref()).await?;
let reusable_credentials = let reusable_credentials =
connection::authenticate(&mut transport, credentials, &session.config().device_id) connection::authenticate(&mut transport, credentials, &self.config().device_id).await?;
.await?;
info!("Authenticated as \"{}\" !", reusable_credentials.username); info!("Authenticated as \"{}\" !", reusable_credentials.username);
session.0.data.write().user_data.canonical_username = reusable_credentials.username.clone(); self.set_username(&reusable_credentials.username);
if let Some(cache) = session.cache() { if let Some(cache) = self.cache() {
cache.save_credentials(&reusable_credentials); cache.save_credentials(&reusable_credentials);
} }
let (tx_connection, rx_connection) = mpsc::unbounded_channel();
self.0
.tx_connection
.set(tx_connection)
.map_err(|_| SessionError::NotConnected)?;
let (sink, stream) = transport.split(); let (sink, stream) = transport.split();
let sender_task = UnboundedReceiverStream::new(sender_rx) let sender_task = UnboundedReceiverStream::new(rx_connection)
.map(Ok) .map(Ok)
.forward(sink); .forward(sink);
let receiver_task = DispatchTask(stream, session.weak()); let receiver_task = DispatchTask(stream, self.weak());
tokio::spawn(async move { tokio::spawn(async move {
let result = future::try_join(sender_task, receiver_task).await; let result = future::try_join(sender_task, receiver_task).await;
@ -159,7 +163,7 @@ impl Session {
} }
}); });
Ok(session) Ok(())
} }
pub fn apresolver(&self) -> &ApResolver { pub fn apresolver(&self) -> &ApResolver {
@ -323,8 +327,10 @@ impl Session {
} }
pub fn send_packet(&self, cmd: PacketType, data: Vec<u8>) -> Result<(), Error> { pub fn send_packet(&self, cmd: PacketType, data: Vec<u8>) -> Result<(), Error> {
self.0.tx_connection.send((cmd as u8, data))?; match self.0.tx_connection.get() {
Ok(()) Some(tx) => Ok(tx.send((cmd as u8, data))?),
None => Err(SessionError::NotConnected.into()),
}
} }
pub fn cache(&self) -> Option<&Arc<Cache>> { pub fn cache(&self) -> Option<&Arc<Cache>> {
@ -366,6 +372,10 @@ impl Session {
self.0.data.read().user_data.canonical_username.clone() self.0.data.read().user_data.canonical_username.clone()
} }
pub fn set_username(&self, username: &str) {
self.0.data.write().user_data.canonical_username = username.to_owned();
}
pub fn country(&self) -> String { pub fn country(&self) -> String {
self.0.data.read().user_data.country.clone() self.0.data.read().user_data.country.clone()
} }

View file

@ -9,7 +9,7 @@ use std::{
time::{Duration, Instant}, time::{Duration, Instant},
}; };
use futures_util::{future, FutureExt, StreamExt}; use futures_util::StreamExt;
use log::{error, info, trace, warn}; use log::{error, info, trace, warn};
use sha1::{Digest, Sha1}; use sha1::{Digest, Sha1};
use thiserror::Error; use thiserror::Error;
@ -1562,7 +1562,9 @@ async fn main() {
let mut player_event_channel: Option<UnboundedReceiver<PlayerEvent>> = None; let mut player_event_channel: Option<UnboundedReceiver<PlayerEvent>> = None;
let mut auto_connect_times: Vec<Instant> = vec![]; let mut auto_connect_times: Vec<Instant> = vec![];
let mut discovery = None; let mut discovery = None;
let mut connecting: Pin<Box<dyn future::FusedFuture<Output = _>>> = Box::pin(future::pending()); let mut connecting = false;
let session = Session::new(setup.session_config.clone(), setup.cache.clone());
if setup.enable_discovery { if setup.enable_discovery {
let device_id = setup.session_config.device_id.clone(); let device_id = setup.session_config.device_id.clone();
@ -1582,15 +1584,8 @@ async fn main() {
} }
if let Some(credentials) = setup.credentials { if let Some(credentials) = setup.credentials {
last_credentials = Some(credentials.clone()); last_credentials = Some(credentials);
connecting = Box::pin( connecting = true;
Session::connect(
setup.session_config.clone(),
credentials,
setup.cache.clone(),
)
.fuse(),
);
} }
loop { loop {
@ -1616,11 +1611,7 @@ async fn main() {
tokio::spawn(spirc_task); tokio::spawn(spirc_task);
} }
connecting = Box::pin(Session::connect( connecting = true;
setup.session_config.clone(),
credentials,
setup.cache.clone(),
).fuse());
}, },
None => { None => {
error!("Discovery stopped unexpectedly"); error!("Discovery stopped unexpectedly");
@ -1628,8 +1619,7 @@ async fn main() {
} }
} }
}, },
session = &mut connecting, if !connecting.is_terminated() => match session { _ = async {}, if connecting && last_credentials.is_some() => {
Ok(session) => {
let mixer_config = setup.mixer_config.clone(); let mixer_config = setup.mixer_config.clone();
let mixer = (setup.mixer)(mixer_config); let mixer = (setup.mixer)(mixer_config);
let player_config = setup.player_config.clone(); let player_config = setup.player_config.clone();
@ -1664,7 +1654,7 @@ async fn main() {
} }
}; };
let (spirc_, spirc_task_) = match Spirc::new(connect_config, session, player, mixer) { let (spirc_, spirc_task_) = match Spirc::new(connect_config, session.clone(), last_credentials.clone().unwrap(), player, mixer).await {
Ok((spirc_, spirc_task_)) => (spirc_, spirc_task_), Ok((spirc_, spirc_task_)) => (spirc_, spirc_task_),
Err(e) => { Err(e) => {
error!("could not initialize spirc: {}", e); error!("could not initialize spirc: {}", e);
@ -1674,17 +1664,14 @@ async fn main() {
spirc = Some(spirc_); spirc = Some(spirc_);
spirc_task = Some(Box::pin(spirc_task_)); spirc_task = Some(Box::pin(spirc_task_));
player_event_channel = Some(event_channel); player_event_channel = Some(event_channel);
},
Err(e) => { connecting = false;
error!("Connection failed: {}", e);
exit(1);
}
}, },
_ = async { _ = async {
if let Some(task) = spirc_task.as_mut() { if let Some(task) = spirc_task.as_mut() {
task.await; task.await;
} }
}, if spirc_task.is_some() => { }, if spirc_task.is_some() && !connecting => {
spirc_task = None; spirc_task = None;
warn!("Spirc shut down unexpectedly"); warn!("Spirc shut down unexpectedly");
@ -1695,14 +1682,9 @@ async fn main() {
}; };
match last_credentials.clone() { match last_credentials.clone() {
Some(credentials) if !reconnect_exceeds_rate_limit() => { Some(_) if !reconnect_exceeds_rate_limit() => {
auto_connect_times.push(Instant::now()); auto_connect_times.push(Instant::now());
connecting = true;
connecting = Box::pin(Session::connect(
setup.session_config.clone(),
credentials,
setup.cache.clone(),
).fuse());
}, },
_ => { _ => {
error!("Spirc shut down too often. Not reconnecting automatically."); error!("Spirc shut down too often. Not reconnecting automatically.");