Register message listeners before connecting

This commit is contained in:
Roderick van Domburg 2022-01-16 01:14:00 +01:00
parent 8811b89b2d
commit abbc3bade8
No known key found for this signature in database
GPG key ID: FE2585E713F9F30A
4 changed files with 141 additions and 138 deletions

View file

@ -8,7 +8,7 @@ use std::{
use futures_util::{
future::{self, FusedFuture},
stream::FusedStream,
FutureExt, StreamExt, TryFutureExt,
FutureExt, StreamExt,
};
use protobuf::{self, Message};
@ -21,6 +21,7 @@ use crate::{
config::ConnectConfig,
context::StationContext,
core::{
authentication::Credentials,
mercury::{MercuryError, MercurySender},
session::UserAttributes,
util::SeqGenerator,
@ -92,7 +93,7 @@ struct SpircTask {
play_request_id: Option<u64>,
play_status: SpircPlayStatus,
remote_update: BoxedStream<Result<Frame, Error>>,
remote_update: BoxedStream<Result<(String, Frame), Error>>,
connection_id_update: BoxedStream<Result<String, Error>>,
user_attributes_update: BoxedStream<Result<UserAttributesUpdate, Error>>,
user_attributes_mutation: BoxedStream<Result<UserAttributesMutation, Error>>,
@ -255,9 +256,10 @@ fn url_encode(bytes: impl AsRef<[u8]>) -> String {
}
impl Spirc {
pub fn new(
pub async fn new(
config: ConnectConfig,
session: Session,
credentials: Credentials,
player: Player,
mixer: Box<dyn Mixer>,
) -> Result<(Spirc, impl Future<Output = ()>), Error> {
@ -265,23 +267,21 @@ impl Spirc {
let ident = session.device_id().to_owned();
// Uri updated in response to issue #288
let canonical_username = &session.username();
debug!("canonical_username: {}", canonical_username);
let uri = format!("hm://remote/user/{}/", url_encode(canonical_username));
let remote_update = Box::pin(
session
.mercury()
.subscribe(uri.clone())
.inspect_err(|x| error!("remote update error: {}", x))
.and_then(|x| async move { Ok(x) })
.map(Result::unwrap) // guaranteed to be safe by `and_then` above
.listen_for("hm://remote/user/")
.map(UnboundedReceiverStream::new)
.flatten_stream()
.map(|response| -> Result<Frame, Error> {
.map(|response| -> Result<(String, Frame), Error> {
let uri_split: Vec<&str> = response.uri.split('/').collect();
let username = match uri_split.get(uri_split.len() - 2) {
Some(s) => s.to_string(),
None => String::new(),
};
let data = response.payload.first().ok_or(SpircError::NoData)?;
Ok(Frame::parse_from_bytes(data)?)
Ok((username, Frame::parse_from_bytes(data)?))
}),
);
@ -324,7 +324,14 @@ impl Spirc {
}),
);
let sender = session.mercury().sender(uri);
// Connect *after* all message listeners are registered
session.connect(credentials).await?;
let canonical_username = &session.username();
debug!("canonical_username: {}", canonical_username);
let sender_uri = format!("hm://remote/user/{}/", url_encode(canonical_username));
let sender = session.mercury().sender(sender_uri);
let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
@ -414,13 +421,17 @@ impl SpircTask {
tokio::select! {
remote_update = self.remote_update.next() => match remote_update {
Some(result) => match result {
Ok(update) => if let Err(e) = self.handle_remote_update(update) {
Ok((username, frame)) => {
if username != self.session.username() {
error!("could not dispatch remote update: frame was intended for {}", username);
} else if let Err(e) = self.handle_remote_update(frame) {
error!("could not dispatch remote update: {}", e);
}
},
Err(e) => error!("could not parse remote update: {}", e),
}
None => {
error!("subscription terminated");
error!("remote update selected, but none received");
break;
}
},
@ -513,7 +524,7 @@ impl SpircTask {
}
if self.sender.flush().await.is_err() {
warn!("Cannot flush spirc event sender.");
warn!("Cannot flush spirc event sender when done.");
}
}
@ -754,7 +765,7 @@ impl SpircTask {
}
fn handle_remote_update(&mut self, update: Frame) -> Result<(), Error> {
trace!("Received update frame: {:#?}", update,);
trace!("Received update frame: {:#?}", update);
// First see if this update was intended for us.
let device_id = &self.ident;

View file

@ -248,9 +248,6 @@ impl MercuryManager {
}
Err(MercuryError::Response(response).into())
} else if let PacketType::MercuryEvent = cmd {
self.lock(|inner| {
let mut found = false;
// TODO: This is just a workaround to make utf-8 encoded usernames work.
// A better solution would be to use an uri struct and urlencode it directly
// before sending while saving the subscription under its unencoded form.
@ -263,6 +260,9 @@ impl MercuryManager {
.collect::<Vec<String>>()
.join("/");
let mut found = false;
self.lock(|inner| {
inner.subscriptions.retain(|&(ref prefix, ref sub)| {
if encoded_uri.starts_with(prefix) {
found = true;
@ -275,6 +275,7 @@ impl MercuryManager {
true
}
});
});
if !found {
debug!("unknown subscription uri={}", &response.uri);
@ -283,7 +284,6 @@ impl MercuryManager {
} else {
Ok(())
}
})
} else if let Some(cb) = pending.callback {
cb.send(Ok(response)).map_err(|_| MercuryError::Channel)?;
Ok(())

View file

@ -46,6 +46,8 @@ pub enum SessionError {
AuthenticationError(#[from] AuthenticationError),
#[error("Cannot create session: {0}")]
IoError(#[from] io::Error),
#[error("Session is not connected")]
NotConnected,
#[error("packet {0} unknown")]
Packet(u8),
}
@ -55,6 +57,7 @@ impl From<SessionError> for Error {
match err {
SessionError::AuthenticationError(_) => Error::unauthenticated(err),
SessionError::IoError(_) => Error::unavailable(err),
SessionError::NotConnected => Error::unavailable(err),
SessionError::Packet(_) => Error::unimplemented(err),
}
}
@ -83,7 +86,7 @@ struct SessionInternal {
data: RwLock<SessionData>,
http_client: HttpClient,
tx_connection: mpsc::UnboundedSender<(u8, Vec<u8>)>,
tx_connection: OnceCell<mpsc::UnboundedSender<(u8, Vec<u8>)>>,
apresolver: OnceCell<ApResolver>,
audio_key: OnceCell<AudioKeyManager>,
@ -104,22 +107,17 @@ static SESSION_COUNTER: AtomicUsize = AtomicUsize::new(0);
pub struct Session(Arc<SessionInternal>);
impl Session {
pub async fn connect(
config: SessionConfig,
credentials: Credentials,
cache: Option<Cache>,
) -> Result<Session, Error> {
pub fn new(config: SessionConfig, cache: Option<Cache>) -> Self {
let http_client = HttpClient::new(config.proxy.as_ref());
let (sender_tx, sender_rx) = mpsc::unbounded_channel();
let session_id = SESSION_COUNTER.fetch_add(1, Ordering::AcqRel);
let session_id = SESSION_COUNTER.fetch_add(1, Ordering::AcqRel);
debug!("new Session[{}]", session_id);
let session = Session(Arc::new(SessionInternal {
Self(Arc::new(SessionInternal {
config,
data: RwLock::new(SessionData::default()),
http_client,
tx_connection: sender_tx,
tx_connection: OnceCell::new(),
cache: cache.map(Arc::new),
apresolver: OnceCell::new(),
audio_key: OnceCell::new(),
@ -129,27 +127,33 @@ impl Session {
token_provider: OnceCell::new(),
handle: tokio::runtime::Handle::current(),
session_id,
}));
}))
}
let ap = session.apresolver().resolve("accesspoint").await?;
pub async fn connect(&self, credentials: Credentials) -> Result<(), Error> {
let ap = self.apresolver().resolve("accesspoint").await?;
info!("Connecting to AP \"{}:{}\"", ap.0, ap.1);
let mut transport =
connection::connect(&ap.0, ap.1, session.config().proxy.as_ref()).await?;
let mut transport = connection::connect(&ap.0, ap.1, self.config().proxy.as_ref()).await?;
let reusable_credentials =
connection::authenticate(&mut transport, credentials, &session.config().device_id)
.await?;
connection::authenticate(&mut transport, credentials, &self.config().device_id).await?;
info!("Authenticated as \"{}\" !", reusable_credentials.username);
session.0.data.write().user_data.canonical_username = reusable_credentials.username.clone();
if let Some(cache) = session.cache() {
self.set_username(&reusable_credentials.username);
if let Some(cache) = self.cache() {
cache.save_credentials(&reusable_credentials);
}
let (tx_connection, rx_connection) = mpsc::unbounded_channel();
self.0
.tx_connection
.set(tx_connection)
.map_err(|_| SessionError::NotConnected)?;
let (sink, stream) = transport.split();
let sender_task = UnboundedReceiverStream::new(sender_rx)
let sender_task = UnboundedReceiverStream::new(rx_connection)
.map(Ok)
.forward(sink);
let receiver_task = DispatchTask(stream, session.weak());
let receiver_task = DispatchTask(stream, self.weak());
tokio::spawn(async move {
let result = future::try_join(sender_task, receiver_task).await;
@ -159,7 +163,7 @@ impl Session {
}
});
Ok(session)
Ok(())
}
pub fn apresolver(&self) -> &ApResolver {
@ -323,8 +327,10 @@ impl Session {
}
pub fn send_packet(&self, cmd: PacketType, data: Vec<u8>) -> Result<(), Error> {
self.0.tx_connection.send((cmd as u8, data))?;
Ok(())
match self.0.tx_connection.get() {
Some(tx) => Ok(tx.send((cmd as u8, data))?),
None => Err(SessionError::NotConnected.into()),
}
}
pub fn cache(&self) -> Option<&Arc<Cache>> {
@ -366,6 +372,10 @@ impl Session {
self.0.data.read().user_data.canonical_username.clone()
}
pub fn set_username(&self, username: &str) {
self.0.data.write().user_data.canonical_username = username.to_owned();
}
pub fn country(&self) -> String {
self.0.data.read().user_data.country.clone()
}

View file

@ -9,7 +9,7 @@ use std::{
time::{Duration, Instant},
};
use futures_util::{future, FutureExt, StreamExt};
use futures_util::StreamExt;
use log::{error, info, trace, warn};
use sha1::{Digest, Sha1};
use thiserror::Error;
@ -1562,7 +1562,9 @@ async fn main() {
let mut player_event_channel: Option<UnboundedReceiver<PlayerEvent>> = None;
let mut auto_connect_times: Vec<Instant> = vec![];
let mut discovery = None;
let mut connecting: Pin<Box<dyn future::FusedFuture<Output = _>>> = Box::pin(future::pending());
let mut connecting = false;
let session = Session::new(setup.session_config.clone(), setup.cache.clone());
if setup.enable_discovery {
let device_id = setup.session_config.device_id.clone();
@ -1582,15 +1584,8 @@ async fn main() {
}
if let Some(credentials) = setup.credentials {
last_credentials = Some(credentials.clone());
connecting = Box::pin(
Session::connect(
setup.session_config.clone(),
credentials,
setup.cache.clone(),
)
.fuse(),
);
last_credentials = Some(credentials);
connecting = true;
}
loop {
@ -1616,11 +1611,7 @@ async fn main() {
tokio::spawn(spirc_task);
}
connecting = Box::pin(Session::connect(
setup.session_config.clone(),
credentials,
setup.cache.clone(),
).fuse());
connecting = true;
},
None => {
error!("Discovery stopped unexpectedly");
@ -1628,8 +1619,7 @@ async fn main() {
}
}
},
session = &mut connecting, if !connecting.is_terminated() => match session {
Ok(session) => {
_ = async {}, if connecting && last_credentials.is_some() => {
let mixer_config = setup.mixer_config.clone();
let mixer = (setup.mixer)(mixer_config);
let player_config = setup.player_config.clone();
@ -1664,7 +1654,7 @@ async fn main() {
}
};
let (spirc_, spirc_task_) = match Spirc::new(connect_config, session, player, mixer) {
let (spirc_, spirc_task_) = match Spirc::new(connect_config, session.clone(), last_credentials.clone().unwrap(), player, mixer).await {
Ok((spirc_, spirc_task_)) => (spirc_, spirc_task_),
Err(e) => {
error!("could not initialize spirc: {}", e);
@ -1674,17 +1664,14 @@ async fn main() {
spirc = Some(spirc_);
spirc_task = Some(Box::pin(spirc_task_));
player_event_channel = Some(event_channel);
},
Err(e) => {
error!("Connection failed: {}", e);
exit(1);
}
connecting = false;
},
_ = async {
if let Some(task) = spirc_task.as_mut() {
task.await;
}
}, if spirc_task.is_some() => {
}, if spirc_task.is_some() && !connecting => {
spirc_task = None;
warn!("Spirc shut down unexpectedly");
@ -1695,14 +1682,9 @@ async fn main() {
};
match last_credentials.clone() {
Some(credentials) if !reconnect_exceeds_rate_limit() => {
Some(_) if !reconnect_exceeds_rate_limit() => {
auto_connect_times.push(Instant::now());
connecting = Box::pin(Session::connect(
setup.session_config.clone(),
credentials,
setup.cache.clone(),
).fuse());
connecting = true;
},
_ => {
error!("Spirc shut down too often. Not reconnecting automatically.");