[Core/connection] Refactor to async/await

This commit is contained in:
ashthespy 2021-01-23 22:21:42 +00:00
parent 47a1575c00
commit 94fc0a12da
5 changed files with 139 additions and 213 deletions

View file

@ -15,7 +15,7 @@ version = "0.1.3"
[dependencies] [dependencies]
base64 = "0.13" base64 = "0.13"
byteorder = "1.3" byteorder = "1.3"
bytes = "0.4" bytes = "0.5"
error-chain = { version = "0.12", default_features = false } error-chain = { version = "0.12", default_features = false }
futures = {version = "0.3",features =["unstable","bilock"]} futures = {version = "0.3",features =["unstable","bilock"]}
httparse = "1.3" httparse = "1.3"

View file

@ -45,7 +45,7 @@ impl Encoder<APCodecItem> for APCodec {
buf.reserve(3 + payload.len()); buf.reserve(3 + payload.len());
buf.put_u8(cmd); buf.put_u8(cmd);
buf.put_u16_be(payload.len() as u16); buf.put_u16(payload.len() as u16);
buf.extend_from_slice(&payload); buf.extend_from_slice(&payload);
self.encode_cipher.nonce_u32(self.encode_nonce); self.encode_cipher.nonce_u32(self.encode_nonce);

View file

@ -1,97 +1,71 @@
use byteorder::{BigEndian, ByteOrder, WriteBytesExt}; use super::codec::APCodec;
use crate::{
diffie_hellman::DHLocalKeys,
protocol,
protocol::keyexchange::{APResponseMessage, ClientHello, ClientResponsePlaintext},
util,
};
use hmac::{Hmac, Mac}; use hmac::{Hmac, Mac};
use protobuf::{self, Message}; use protobuf::{self, Message};
use rand::thread_rng; use rand::thread_rng;
use sha1::Sha1; use sha1::Sha1;
use std::io::{self, Read}; use std::{io, marker::Unpin};
use std::marker::PhantomData; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
// use tokio_codec::{Decoder, Framed};
// use tokio_io::io::{read_exact, write_all, ReadExact, Window, WriteAll};
// use tokio_io::{AsyncRead, AsyncWrite};
use super::codec::APCodec;
use crate::diffie_hellman::DHLocalKeys;
use crate::protocol;
use crate::protocol::keyexchange::{APResponseMessage, ClientHello, ClientResponsePlaintext};
use crate::util;
use futures::{
io::{ReadExact, Window, WriteAll},
Future,
};
use std::{
pin::Pin,
task::{Context, Poll},
};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio_util::codec::{Decoder, Framed}; use tokio_util::codec::{Decoder, Framed};
pub struct Handshake<'a, T> { // struct handshake {
keys: DHLocalKeys, // keys: DHLocalKeys,
state: HandshakeState<'a, T>, // connection: T,
} // accumulator: Vec<u8>,
// }
enum HandshakeState<'a, T> { pub async fn handshake<T: AsyncRead + AsyncWrite + Unpin>(
ClientHello(WriteAll<'a, T>), mut connection: T,
APResponse(RecvPacket<'a, T, APResponseMessage>), ) -> Result<Framed<T, APCodec>, io::Error> {
ClientResponse(Option<APCodec>, WriteAll<'a, T>),
}
pub fn handshake<'a, T: AsyncRead + AsyncWrite>(connection: T) -> Handshake<'a, T> {
let local_keys = DHLocalKeys::random(&mut thread_rng()); let local_keys = DHLocalKeys::random(&mut thread_rng());
let client_hello = client_hello(connection, local_keys.public_key()); // Send ClientHello
let client_hello: Vec<u8> = client_hello(local_keys.public_key()).await?;
connection.write_all(&client_hello).await?;
Handshake { // Receive APResponseMessage
keys: local_keys, let size = connection.read_u32().await?;
state: HandshakeState::ClientHello(client_hello), let mut buffer = Vec::with_capacity(size as usize - 4);
} let bytes = connection.read_buf(&mut buffer).await?;
let message = protobuf::parse_from_bytes::<APResponseMessage>(&buffer[..bytes])?;
let mut accumulator = client_hello.clone();
accumulator.extend_from_slice(&size.to_be_bytes());
accumulator.extend_from_slice(&buffer);
let remote_key = message
.get_challenge()
.get_login_crypto_challenge()
.get_diffie_hellman()
.get_gs()
.to_owned();
// Solve the challenge
let shared_secret = local_keys.shared_secret(&remote_key);
let (challenge, send_key, recv_key) = compute_keys(&shared_secret, &accumulator);
let codec = APCodec::new(&send_key, &recv_key);
let buffer: Vec<u8> = client_response(challenge).await?;
connection.write_all(&buffer).await?;
let framed = codec.framed(connection);
Ok(framed)
} }
impl<'a, T: AsyncRead + AsyncWrite> Future for Handshake<'a, T> { // async fn recv_packet<T: AsyncRead + Unpin, Message: protobuf::Message>(
type Output = Result<Framed<T, APCodec>, io::Error>; // mut connection: T,
// ) -> Result<(Message, &Vec<u8>), io::Error> {
// let size = connection.read_u32().await?;
// let mut buffer = Vec::with_capacity(size as usize - 4);
// let bytes = connection.read_buf(&mut buffer).await?;
// let proto = protobuf::parse_from_bytes(&buffer[..bytes])?;
// Ok(proto)
// }
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { async fn client_hello(gc: Vec<u8>) -> Result<Vec<u8>, io::Error> {
use self::HandshakeState::*;
loop {
self.state = match self.state {
ClientHello(ref mut write) => {
let (connection, accumulator) = ready!(write.poll());
let read = recv_packet(connection, accumulator);
APResponse(read)
}
APResponse(ref mut read) => {
let (connection, message, accumulator) = ready!(read.poll());
let remote_key = message
.get_challenge()
.get_login_crypto_challenge()
.get_diffie_hellman()
.get_gs()
.to_owned();
let shared_secret = self.keys.shared_secret(&remote_key);
let (challenge, send_key, recv_key) =
compute_keys(&shared_secret, &accumulator);
let codec = APCodec::new(&send_key, &recv_key);
let write = client_response(connection, challenge);
ClientResponse(Some(codec), write)
}
ClientResponse(ref mut codec, ref mut write) => {
let (connection, _) = ready!(write.poll());
let codec = codec.take().unwrap();
let framed = codec.framed(connection);
return Poll::Ready(Ok(framed));
}
}
}
}
}
fn client_hello<'a, T: AsyncWrite>(connection: T, gc: Vec<u8>) -> WriteAll<'a, T> {
let mut packet = ClientHello::new(); let mut packet = ClientHello::new();
packet packet
.mut_build_info() .mut_build_info()
@ -99,7 +73,7 @@ fn client_hello<'a, T: AsyncWrite>(connection: T, gc: Vec<u8>) -> WriteAll<'a, T
packet packet
.mut_build_info() .mut_build_info()
.set_platform(protocol::keyexchange::Platform::PLATFORM_LINUX_X86); .set_platform(protocol::keyexchange::Platform::PLATFORM_LINUX_X86);
packet.mut_build_info().set_version(109800078); packet.mut_build_info().set_version(109_800_078);
packet packet
.mut_cryptosuites_supported() .mut_cryptosuites_supported()
.push(protocol::keyexchange::Cryptosuite::CRYPTO_SUITE_SHANNON); .push(protocol::keyexchange::Cryptosuite::CRYPTO_SUITE_SHANNON);
@ -114,16 +88,15 @@ fn client_hello<'a, T: AsyncWrite>(connection: T, gc: Vec<u8>) -> WriteAll<'a, T
packet.set_client_nonce(util::rand_vec(&mut thread_rng(), 0x10)); packet.set_client_nonce(util::rand_vec(&mut thread_rng(), 0x10));
packet.set_padding(vec![0x1e]); packet.set_padding(vec![0x1e]);
let mut buffer = vec![0, 4];
let size = 2 + 4 + packet.compute_size(); let size = 2 + 4 + packet.compute_size();
buffer.write_u32::<BigEndian>(size).unwrap(); let mut buffer = Vec::with_capacity(size as usize);
packet.write_to_vec(&mut buffer).unwrap(); buffer.extend(&[0, 4]);
buffer.write_u32(size).await?;
// write_all(connection, buffer) buffer.extend(packet.write_to_bytes()?);
connection.write_all(&buffer) Ok(buffer)
} }
fn client_response<'a, T: AsyncWrite>(connection: T, challenge: Vec<u8>) -> WriteAll<'a, T> { async fn client_response(challenge: Vec<u8>) -> Result<Vec<u8>, io::Error> {
let mut packet = ClientResponsePlaintext::new(); let mut packet = ClientResponsePlaintext::new();
packet packet
.mut_login_crypto_response() .mut_login_crypto_response()
@ -132,73 +105,14 @@ fn client_response<'a, T: AsyncWrite>(connection: T, challenge: Vec<u8>) -> Writ
packet.mut_pow_response(); packet.mut_pow_response();
packet.mut_crypto_response(); packet.mut_crypto_response();
let mut buffer = vec![]; // let mut buffer = vec![];
let size = 4 + packet.compute_size(); let size = 4 + packet.compute_size();
buffer.write_u32::<BigEndian>(size).unwrap(); let mut buffer = Vec::with_capacity(size as usize);
packet.write_to_vec(&mut buffer).unwrap(); buffer.write_u32(size).await?;
// This seems to reallocate
// write_all(connection, buffer) // packet.write_to_vec(&mut buffer)?;
connection.write_all(&buffer) buffer.extend(packet.write_to_bytes()?);
} Ok(buffer)
enum RecvPacket<'a, T, M: Message> {
Header(ReadExact<'a, T>, PhantomData<M>),
Body(ReadExact<'a, T>, PhantomData<M>),
}
fn recv_packet<'a, T: AsyncRead, M>(connection: T, acc: Vec<u8>) -> RecvPacket<'a, T, M>
where
T: Read,
M: Message,
{
RecvPacket::Header(read_into_accumulator(connection, 4, acc), PhantomData)
}
impl<'a, T: AsyncRead, M> Future for RecvPacket<'a, T, M>
where
T: Read,
M: Message,
{
type Output = Result<(T, M, Vec<u8>), io::Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
use self::RecvPacket::*;
loop {
*self = match *self {
Header(ref mut read, _) => {
let (connection, header) = ready!(read.poll());
let size = BigEndian::read_u32(header.as_ref()) as usize;
let acc = header.into_inner();
let read = read_into_accumulator(connection, size - 4, acc);
RecvPacket::Body(read, PhantomData)
}
Body(ref mut read, _) => {
let (connection, data) = ready!(read.poll());
let message = protobuf::parse_from_bytes(data.as_ref()).unwrap();
let acc = data.into_inner();
return Poll::Ready(Ok((connection, message, acc)));
}
}
}
}
}
fn read_into_accumulator<'a, T: AsyncRead>(
connection: T,
size: usize,
mut acc: Vec<u8>,
) -> ReadExact<'a, T> {
let offset = acc.len();
acc.resize(offset + size, 0);
let mut window = Window::new(acc);
window.set_start(offset);
// read_exact(connection, window)
connection.read_exact(window)
} }
fn compute_keys(shared_secret: &[u8], packets: &[u8]) -> (Vec<u8>, Vec<u8>, Vec<u8>) { fn compute_keys(shared_secret: &[u8], packets: &[u8]) -> (Vec<u8>, Vec<u8>, Vec<u8>) {

View file

@ -1,48 +1,39 @@
mod codec; mod codec;
mod handshake; mod handshake;
pub use self::codec::APCodec; pub use self::{codec::APCodec, handshake::handshake};
pub use self::handshake::handshake; use crate::{authentication::Credentials, version};
use tokio::net::TcpStream;
use futures::{AsyncRead, AsyncWrite, Future, Sink, SinkExt, Stream, StreamExt}; use futures::{SinkExt, StreamExt};
use protobuf::{self, Message}; use protobuf::{self, Message};
use std::io; use std::{io, net::ToSocketAddrs};
use std::net::ToSocketAddrs; use tokio::net::TcpStream;
use tokio_util::codec::Framed; use tokio_util::codec::Framed;
// use futures::compat::{AsyncWrite01CompatExt, AsyncRead01CompatExt};
// use tokio_util::compat::{self, Tokio02AsyncReadCompatExt, Tokio02AsyncWriteCompatExt};
// use tokio_codec::Framed;
// use tokio_core::net::TcpStream;
// use tokio_core::reactor::Handle;
use url::Url; use url::Url;
use crate::authentication::Credentials; // use crate::proxytunnel;
use crate::version;
use crate::proxytunnel;
pub type Transport = Framed<TcpStream, APCodec>; pub type Transport = Framed<TcpStream, APCodec>;
pub async fn connect(addr: String, proxy: &Option<Url>) -> Result<Transport, io::Error> { pub async fn connect(addr: String, proxy: &Option<Url>) -> Result<Transport, io::Error> {
let (addr, connect_url): (_, Option<String>) = match *proxy { let (addr, connect_url): (_, Option<String>) = match *proxy {
Some(ref url) => { Some(ref url) => {
unimplemented!() info!("Using proxy \"{}\"", url);
// info!("Using proxy \"{}\"", url);
// let mut iter = url.to_socket_addrs()?;
// let mut iter = url.to_socket_addrs()?; let socket_addr = iter.next().ok_or_else(|| {
// let socket_addr = iter.next().ok_or(io::Error::new( io::Error::new(
// io::ErrorKind::NotFound, io::ErrorKind::NotFound,
// "Can't resolve proxy server address", "Can't resolve proxy server address",
// ))?; )
// (socket_addr, Some(addr)) })?;
(socket_addr, Some(addr))
} }
None => { None => {
let mut iter = addr.to_socket_addrs()?; let mut iter = addr.to_socket_addrs()?;
let socket_addr = iter.next().ok_or(io::Error::new( let socket_addr = iter.next().ok_or_else(|| {
io::ErrorKind::NotFound, io::Error::new(io::ErrorKind::NotFound, "Can't resolve server address")
"Can't resolve server address", })?;
))?;
(socket_addr, None) (socket_addr, None)
} }
}; };
@ -54,8 +45,7 @@ pub async fn connect(addr: String, proxy: &Option<Url>) -> Result<Transport, io:
// let connection = handshake(connection).await?; // let connection = handshake(connection).await?;
// Ok(connection) // Ok(connection)
} else { } else {
let connection = handshake(connection).await?; handshake(connection).await
Ok(connection)
} }
} }
@ -64,8 +54,10 @@ pub async fn authenticate(
credentials: Credentials, credentials: Credentials,
device_id: String, device_id: String,
) -> Result<(Transport, Credentials), io::Error> { ) -> Result<(Transport, Credentials), io::Error> {
use crate::protocol::authentication::{APWelcome, ClientResponseEncrypted, CpuFamily, Os}; use crate::protocol::{
use crate::protocol::keyexchange::APLoginFailed; authentication::{APWelcome, ClientResponseEncrypted, CpuFamily, Os},
keyexchange::APLoginFailed,
};
let mut packet = ClientResponseEncrypted::new(); let mut packet = ClientResponseEncrypted::new();
packet packet
@ -94,13 +86,11 @@ pub async fn authenticate(
let cmd: u8 = 0xab; let cmd: u8 = 0xab;
let data = packet.write_to_bytes().unwrap(); let data = packet.write_to_bytes().unwrap();
transport.send((cmd, data)).await; transport.send((cmd, data)).await?;
let packet = transport.next().await; let packet = transport.next().await;
// let (packet, transport) = transport
// .into_future() // TODO: Don't panic?
// .map_err(|(err, _stream)| err)
// .await?;
match packet { match packet {
Some(Ok((0xac, data))) => { Some(Ok((0xac, data))) => {
let welcome_data: APWelcome = protobuf::parse_from_bytes(data.as_ref()).unwrap(); let welcome_data: APWelcome = protobuf::parse_from_bytes(data.as_ref()).unwrap();

View file

@ -1,23 +1,45 @@
use env_logger; use futures::future::TryFutureExt;
use std::env; use librespot_core::*;
use tokio::runtime::Runtime; use tokio::runtime;
use librespot_core::{apresolve::apresolve_or_fallback, connection}; #[cfg(test)]
mod tests {
// TODO: Rewrite this into an actual test instead of this wonder use super::*;
fn main() { // Test AP Resolve
env_logger::init(); use apresolve::apresolve_or_fallback;
let mut rt = Runtime::new().unwrap(); #[test]
fn test_ap_resolve() {
let args: Vec<_> = env::args().collect(); let mut rt = runtime::Runtime::new().unwrap();
if args.len() != 4 { let ap = rt.block_on(apresolve_or_fallback(&None, &Some(80)));
println!("Usage: {} USERNAME PASSWORD PLAYLIST", args[0]); println!("AP: {:?}", ap);
} }
// let username = args[1].to_owned();
// let password = args[2].to_owned();
let ap = rt.block_on(apresolve_or_fallback(&None, &Some(80))); // Test connect
use authentication::Credentials;
use config::SessionConfig;
use connection;
#[test]
fn test_connection() {
println!("Running connection test");
let mut rt = runtime::Runtime::new().unwrap();
let access_point_addr = rt.block_on(apresolve_or_fallback(&None, &None)).unwrap();
let credentials = Credentials::with_password(String::from("test"), String::from("test"));
let session_config = SessionConfig::default();
let proxy = None;
println!("AP: {:?}", ap); println!("Connecting to AP \"{}\"", access_point_addr);
let connection = rt.block_on(connection::connect(&None)); let connection = connection::connect(access_point_addr, &proxy);
let device_id = session_config.device_id.clone();
let authentication = connection.and_then(move |connection| {
connection::authenticate(connection, credentials, device_id)
});
match rt.block_on(authentication) {
Ok((_transport, reusable_credentials)) => {
println!("Authenticated as \"{}\" !", reusable_credentials.username)
}
// TODO assert that we get BadCredentials once we don't panic
Err(e) => println!("ConnectError: {:?}", e),
}
}
} }