[Core/connection] Refactor to async/await

This commit is contained in:
ashthespy 2021-01-23 22:21:42 +00:00
parent 47a1575c00
commit 94fc0a12da
5 changed files with 139 additions and 213 deletions

View file

@ -15,7 +15,7 @@ version = "0.1.3"
[dependencies]
base64 = "0.13"
byteorder = "1.3"
bytes = "0.4"
bytes = "0.5"
error-chain = { version = "0.12", default_features = false }
futures = {version = "0.3",features =["unstable","bilock"]}
httparse = "1.3"

View file

@ -45,7 +45,7 @@ impl Encoder<APCodecItem> for APCodec {
buf.reserve(3 + payload.len());
buf.put_u8(cmd);
buf.put_u16_be(payload.len() as u16);
buf.put_u16(payload.len() as u16);
buf.extend_from_slice(&payload);
self.encode_cipher.nonce_u32(self.encode_nonce);

View file

@ -1,97 +1,71 @@
use byteorder::{BigEndian, ByteOrder, WriteBytesExt};
use super::codec::APCodec;
use crate::{
diffie_hellman::DHLocalKeys,
protocol,
protocol::keyexchange::{APResponseMessage, ClientHello, ClientResponsePlaintext},
util,
};
use hmac::{Hmac, Mac};
use protobuf::{self, Message};
use rand::thread_rng;
use sha1::Sha1;
use std::io::{self, Read};
use std::marker::PhantomData;
// use tokio_codec::{Decoder, Framed};
// use tokio_io::io::{read_exact, write_all, ReadExact, Window, WriteAll};
// use tokio_io::{AsyncRead, AsyncWrite};
use super::codec::APCodec;
use crate::diffie_hellman::DHLocalKeys;
use crate::protocol;
use crate::protocol::keyexchange::{APResponseMessage, ClientHello, ClientResponsePlaintext};
use crate::util;
use futures::{
io::{ReadExact, Window, WriteAll},
Future,
};
use std::{
pin::Pin,
task::{Context, Poll},
};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use std::{io, marker::Unpin};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio_util::codec::{Decoder, Framed};
pub struct Handshake<'a, T> {
keys: DHLocalKeys,
state: HandshakeState<'a, T>,
}
// struct handshake {
// keys: DHLocalKeys,
// connection: T,
// accumulator: Vec<u8>,
// }
enum HandshakeState<'a, T> {
ClientHello(WriteAll<'a, T>),
APResponse(RecvPacket<'a, T, APResponseMessage>),
ClientResponse(Option<APCodec>, WriteAll<'a, T>),
}
pub fn handshake<'a, T: AsyncRead + AsyncWrite>(connection: T) -> Handshake<'a, T> {
pub async fn handshake<T: AsyncRead + AsyncWrite + Unpin>(
mut connection: T,
) -> Result<Framed<T, APCodec>, io::Error> {
let local_keys = DHLocalKeys::random(&mut thread_rng());
let client_hello = client_hello(connection, local_keys.public_key());
// Send ClientHello
let client_hello: Vec<u8> = client_hello(local_keys.public_key()).await?;
connection.write_all(&client_hello).await?;
Handshake {
keys: local_keys,
state: HandshakeState::ClientHello(client_hello),
}
// Receive APResponseMessage
let size = connection.read_u32().await?;
let mut buffer = Vec::with_capacity(size as usize - 4);
let bytes = connection.read_buf(&mut buffer).await?;
let message = protobuf::parse_from_bytes::<APResponseMessage>(&buffer[..bytes])?;
let mut accumulator = client_hello.clone();
accumulator.extend_from_slice(&size.to_be_bytes());
accumulator.extend_from_slice(&buffer);
let remote_key = message
.get_challenge()
.get_login_crypto_challenge()
.get_diffie_hellman()
.get_gs()
.to_owned();
// Solve the challenge
let shared_secret = local_keys.shared_secret(&remote_key);
let (challenge, send_key, recv_key) = compute_keys(&shared_secret, &accumulator);
let codec = APCodec::new(&send_key, &recv_key);
let buffer: Vec<u8> = client_response(challenge).await?;
connection.write_all(&buffer).await?;
let framed = codec.framed(connection);
Ok(framed)
}
impl<'a, T: AsyncRead + AsyncWrite> Future for Handshake<'a, T> {
type Output = Result<Framed<T, APCodec>, io::Error>;
// async fn recv_packet<T: AsyncRead + Unpin, Message: protobuf::Message>(
// mut connection: T,
// ) -> Result<(Message, &Vec<u8>), io::Error> {
// let size = connection.read_u32().await?;
// let mut buffer = Vec::with_capacity(size as usize - 4);
// let bytes = connection.read_buf(&mut buffer).await?;
// let proto = protobuf::parse_from_bytes(&buffer[..bytes])?;
// Ok(proto)
// }
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
use self::HandshakeState::*;
loop {
self.state = match self.state {
ClientHello(ref mut write) => {
let (connection, accumulator) = ready!(write.poll());
let read = recv_packet(connection, accumulator);
APResponse(read)
}
APResponse(ref mut read) => {
let (connection, message, accumulator) = ready!(read.poll());
let remote_key = message
.get_challenge()
.get_login_crypto_challenge()
.get_diffie_hellman()
.get_gs()
.to_owned();
let shared_secret = self.keys.shared_secret(&remote_key);
let (challenge, send_key, recv_key) =
compute_keys(&shared_secret, &accumulator);
let codec = APCodec::new(&send_key, &recv_key);
let write = client_response(connection, challenge);
ClientResponse(Some(codec), write)
}
ClientResponse(ref mut codec, ref mut write) => {
let (connection, _) = ready!(write.poll());
let codec = codec.take().unwrap();
let framed = codec.framed(connection);
return Poll::Ready(Ok(framed));
}
}
}
}
}
fn client_hello<'a, T: AsyncWrite>(connection: T, gc: Vec<u8>) -> WriteAll<'a, T> {
async fn client_hello(gc: Vec<u8>) -> Result<Vec<u8>, io::Error> {
let mut packet = ClientHello::new();
packet
.mut_build_info()
@ -99,7 +73,7 @@ fn client_hello<'a, T: AsyncWrite>(connection: T, gc: Vec<u8>) -> WriteAll<'a, T
packet
.mut_build_info()
.set_platform(protocol::keyexchange::Platform::PLATFORM_LINUX_X86);
packet.mut_build_info().set_version(109800078);
packet.mut_build_info().set_version(109_800_078);
packet
.mut_cryptosuites_supported()
.push(protocol::keyexchange::Cryptosuite::CRYPTO_SUITE_SHANNON);
@ -114,16 +88,15 @@ fn client_hello<'a, T: AsyncWrite>(connection: T, gc: Vec<u8>) -> WriteAll<'a, T
packet.set_client_nonce(util::rand_vec(&mut thread_rng(), 0x10));
packet.set_padding(vec![0x1e]);
let mut buffer = vec![0, 4];
let size = 2 + 4 + packet.compute_size();
buffer.write_u32::<BigEndian>(size).unwrap();
packet.write_to_vec(&mut buffer).unwrap();
// write_all(connection, buffer)
connection.write_all(&buffer)
let mut buffer = Vec::with_capacity(size as usize);
buffer.extend(&[0, 4]);
buffer.write_u32(size).await?;
buffer.extend(packet.write_to_bytes()?);
Ok(buffer)
}
fn client_response<'a, T: AsyncWrite>(connection: T, challenge: Vec<u8>) -> WriteAll<'a, T> {
async fn client_response(challenge: Vec<u8>) -> Result<Vec<u8>, io::Error> {
let mut packet = ClientResponsePlaintext::new();
packet
.mut_login_crypto_response()
@ -132,73 +105,14 @@ fn client_response<'a, T: AsyncWrite>(connection: T, challenge: Vec<u8>) -> Writ
packet.mut_pow_response();
packet.mut_crypto_response();
let mut buffer = vec![];
// let mut buffer = vec![];
let size = 4 + packet.compute_size();
buffer.write_u32::<BigEndian>(size).unwrap();
packet.write_to_vec(&mut buffer).unwrap();
// write_all(connection, buffer)
connection.write_all(&buffer)
}
enum RecvPacket<'a, T, M: Message> {
Header(ReadExact<'a, T>, PhantomData<M>),
Body(ReadExact<'a, T>, PhantomData<M>),
}
fn recv_packet<'a, T: AsyncRead, M>(connection: T, acc: Vec<u8>) -> RecvPacket<'a, T, M>
where
T: Read,
M: Message,
{
RecvPacket::Header(read_into_accumulator(connection, 4, acc), PhantomData)
}
impl<'a, T: AsyncRead, M> Future for RecvPacket<'a, T, M>
where
T: Read,
M: Message,
{
type Output = Result<(T, M, Vec<u8>), io::Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
use self::RecvPacket::*;
loop {
*self = match *self {
Header(ref mut read, _) => {
let (connection, header) = ready!(read.poll());
let size = BigEndian::read_u32(header.as_ref()) as usize;
let acc = header.into_inner();
let read = read_into_accumulator(connection, size - 4, acc);
RecvPacket::Body(read, PhantomData)
}
Body(ref mut read, _) => {
let (connection, data) = ready!(read.poll());
let message = protobuf::parse_from_bytes(data.as_ref()).unwrap();
let acc = data.into_inner();
return Poll::Ready(Ok((connection, message, acc)));
}
}
}
}
}
fn read_into_accumulator<'a, T: AsyncRead>(
connection: T,
size: usize,
mut acc: Vec<u8>,
) -> ReadExact<'a, T> {
let offset = acc.len();
acc.resize(offset + size, 0);
let mut window = Window::new(acc);
window.set_start(offset);
// read_exact(connection, window)
connection.read_exact(window)
let mut buffer = Vec::with_capacity(size as usize);
buffer.write_u32(size).await?;
// This seems to reallocate
// packet.write_to_vec(&mut buffer)?;
buffer.extend(packet.write_to_bytes()?);
Ok(buffer)
}
fn compute_keys(shared_secret: &[u8], packets: &[u8]) -> (Vec<u8>, Vec<u8>, Vec<u8>) {

View file

@ -1,48 +1,39 @@
mod codec;
mod handshake;
pub use self::codec::APCodec;
pub use self::handshake::handshake;
use tokio::net::TcpStream;
pub use self::{codec::APCodec, handshake::handshake};
use crate::{authentication::Credentials, version};
use futures::{AsyncRead, AsyncWrite, Future, Sink, SinkExt, Stream, StreamExt};
use futures::{SinkExt, StreamExt};
use protobuf::{self, Message};
use std::io;
use std::net::ToSocketAddrs;
use std::{io, net::ToSocketAddrs};
use tokio::net::TcpStream;
use tokio_util::codec::Framed;
// use futures::compat::{AsyncWrite01CompatExt, AsyncRead01CompatExt};
// use tokio_util::compat::{self, Tokio02AsyncReadCompatExt, Tokio02AsyncWriteCompatExt};
// use tokio_codec::Framed;
// use tokio_core::net::TcpStream;
// use tokio_core::reactor::Handle;
use url::Url;
use crate::authentication::Credentials;
use crate::version;
use crate::proxytunnel;
// use crate::proxytunnel;
pub type Transport = Framed<TcpStream, APCodec>;
pub async fn connect(addr: String, proxy: &Option<Url>) -> Result<Transport, io::Error> {
let (addr, connect_url): (_, Option<String>) = match *proxy {
Some(ref url) => {
unimplemented!()
// info!("Using proxy \"{}\"", url);
//
// let mut iter = url.to_socket_addrs()?;
// let socket_addr = iter.next().ok_or(io::Error::new(
// io::ErrorKind::NotFound,
// "Can't resolve proxy server address",
// ))?;
// (socket_addr, Some(addr))
info!("Using proxy \"{}\"", url);
let mut iter = url.to_socket_addrs()?;
let socket_addr = iter.next().ok_or_else(|| {
io::Error::new(
io::ErrorKind::NotFound,
"Can't resolve proxy server address",
)
})?;
(socket_addr, Some(addr))
}
None => {
let mut iter = addr.to_socket_addrs()?;
let socket_addr = iter.next().ok_or(io::Error::new(
io::ErrorKind::NotFound,
"Can't resolve server address",
))?;
let socket_addr = iter.next().ok_or_else(|| {
io::Error::new(io::ErrorKind::NotFound, "Can't resolve server address")
})?;
(socket_addr, None)
}
};
@ -54,8 +45,7 @@ pub async fn connect(addr: String, proxy: &Option<Url>) -> Result<Transport, io:
// let connection = handshake(connection).await?;
// Ok(connection)
} else {
let connection = handshake(connection).await?;
Ok(connection)
handshake(connection).await
}
}
@ -64,8 +54,10 @@ pub async fn authenticate(
credentials: Credentials,
device_id: String,
) -> Result<(Transport, Credentials), io::Error> {
use crate::protocol::authentication::{APWelcome, ClientResponseEncrypted, CpuFamily, Os};
use crate::protocol::keyexchange::APLoginFailed;
use crate::protocol::{
authentication::{APWelcome, ClientResponseEncrypted, CpuFamily, Os},
keyexchange::APLoginFailed,
};
let mut packet = ClientResponseEncrypted::new();
packet
@ -94,13 +86,11 @@ pub async fn authenticate(
let cmd: u8 = 0xab;
let data = packet.write_to_bytes().unwrap();
transport.send((cmd, data)).await;
transport.send((cmd, data)).await?;
let packet = transport.next().await;
// let (packet, transport) = transport
// .into_future()
// .map_err(|(err, _stream)| err)
// .await?;
// TODO: Don't panic?
match packet {
Some(Ok((0xac, data))) => {
let welcome_data: APWelcome = protobuf::parse_from_bytes(data.as_ref()).unwrap();

View file

@ -1,23 +1,45 @@
use env_logger;
use std::env;
use tokio::runtime::Runtime;
use futures::future::TryFutureExt;
use librespot_core::*;
use tokio::runtime;
use librespot_core::{apresolve::apresolve_or_fallback, connection};
// TODO: Rewrite this into an actual test instead of this wonder
fn main() {
env_logger::init();
let mut rt = Runtime::new().unwrap();
let args: Vec<_> = env::args().collect();
if args.len() != 4 {
println!("Usage: {} USERNAME PASSWORD PLAYLIST", args[0]);
#[cfg(test)]
mod tests {
use super::*;
// Test AP Resolve
use apresolve::apresolve_or_fallback;
#[test]
fn test_ap_resolve() {
let mut rt = runtime::Runtime::new().unwrap();
let ap = rt.block_on(apresolve_or_fallback(&None, &Some(80)));
println!("AP: {:?}", ap);
}
// let username = args[1].to_owned();
// let password = args[2].to_owned();
let ap = rt.block_on(apresolve_or_fallback(&None, &Some(80)));
// Test connect
use authentication::Credentials;
use config::SessionConfig;
use connection;
#[test]
fn test_connection() {
println!("Running connection test");
let mut rt = runtime::Runtime::new().unwrap();
let access_point_addr = rt.block_on(apresolve_or_fallback(&None, &None)).unwrap();
let credentials = Credentials::with_password(String::from("test"), String::from("test"));
let session_config = SessionConfig::default();
let proxy = None;
println!("AP: {:?}", ap);
let connection = rt.block_on(connection::connect(&None));
println!("Connecting to AP \"{}\"", access_point_addr);
let connection = connection::connect(access_point_addr, &proxy);
let device_id = session_config.device_id.clone();
let authentication = connection.and_then(move |connection| {
connection::authenticate(connection, credentials, device_id)
});
match rt.block_on(authentication) {
Ok((_transport, reusable_credentials)) => {
println!("Authenticated as \"{}\" !", reusable_credentials.username)
}
// TODO assert that we get BadCredentials once we don't panic
Err(e) => println!("ConnectError: {:?}", e),
}
}
}