From 4d277e5b757f2bd4f8512f43316a28402410075d Mon Sep 17 00:00:00 2001 From: Paul Lietar Date: Mon, 9 May 2016 12:22:51 +0100 Subject: [PATCH] stream: refactor into a reactor pattern. --- src/album_cover.rs | 89 +++++++++++++---- src/audio_file.rs | 198 +++++++++++++++++++++----------------- src/audio_file2.rs | 136 ++++++++++++++++++++++++++ src/audio_key.rs | 5 +- src/connection.rs | 4 - src/lib.in.rs | 3 +- src/mercury.rs | 5 +- src/session.rs | 30 +++--- src/stream.rs | 233 ++++++++++++++++++++++++--------------------- 9 files changed, 465 insertions(+), 238 deletions(-) create mode 100644 src/audio_file2.rs diff --git a/src/album_cover.rs b/src/album_cover.rs index 9b5c1b9f..d453ccb5 100644 --- a/src/album_cover.rs +++ b/src/album_cover.rs @@ -4,24 +4,75 @@ use byteorder::{WriteBytesExt, BigEndian}; use session::Session; use util::FileId; -use stream::StreamEvent; +use stream; -pub fn get_album_cover(session: &Session, file_id: FileId) - -> eventual::Future, ()> { - - let (channel_id, rx) = session.allocate_stream(); - - let mut req: Vec = Vec::new(); - req.write_u16::(channel_id).unwrap(); - req.write_u16::(0).unwrap(); - req.write(&file_id.0).unwrap(); - session.send_packet(0x19, &req).unwrap(); - - rx.map_err(|_| ()) - .reduce(Vec::new(), |mut current, event| { - if let StreamEvent::Data(data) = event { - current.extend_from_slice(&data) - } - current - }) +pub struct AlbumCover { + file_id: FileId, + data: Vec, + cover_tx: eventual::Complete, ()>, +} + +impl stream::Handler for AlbumCover { + fn on_create(self, channel_id: stream::ChannelId, session: &Session) -> stream::Response { + let mut req: Vec = Vec::new(); + req.write_u16::(channel_id).unwrap(); + req.write_u16::(0).unwrap(); + req.write(&self.file_id.0).unwrap(); + session.send_packet(0x19, &req).unwrap(); + + stream::Response::Continue(self) + } + + fn on_header(self, _header_id: u8, _header_data: &[u8], _session: &Session) -> stream::Response { + stream::Response::Continue(self) + } + + fn on_data(mut self, data: &[u8], _session: &Session) -> stream::Response { + self.data.extend_from_slice(data); + stream::Response::Continue(self) + } + + fn on_close(self, _session: &Session) -> stream::Response { + // End of chunk, request a new one + self.cover_tx.complete(self.data); + stream::Response::Close + } + + fn on_error(self, _session: &Session) -> stream::Response { + self.cover_tx.fail(()); + stream::Response::Close + } + + fn box_on_create(self: Box, channel_id: stream::ChannelId, session: &Session) -> stream::Response> { + self.on_create(channel_id, session).boxed() + } + + fn box_on_header(self: Box, header_id: u8, header_data: &[u8], session: &Session) -> stream::Response> { + self.on_header(header_id, header_data, session).boxed() + } + + fn box_on_data(self: Box, data: &[u8], session: &Session) -> stream::Response> { + self.on_data(data, session).boxed() + } + + fn box_on_error(self: Box, session: &Session) -> stream::Response> { + self.on_error(session).boxed() + } + + fn box_on_close(self: Box, session: &Session) -> stream::Response> { + self.on_close(session).boxed() + } +} + +impl AlbumCover { + pub fn get(file_id: FileId, session: &Session) -> eventual::Future, ()> { + let (tx, rx) = eventual::Future::pair(); + session.stream(Box::new(AlbumCover { + file_id: file_id, + data: Vec::new(), + cover_tx: tx, + })); + + rx + } } diff --git a/src/audio_file.rs b/src/audio_file.rs index 370af1b5..1b4aa81a 100644 --- a/src/audio_file.rs +++ b/src/audio_file.rs @@ -4,14 +4,13 @@ use eventual; use std::cmp::min; use std::sync::{Arc, Condvar, Mutex}; use std::sync::mpsc::{self, TryRecvError}; -use std::thread; use std::fs; use std::io::{self, Read, Write, Seek, SeekFrom}; use tempfile::NamedTempFile; use util::{FileId, IgnoreExt}; use session::Session; -use stream::StreamEvent; +use audio_file2; const CHUNK_SIZE: usize = 0x20000; @@ -24,127 +23,152 @@ pub struct AudioFile { shared: Arc, } -struct AudioFileShared { - file_id: FileId, +struct AudioFileInternal { + partial_tx: Option>, + complete_tx: eventual::Complete, + write_file: NamedTempFile, + seek_rx: mpsc::Receiver, + shared: Arc, chunk_count: usize, +} + +struct AudioFileShared { cond: Condvar, bitmap: Mutex, } impl AudioFile { pub fn new(session: &Session, file_id: FileId) - -> (AudioFile, eventual::Future) { - - let size = session.stream(file_id, 0, 1) - .iter() - .filter_map(|event| { - match event { - StreamEvent::Header(id, ref data) if id == 0x3 => { - Some(BigEndian::read_u32(data) as usize * 4) - } - _ => None, - } - }) - .next() - .unwrap(); - - let chunk_count = (size + CHUNK_SIZE - 1) / CHUNK_SIZE; + -> (eventual::Future, eventual::Future) { let shared = Arc::new(AudioFileShared { - file_id: file_id, - chunk_count: chunk_count, cond: Condvar::new(), - bitmap: Mutex::new(BitSet::with_capacity(chunk_count)), + bitmap: Mutex::new(BitSet::new()), }); - let write_file = NamedTempFile::new().unwrap(); - write_file.set_len(size as u64).unwrap(); - let read_file = write_file.reopen().unwrap(); - let (seek_tx, seek_rx) = mpsc::channel(); + let (partial_tx, partial_rx) = eventual::Future::pair(); let (complete_tx, complete_rx) = eventual::Future::pair(); - { - let shared = shared.clone(); - let session = session.clone(); - thread::spawn(move || AudioFile::fetch(&session, shared, write_file, seek_rx, complete_tx)); + let internal = AudioFileInternal { + shared: shared.clone(), + write_file: NamedTempFile::new().unwrap(), + seek_rx: seek_rx, + partial_tx: Some(partial_tx), + complete_tx: complete_tx, + chunk_count: 0, + }; + + audio_file2::AudioFile::new(file_id, 0, internal, session); + + let file_rx = partial_rx.map(|read_file| { + AudioFile { + read_file: read_file, + + position: 0, + seek: seek_tx, + + shared: shared, + } + }); + + (file_rx, complete_rx) + } +} + +impl audio_file2::Handler for AudioFileInternal { + fn on_header(mut self, header_id: u8, header_data: &[u8], _session: &Session) -> audio_file2::Response { + if header_id == 0x3 { + if let Some(tx) = self.partial_tx.take() { + let size = BigEndian::read_u32(header_data) as usize * 4; + self.write_file.set_len(size as u64).unwrap(); + let read_file = self.write_file.reopen().unwrap(); + + self.chunk_count = (size + CHUNK_SIZE - 1) / CHUNK_SIZE; + self.shared.bitmap.lock().unwrap().reserve_len(self.chunk_count); + + tx.complete(read_file) + } } - (AudioFile { - read_file: read_file, - - position: 0, - seek: seek_tx, - - shared: shared, - }, complete_rx) + audio_file2::Response::Continue(self) } - fn fetch(session: &Session, - shared: Arc, - mut write_file: NamedTempFile, - seek_rx: mpsc::Receiver, - complete_tx: eventual::Complete) { - let mut index = 0; + fn on_data(mut self, offset: usize, data: &[u8], _session: &Session) -> audio_file2::Response { + self.write_file.seek(SeekFrom::Start(offset as u64)).unwrap(); + self.write_file.write_all(&data).unwrap(); - loop { - match seek_rx.try_recv() { - Ok(position) => { - index = position as usize / CHUNK_SIZE; - } - Err(TryRecvError::Disconnected) => break, - Err(TryRecvError::Empty) => (), - } + // We've crossed a chunk boundary + // Mark the previous one as complete in the bitmap and notify the reader + let seek = if (offset + data.len()) % CHUNK_SIZE < data.len() { + let mut index = offset / CHUNK_SIZE; + let mut bitmap = self.shared.bitmap.lock().unwrap(); + bitmap.insert(index); + self.shared.cond.notify_all(); - let bitmap = shared.bitmap.lock().unwrap(); - if bitmap.len() >= shared.chunk_count { + println!("{}/{} {:?}", bitmap.len(), self.chunk_count, *bitmap); + + // If all blocks are complete when can stop + if bitmap.len() >= self.chunk_count { + println!("All good"); drop(bitmap); - write_file.seek(SeekFrom::Start(0)).unwrap(); - complete_tx.complete(write_file); - break; + self.write_file.seek(SeekFrom::Start(0)).unwrap(); + self.complete_tx.complete(self.write_file); + return audio_file2::Response::Close; } + // Find the next undownloaded block + index = (index + 1) % self.chunk_count; while bitmap.contains(index) { - index = (index + 1) % shared.chunk_count; + index = (index + 1) % self.chunk_count; } - drop(bitmap); - AudioFile::fetch_chunk(session, &shared, &mut write_file, index); + Some(index) + } else { + None + }; + + match self.seek_rx.try_recv() { + Ok(seek_offset) => audio_file2::Response::Seek(self, seek_offset as usize / CHUNK_SIZE * CHUNK_SIZE), + Err(TryRecvError::Disconnected) => audio_file2::Response::Close, + Err(TryRecvError::Empty) => match seek { + Some(index) => audio_file2::Response::Seek(self, index * CHUNK_SIZE), + None => audio_file2::Response::Continue(self), + }, } } - fn fetch_chunk(session: &Session, - shared: &Arc, - write_file: &mut NamedTempFile, - index: usize) { + fn on_eof(mut self, _session: &Session) -> audio_file2::Response { + let index = { + let mut index = self.chunk_count - 1; + let mut bitmap = self.shared.bitmap.lock().unwrap(); + bitmap.insert(index); + self.shared.cond.notify_all(); - let rx = session.stream(shared.file_id, - (index * CHUNK_SIZE / 4) as u32, - (CHUNK_SIZE / 4) as u32); + println!("{:?}", *bitmap); - debug!("Fetch chunk {} / {}", index + 1, shared.chunk_count); + println!("{} {}", bitmap.len(), self.chunk_count); - write_file.seek(SeekFrom::Start((index * CHUNK_SIZE) as u64)).unwrap(); - - let mut size = 0usize; - for event in rx.iter() { - match event { - StreamEvent::Header(..) => (), - StreamEvent::Data(data) => { - write_file.write_all(&data).unwrap(); - - size += data.len(); - if size >= CHUNK_SIZE { - break; - } - } + // If all blocks are complete when can stop + if bitmap.len() >= self.chunk_count { + drop(bitmap); + self.write_file.seek(SeekFrom::Start(0)).unwrap(); + self.complete_tx.complete(self.write_file); + return audio_file2::Response::Close; } - } - let mut bitmap = shared.bitmap.lock().unwrap(); - bitmap.insert(index as usize); + // Find the next undownloaded block + index = (index + 1) % self.chunk_count; + while bitmap.contains(index) { + index = (index + 1) % self.chunk_count; + } + index + }; - shared.cond.notify_all(); + audio_file2::Response::Seek(self, index * CHUNK_SIZE) + } + + fn on_error(self, _session: &Session) { } } diff --git a/src/audio_file2.rs b/src/audio_file2.rs new file mode 100644 index 00000000..873a53af --- /dev/null +++ b/src/audio_file2.rs @@ -0,0 +1,136 @@ +use session::Session; +use stream; +use util::FileId; + +use byteorder::{BigEndian, WriteBytesExt}; +use std::io::Write; + +const CHUNK_SIZE: usize = 0x20000; + +pub enum Response { +// Wait(H), + Continue(H), + Seek(H, usize), + Close, +} + +pub trait Handler : Sized + Send + 'static { + fn on_header(self, header_id: u8, header_data: &[u8], session: &Session) -> Response; + fn on_data(self, offset: usize, data: &[u8], session: &Session) -> Response; + fn on_eof(self, session: &Session) -> Response; + fn on_error(self, session: &Session); +} + +pub struct AudioFile { + handler: H, + file_id: FileId, + offset: usize, +} + +impl AudioFile { + pub fn new(file_id: FileId, offset: usize, handler: H, session: &Session) { + let handler = AudioFile { + handler: handler, + file_id: file_id, + offset: offset, + }; + + session.stream(Box::new(handler)); + } +} + +impl stream::Handler for AudioFile { + fn on_create(self, channel_id: stream::ChannelId, session: &Session) -> stream::Response { + debug!("Got channel {}", channel_id); + + let mut data: Vec = Vec::new(); + data.write_u16::(channel_id).unwrap(); + data.write_u8(0).unwrap(); + data.write_u8(1).unwrap(); + data.write_u16::(0x0000).unwrap(); + data.write_u32::(0x00000000).unwrap(); + data.write_u32::(0x00009C40).unwrap(); + data.write_u32::(0x00020000).unwrap(); + data.write(&self.file_id.0).unwrap(); + data.write_u32::(self.offset as u32 / 4).unwrap(); + data.write_u32::((self.offset + CHUNK_SIZE) as u32 / 4).unwrap(); + + session.send_packet(0x8, &data).unwrap(); + + stream::Response::Continue(self) + } + + fn on_header(mut self, header_id: u8, header_data: &[u8], session: &Session) -> stream::Response { + //println!("on_header"); + match self.handler.on_header(header_id, header_data, session) { + Response::Continue(handler) => { + self.handler = handler; + stream::Response::Continue(self) + } + Response::Seek(handler, offset) => { + self.handler = handler; + self.offset = offset; + stream::Response::Spawn(self) + } + Response::Close => stream::Response::Close, + } + } + + fn on_data(mut self, data: &[u8], session: &Session) -> stream::Response { + //println!("on_data"); + match self.handler.on_data(self.offset, data, session) { + Response::Continue(handler) => { + self.handler = handler; + self.offset += data.len(); + stream::Response::Continue(self) + } + Response::Seek(handler, offset) => { + println!("seek request {}", offset); + self.handler = handler; + self.offset = offset; + stream::Response::Spawn(self) + } + Response::Close => stream::Response::Close, + } + } + + fn on_close(self, _session: &Session) -> stream::Response { + // End of chunk, request a new one + stream::Response::Spawn(self) + } + + fn on_error(mut self, session: &Session) -> stream::Response { + println!("on_error"); + match self.handler.on_eof(session) { + Response::Continue(_) => stream::Response::Close, + Response::Seek(handler, offset) => { + println!("seek request {}", offset); + self.handler = handler; + self.offset = offset; + stream::Response::Spawn(self) + } + Response::Close => stream::Response::Close, + } + } + + fn box_on_create(self: Box, channel_id: stream::ChannelId, session: &Session) -> stream::Response> { + self.on_create(channel_id, session).boxed() + } + + fn box_on_header(self: Box, header_id: u8, header_data: &[u8], session: &Session) -> stream::Response> { + self.on_header(header_id, header_data, session).boxed() + } + + fn box_on_data(self: Box, data: &[u8], session: &Session) -> stream::Response> { + self.on_data(data, session).boxed() + } + + fn box_on_error(self: Box, session: &Session) -> stream::Response> { + self.on_error(session).boxed() + } + + fn box_on_close(self: Box, session: &Session) -> stream::Response> { + self.on_close(session).boxed() + } +} + diff --git a/src/audio_key.rs b/src/audio_key.rs index d48c7ec3..64355568 100644 --- a/src/audio_key.rs +++ b/src/audio_key.rs @@ -4,8 +4,7 @@ use std::collections::HashMap; use std::io::{Cursor, Read, Write}; use util::{SpotifyId, FileId}; -use session::Session; -use connection::PacketHandler; +use session::{Session, PacketHandler}; pub type AudioKey = [u8; 16]; #[derive(Debug,Hash,PartialEq,Eq,Copy,Clone)] @@ -70,7 +69,7 @@ impl AudioKeyManager { } impl PacketHandler for AudioKeyManager { - fn handle(&mut self, cmd: u8, data: Vec) { + fn handle(&mut self, cmd: u8, data: Vec, _session: &Session) { let mut data = Cursor::new(data); let seq = data.read_u32::().unwrap(); diff --git a/src/connection.rs b/src/connection.rs index acaa6fd2..20bacd53 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -96,7 +96,3 @@ impl CipherConnection { Ok((cmd, data)) } } - -pub trait PacketHandler { - fn handle(&mut self, cmd: u8, data: Vec); -} diff --git a/src/lib.in.rs b/src/lib.in.rs index e44c9e69..93926ea0 100644 --- a/src/lib.in.rs +++ b/src/lib.in.rs @@ -17,8 +17,7 @@ pub mod spirc; pub mod link; pub mod stream; pub mod main_helper; +mod audio_file2; #[cfg(feature = "facebook")] pub mod spotilocal; - -pub use album_cover::get_album_cover; diff --git a/src/mercury.rs b/src/mercury.rs index 69b5c3bd..37a908c6 100644 --- a/src/mercury.rs +++ b/src/mercury.rs @@ -7,8 +7,7 @@ use std::mem::replace; use std::sync::mpsc; use protocol; -use session::Session; -use connection::PacketHandler; +use session::{Session, PacketHandler}; #[derive(Debug, PartialEq, Eq)] pub enum MercuryMethod { @@ -186,7 +185,7 @@ impl MercuryManager { } impl PacketHandler for MercuryManager { - fn handle(&mut self, cmd: u8, data: Vec) { + fn handle(&mut self, cmd: u8, data: Vec, _session: &Session) { let mut packet = Cursor::new(data); let seq = { diff --git a/src/session.rs b/src/session.rs index 21b91a29..41340c2c 100644 --- a/src/session.rs +++ b/src/session.rs @@ -12,21 +12,23 @@ use std::io::{Read, Write, Cursor}; use std::result::Result; use std::sync::{Mutex, RwLock, Arc, mpsc}; -use album_cover::get_album_cover; +use album_cover::AlbumCover; use apresolve::apresolve; use audio_key::{AudioKeyManager, AudioKey, AudioKeyError}; use audio_file::AudioFile; use authentication::Credentials; use cache::Cache; -use connection::{self, PlainConnection, CipherConnection, PacketHandler}; +use connection::{self, PlainConnection, CipherConnection}; use diffie_hellman::DHLocalKeys; use mercury::{MercuryManager, MercuryRequest, MercuryResponse}; use metadata::{MetadataManager, MetadataRef, MetadataTrait}; use protocol; -use stream::{ChannelId, StreamManager, StreamEvent, StreamError}; +use stream::StreamManager; use util::{self, SpotifyId, FileId, ReadSeek}; use version; +use stream; + pub enum Bitrate { Bitrate96, Bitrate160, @@ -249,12 +251,12 @@ impl Session { match cmd { 0x4 => self.send_packet(0x49, &data).unwrap(), 0x4a => (), - 0x9 | 0xa => self.0.stream.lock().unwrap().handle(cmd, data), - 0xd | 0xe => self.0.audio_key.lock().unwrap().handle(cmd, data), + 0x9 | 0xa => self.0.stream.lock().unwrap().handle(cmd, data, self), + 0xd | 0xe => self.0.audio_key.lock().unwrap().handle(cmd, data, self), 0x1b => { self.0.data.write().unwrap().country = String::from_utf8(data).unwrap(); } - 0xb2...0xb6 => self.0.mercury.lock().unwrap().handle(cmd, data), + 0xb2...0xb6 => self.0.mercury.lock().unwrap().handle(cmd, data, self), _ => (), } } @@ -293,7 +295,7 @@ impl Session { self_.0.cache.put_file(file_id, &mut complete_file) }).fire(); - Box::new(audio_file) + Box::new(audio_file.await().unwrap()) }) } @@ -307,7 +309,7 @@ impl Session { }) .unwrap_or_else(|| { let self_ = self.clone(); - get_album_cover(self, file_id) + AlbumCover::get(file_id, self) .map(move |data| { self_.0.cache.put_file(file_id, &mut Cursor::new(&data)); data @@ -315,12 +317,8 @@ impl Session { }) } - pub fn stream(&self, file: FileId, offset: u32, size: u32) -> eventual::Stream { - self.0.stream.lock().unwrap().request(self, file, offset, size) - } - - pub fn allocate_stream(&self) -> (ChannelId, eventual::Stream) { - self.0.stream.lock().unwrap().allocate_stream() + pub fn stream(&self, handler: Box) { + self.0.stream.lock().unwrap().create(handler, self) } pub fn metadata(&self, id: SpotifyId) -> MetadataRef { @@ -355,3 +353,7 @@ impl Session { &self.0.device_id } } + +pub trait PacketHandler { + fn handle(&mut self, cmd: u8, data: Vec, session: &Session); +} diff --git a/src/stream.rs b/src/stream.rs index 946aa3f8..7239ee56 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,37 +1,113 @@ -use byteorder::{BigEndian, ByteOrder, ReadBytesExt, WriteBytesExt}; +use byteorder::{BigEndian, ByteOrder, ReadBytesExt}; use std::collections::HashMap; use std::collections::hash_map::Entry; -use std::io::{Cursor, Seek, SeekFrom, Write}; -use eventual::{self, Async}; +use std::io::{Cursor, Seek, SeekFrom}; +use session::{Session, PacketHandler}; -use util::{ArcVec, FileId}; -use connection::PacketHandler; -use session::Session; - -#[derive(Debug)] -pub enum StreamEvent { - Header(u8, ArcVec), - Data(ArcVec), +pub enum Response { + Continue(H), + Spawn(S), + Close, } -#[derive(Debug,Hash,PartialEq,Eq,Copy,Clone)] -pub struct StreamError; +impl Response { + pub fn boxed(self) -> Response> { + match self { + Response::Continue(handler) => Response::Continue(Box::new(handler)), + Response::Spawn(handler) => Response::Spawn(Box::new(handler)), + Response::Close => Response::Close, + } + } +} + +pub trait Handler: Send { + fn on_create(self, channel_id: ChannelId, session: &Session) -> Response where Self: Sized; + fn on_header(self, header_id: u8, header_data: &[u8], session: &Session) -> Response where Self: Sized; + fn on_data(self, data: &[u8], session: &Session) -> Response where Self: Sized; + fn on_error(self, session: &Session) -> Response where Self: Sized; + fn on_close(self, session: &Session) -> Response where Self: Sized; + + fn box_on_create(self: Box, channel_id: ChannelId, session: &Session) -> Response>; + fn box_on_header(self: Box, header_id: u8, header_data: &[u8], session: &Session) -> Response>; + fn box_on_data(self: Box, data: &[u8], session: &Session) -> Response>; + fn box_on_error(self: Box, session: &Session) -> Response>; + fn box_on_close(self: Box, session: &Session) -> Response>; +} pub type ChannelId = u16; enum ChannelMode { Header, - Data, + Data } -struct Channel { - mode: ChannelMode, - callback: Option>, +struct Channel(ChannelMode, Box); + +impl Channel { + fn handle_packet(self, cmd: u8, data: Vec, session: &Session) -> Response> { + let Channel(mode, mut handler) = self; + + let mut packet = Cursor::new(&data as &[u8]); + packet.read_u16::().unwrap(); // Skip channel id + + if cmd == 0xa { + println!("error: {} {}", data.len(), packet.read_u16::().unwrap()); + return match handler.box_on_error(session) { + Response::Continue(_) => Response::Close, + Response::Spawn(f) => Response::Spawn(f), + Response::Close => Response::Close, + }; + } + + match mode { + ChannelMode::Header => { + let mut length = 0; + + while packet.position() < data.len() as u64 { + length = packet.read_u16::().unwrap(); + if length > 0 { + let header_id = packet.read_u8().unwrap(); + let header_data = &data[packet.position() as usize .. packet.position() as usize + length as usize - 1]; + + handler = match handler.box_on_header(header_id, header_data, session) { + Response::Continue(handler) => handler, + Response::Spawn(f) => return Response::Spawn(f), + Response::Close => return Response::Close, + }; + + packet.seek(SeekFrom::Current(length as i64 - 1)).unwrap(); + } + } + + if length == 0 { + Response::Continue(Channel(ChannelMode::Data, handler)) + } else { + Response::Continue(Channel(ChannelMode::Header, handler)) + } + } + ChannelMode::Data => { + if packet.position() < data.len() as u64 { + let event_data = &data[packet.position() as usize..]; + match handler.box_on_data(event_data, session) { + Response::Continue(handler) => Response::Continue(Channel(ChannelMode::Data, handler)), + Response::Spawn(f) => Response::Spawn(f), + Response::Close => Response::Close, + } + } else { + match handler.box_on_close(session) { + Response::Continue(_) => Response::Close, + Response::Spawn(f) => Response::Spawn(f), + Response::Close => Response::Close, + } + } + } + } + } } pub struct StreamManager { next_id: ChannelId, - channels: HashMap, + channels: HashMap>, } impl StreamManager { @@ -42,107 +118,52 @@ impl StreamManager { } } - pub fn allocate_stream(&mut self) -> (ChannelId, eventual::Stream) { - let (tx, rx) = eventual::Stream::pair(); - + pub fn create(&mut self, handler: Box, session: &Session) { let channel_id = self.next_id; self.next_id += 1; - self.channels.insert(channel_id, - Channel { - mode: ChannelMode::Header, - callback: Some(tx), - }); + trace!("allocated stream {}", channel_id); - (channel_id, rx) - } - - pub fn request(&mut self, - session: &Session, - file: FileId, - offset: u32, - size: u32) - -> eventual::Stream { - - let (channel_id, rx) = self.allocate_stream(); - - let mut data: Vec = Vec::new(); - data.write_u16::(channel_id).unwrap(); - data.write_u8(0).unwrap(); - data.write_u8(1).unwrap(); - data.write_u16::(0x0000).unwrap(); - data.write_u32::(0x00000000).unwrap(); - data.write_u32::(0x00009C40).unwrap(); - data.write_u32::(0x00020000).unwrap(); - data.write(&file.0).unwrap(); - data.write_u32::(offset).unwrap(); - data.write_u32::(offset + size).unwrap(); - - session.send_packet(0x8, &data).unwrap(); - - rx - } -} - -impl Channel { - fn handle_packet(&mut self, cmd: u8, data: Vec) { - let data = ArcVec::new(data); - let mut packet = Cursor::new(&data as &[u8]); - packet.read_u16::().unwrap(); // Skip channel id - - if cmd == 0xa { - self.callback.take().map(|c| c.fail(StreamError)); - } else { - match self.mode { - ChannelMode::Header => { - let mut length = 0; - - while packet.position() < data.len() as u64 { - length = packet.read_u16::().unwrap(); - if length > 0 { - let header_id = packet.read_u8().unwrap(); - let header_data = data.clone() - .offset(packet.position() as usize) - .limit(length as usize - 1); - - let header = StreamEvent::Header(header_id, header_data); - - self.callback = self.callback.take().and_then(|c| c.send(header).await().ok()); - - packet.seek(SeekFrom::Current(length as i64 - 1)).unwrap(); - } - } - - if length == 0 { - self.mode = ChannelMode::Data; - } - } - - ChannelMode::Data => { - if packet.position() < data.len() as u64 { - let event_data = data.clone().offset(packet.position() as usize); - let event = StreamEvent::Data(event_data); - - self.callback = self.callback.take().and_then(|c| c.send(event).await().ok()); - } else { - self.callback = None; - } - } + match handler.box_on_create(channel_id, session) { + Response::Continue(handler) => { + self.channels.insert(channel_id, Some(Channel(ChannelMode::Header, handler))); } + Response::Spawn(handler) => self.create(handler, session), + Response::Close => (), } } } impl PacketHandler for StreamManager { - fn handle(&mut self, cmd: u8, data: Vec) { + fn handle(&mut self, cmd: u8, data: Vec, session: &Session) { let id: ChannelId = BigEndian::read_u16(&data[0..2]); - if let Entry::Occupied(mut entry) = self.channels.entry(id) { - entry.get_mut().handle_packet(cmd, data); - - if entry.get().callback.is_none() { - entry.remove(); + let spawn = if let Entry::Occupied(mut entry) = self.channels.entry(id) { + if let Some(channel) = entry.get_mut().take() { + match channel.handle_packet(cmd, data, session) { + Response::Continue(channel) => { + entry.insert(Some(channel)); + None + } + Response::Spawn(f) => { + entry.remove(); + Some(f) + } + Response::Close => { + entry.remove(); + None + } + } + } else { + None } + } else { + None + }; + + + if let Some(s) = spawn { + self.create(s, session); } } }