From 2fe91d5dd34e8d8e8fc51b0bcbc6575c3ded8ee9 Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Tue, 20 Jun 2023 18:22:34 +0200 Subject: [PATCH 1/8] Give the user a tracker for conn interactions Signed-off-by: Luca Fulchir --- src/connection/handshake/tracker.rs | 54 +++++++++++--------- src/connection/mod.rs | 54 ++++++++++++++++---- src/inner/worker.rs | 76 ++++++++++++++++++++++------- src/lib.rs | 53 ++++++++++---------- src/tests.rs | 2 +- 5 files changed, 162 insertions(+), 77 deletions(-) diff --git a/src/connection/handshake/tracker.rs b/src/connection/handshake/tracker.rs index 30a6a8f..c40158e 100644 --- a/src/connection/handshake/tracker.rs +++ b/src/connection/handshake/tracker.rs @@ -3,8 +3,9 @@ use crate::{ auth::{Domain, ServiceID}, connection::{ + self, handshake::{self, Error, Handshake}, - Conn, IDRecv, IDSend, + Connection, IDRecv, IDSend, }, enc::{ self, @@ -18,20 +19,27 @@ use crate::{ use ::tokio::sync::oneshot; pub(crate) struct Server { - pub id: KeyID, - pub key: PrivKey, - pub domains: Vec, + pub(crate) id: KeyID, + pub(crate) key: PrivKey, + pub(crate) domains: Vec, } -pub(crate) type ConnectAnswer = Result<(KeyID, IDSend), crate::Error>; +pub(crate) type ConnectAnswer = Result; +#[derive(Debug)] +pub(crate) struct ConnectOk { + pub(crate) auth_key_id: KeyID, + pub(crate) auth_id_send: IDSend, + pub(crate) authsrv_conn: connection::AuthSrvConn, + pub(crate) service_conn: Option, +} pub(crate) struct Client { - pub service_id: ServiceID, - pub service_conn_id: IDRecv, - pub connection: Conn, - pub timeout: Option<::tokio::task::JoinHandle<()>>, - pub answer: oneshot::Sender, - pub srv_key_id: KeyID, + pub(crate) service_id: ServiceID, + pub(crate) service_conn_id: IDRecv, + pub(crate) connection: Connection, + pub(crate) timeout: Option<::tokio::task::JoinHandle<()>>, + pub(crate) answer: oneshot::Sender, + pub(crate) srv_key_id: KeyID, } /// Tracks the keys used by the client and the handshake @@ -78,7 +86,7 @@ impl ClientList { pub_key: PubKey, service_id: ServiceID, service_conn_id: IDRecv, - connection: Conn, + connection: Connection, answer: oneshot::Sender, srv_key_id: KeyID, ) -> Result<(KeyID, &mut Client), oneshot::Sender> { @@ -128,26 +136,26 @@ impl ClientList { #[derive(Debug, Clone)] pub(crate) struct AuthNeededInfo { /// Parsed handshake packet - pub handshake: Handshake, + pub(crate) handshake: Handshake, /// hkdf generated from the handshake - pub hkdf: Hkdf, + pub(crate) hkdf: Hkdf, } /// Client information needed to fully establish the conenction #[derive(Debug)] pub(crate) struct ClientConnectInfo { /// The service ID that we are connecting to - pub service_id: ServiceID, + pub(crate) service_id: ServiceID, /// The service ID that we are connecting to - pub service_connection_id: IDRecv, + pub(crate) service_connection_id: IDRecv, /// Parsed handshake packet - pub handshake: Handshake, - /// Conn - pub connection: Conn, + pub(crate) handshake: Handshake, + /// Connection + pub(crate) connection: Connection, /// where to wake up the waiting client - pub answer: oneshot::Sender, - /// server public key id that we used on the handshake - pub srv_key_id: KeyID, + pub(crate) answer: oneshot::Sender, + /// server pub(crate)lic key id that we used on the handshake + pub(crate) srv_key_id: KeyID, } /// Intermediate actions to be taken while parsing the handshake #[derive(Debug)] @@ -231,7 +239,7 @@ impl Tracker { pub_key: PubKey, service_id: ServiceID, service_conn_id: IDRecv, - connection: Conn, + connection: Connection, answer: oneshot::Sender, srv_key_id: KeyID, ) -> Result<(KeyID, &mut Client), oneshot::Sender> { diff --git a/src/connection/mod.rs b/src/connection/mod.rs index f6c0a86..58853c7 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -5,7 +5,7 @@ pub mod packet; pub mod socket; pub mod stream; -use ::std::{rc::Rc, vec::Vec}; +use ::std::{collections::HashMap, rc::Rc, vec::Vec}; pub use crate::connection::{handshake::Handshake, packet::Packet}; @@ -17,9 +17,8 @@ use crate::{ sym::{self, CipherRecv, CipherSend}, Random, }, - inner::ThreadTracker, + inner::{worker, ThreadTracker}, }; -use ::std::rc; /// Fenrir Connection ID /// @@ -126,13 +125,40 @@ impl ProtocolVersion { } } +#[derive(Debug, Clone, Copy, Eq, Hash, PartialEq)] +pub(crate) struct UserConnTracker(usize); +impl UserConnTracker { + fn advance(&mut self) -> Self { + let old = self.0; + self.0 = self.0 + 1; + UserConnTracker(old) + } +} + +/// Connection to an Authentication Server +#[derive(Debug)] +pub struct AuthSrvConn(pub(crate) Conn); +/// Connection to a service +#[derive(Debug)] +pub struct ServiceConn(pub(crate) Conn); + /// The connection, as seen from a user of libFenrir #[derive(Debug)] -pub struct Connection(rc::Weak); +pub struct Conn { + pub(crate) queue: ::async_channel::Sender, + pub(crate) conn: UserConnTracker, +} + +impl Conn { + /// Queue some data to be sent in this connection + pub fn send(&mut self, stream: stream::ID, _data: Vec) { + todo!() + } +} /// A single connection and its data #[derive(Debug)] -pub(crate) struct Conn { +pub(crate) struct Connection { /// Receiving Conn ID pub id_recv: IDRecv, /// Sending Conn ID @@ -160,7 +186,7 @@ pub enum Role { Client, } -impl Conn { +impl Connection { pub(crate) fn new( hkdf: Hkdf, cipher: sym::Kind, @@ -190,7 +216,9 @@ impl Conn { pub(crate) struct ConnList { thread_id: ThreadTracker, - connections: Vec>>, + connections: Vec>>, + user_tracker: HashMap, + last_tracked: UserConnTracker, /// Bitmap to track which connection ids are used or free ids_used: Vec<::bitmaps::Bitmap<1024>>, } @@ -206,6 +234,8 @@ impl ConnList { let mut ret = Self { thread_id, connections: Vec::with_capacity(INITIAL_CAP), + user_tracker: HashMap::with_capacity(INITIAL_CAP), + last_tracked: UserConnTracker(0), ids_used: vec![bitmap_id], }; ret.connections.resize_with(INITIAL_CAP, || None); @@ -261,7 +291,10 @@ impl ConnList { new_id } /// NOTE: does NOT check if the connection has been previously reserved! - pub(crate) fn track(&mut self, conn: Rc) -> Result<(), ()> { + pub(crate) fn track( + &mut self, + conn: Rc, + ) -> Result { let conn_id = match conn.id_recv { IDRecv(ID::Handshake) => { return Err(()); @@ -271,7 +304,9 @@ impl ConnList { let id_in_thread: usize = (conn_id.get() / (self.thread_id.total as u64)) as usize; self.connections[id_in_thread] = Some(conn); - Ok(()) + let tracked = self.last_tracked.advance(); + let _ = self.user_tracker.insert(tracked, id_in_thread); + Ok(tracked) } pub(crate) fn remove(&mut self, id: IDRecv) { if let IDRecv(ID::ID(raw_id)) = id { @@ -303,7 +338,6 @@ enum MapEntry { Present(IDSend), Reserved, } -use ::std::collections::HashMap; /// Link the public key of the authentication server to a connection id /// so that we can reuse that connection to ask for more authentications diff --git a/src/inner/worker.rs b/src/inner/worker.rs index 27dd92a..29ad7f7 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -11,7 +11,7 @@ use crate::{ }, packet::{self, Packet}, socket::{UdpClient, UdpServer}, - Conn, ConnList, IDSend, + AuthSrvConn, ConnList, Connection, IDSend, ServiceConn, }, dnssec, enc::{ @@ -64,6 +64,7 @@ pub struct Worker { token_check: Option>>, sockets: Vec>, queue: ::async_channel::Receiver, + queue_sender: ::async_channel::Sender, queue_timeouts_recv: mpsc::UnboundedReceiver, queue_timeouts_send: mpsc::UnboundedSender, thread_channels: Vec<::async_channel::Sender>, @@ -82,6 +83,7 @@ impl Worker { token_check: Option>>, sockets: Vec>, queue: ::async_channel::Receiver, + queue_sender: ::async_channel::Sender, ) -> ::std::io::Result { let (queue_timeouts_send, queue_timeouts_recv) = mpsc::unbounded_channel(); @@ -118,6 +120,7 @@ impl Worker { token_check, sockets, queue, + queue_sender, queue_timeouts_recv, queue_timeouts_send, thread_channels: Vec::new(), @@ -293,7 +296,7 @@ impl Worker { // are PubKey::Exchange unreachable!() } - let mut conn = Conn::new( + let mut conn = Connection::new( hkdf, cipher_selected, connection::Role::Client, @@ -515,7 +518,7 @@ impl Worker { let head_len = req.cipher.nonce_len(); let tag_len = req.cipher.tag_len(); - let mut auth_conn = Conn::new( + let mut auth_conn = Connection::new( authinfo.hkdf, req.cipher, connection::Role::Server, @@ -587,20 +590,32 @@ impl Worker { ); unreachable!(); } - let auth_srv_conn = IDSend(resp_data.id); + let auth_id_send = IDSend(resp_data.id); let mut conn = cci.connection; - conn.id_send = auth_srv_conn; + conn.id_send = auth_id_send; let id_recv = conn.id_recv; let cipher = conn.cipher_recv.kind(); // track the connection to the authentication server - if self.connections.track(conn.into()).is_err() { - ::tracing::error!("Could not track new connection"); - self.connections.remove(id_recv); - let _ = cci.answer.send(Err( - handshake::Error::InternalTracking.into(), - )); - return; - } + let track_auth_conn = + match self.connections.track(conn.into()) { + Ok(track_auth_conn) => track_auth_conn, + Err(e) => { + ::tracing::error!( + "Could not track new auth srv connection" + ); + self.connections.remove(id_recv); + // FIXME: proper connection closing + let _ = cci.answer.send(Err( + handshake::Error::InternalTracking.into(), + )); + return; + } + }; + let authsrv_conn = AuthSrvConn(connection::Conn { + queue: self.queue_sender.clone(), + conn: track_auth_conn, + }); + let mut service_conn = None; if cci.service_id != auth::SERVICEID_AUTH { // create and track the connection to the service // SECURITY: xor with secrets @@ -611,7 +626,7 @@ impl Worker { cci.service_id.as_bytes(), resp_data.service_key, ); - let mut service_connection = Conn::new( + let mut service_connection = Connection::new( hkdf, cipher, connection::Role::Client, @@ -620,11 +635,38 @@ impl Worker { service_connection.id_recv = cci.service_connection_id; service_connection.id_send = IDSend(resp_data.service_connection_id); - let _ = - self.connections.track(service_connection.into()); + let track_serv_conn = match self + .connections + .track(service_connection.into()) + { + Ok(track_serv_conn) => track_serv_conn, + Err(e) => { + ::tracing::error!( + "Could not track new service connection" + ); + self.connections + .remove(cci.service_connection_id); + // FIXME: proper connection closing + // FIXME: drop auth srv connection if we just + // established it + let _ = cci.answer.send(Err( + handshake::Error::InternalTracking.into(), + )); + return; + } + }; + service_conn = Some(ServiceConn(connection::Conn { + queue: self.queue_sender.clone(), + conn: track_serv_conn, + })); } let _ = - cci.answer.send(Ok((cci.srv_key_id, auth_srv_conn))); + cci.answer.send(Ok(handshake::tracker::ConnectOk { + auth_key_id: cci.srv_key_id, + auth_id_send, + authsrv_conn, + service_conn, + })); } handshake::Action::Nothing => {} }; diff --git a/src/lib.rs b/src/lib.rs index 57da614..a0ee7ae 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -39,7 +39,7 @@ use crate::{ }, }; pub use config::Config; -pub use connection::Connection; +pub use connection::{AuthSrvConn, ServiceConn}; /// Main fenrir library errors #[derive(::thiserror::Error, Debug)] @@ -382,7 +382,7 @@ impl Fenrir { &self, domain: &Domain, service: ServiceID, - ) -> Result<(), Error> { + ) -> Result<(AuthSrvConn, Option), Error> { let resolved = self.resolv(domain).await?; self.connect_resolved(resolved, domain, service).await } @@ -392,7 +392,7 @@ impl Fenrir { resolved: dnssec::Record, domain: &Domain, service: ServiceID, - ) -> Result<(), Error> { + ) -> Result<(AuthSrvConn, Option), Error> { loop { // check if we already have a connection to that auth. srv let is_reserved = { @@ -460,29 +460,28 @@ impl Fenrir { .await; match recv.await { - Ok(res) => { - match res { - Err(e) => { - let mut conn_auth_lock = - self.conn_auth_srv.lock().await; - conn_auth_lock.remove_reserved(&resolved); - Err(e) - } - Ok((key_id, id_send)) => { - let key = resolved - .public_keys - .iter() - .find(|k| k.0 == key_id) - .unwrap(); - let mut conn_auth_lock = - self.conn_auth_srv.lock().await; - conn_auth_lock.add(&key.1, id_send, &resolved); - - //FIXME: user needs to somehow track the connection - Ok(()) - } + Ok(res) => match res { + Err(e) => { + let mut conn_auth_lock = self.conn_auth_srv.lock().await; + conn_auth_lock.remove_reserved(&resolved); + Err(e) } - } + Ok(connections) => { + let key = resolved + .public_keys + .iter() + .find(|k| k.0 == connections.auth_key_id) + .unwrap(); + let mut conn_auth_lock = self.conn_auth_srv.lock().await; + conn_auth_lock.add( + &key.1, + connections.auth_id_send, + &resolved, + ); + + Ok((connections.authsrv_conn, connections.service_conn)) + } + }, Err(e) => { // Thread dropped the sender. no more thread? let mut conn_auth_lock = self.conn_auth_srv.lock().await; @@ -524,6 +523,7 @@ impl Fenrir { self.token_check.clone(), socks, work_recv, + work_send.clone(), ) .await?; // don't keep around private keys too much @@ -547,7 +547,6 @@ impl Fenrir { } Ok(worker) } - // needs to be called before add_sockets /// Start one working thread for each physical cpu /// threads are pinned to each cpu core. @@ -589,6 +588,7 @@ impl Fenrir { let th_tokio_rt = tokio_rt.clone(); let th_config = self.cfg.clone(); let (work_send, work_recv) = ::async_channel::unbounded::(); + let th_work_send = work_send.clone(); let th_stop_working = self.stop_working.subscribe(); let th_token_check = self.token_check.clone(); let th_sockets = sockets.clone(); @@ -629,6 +629,7 @@ impl Fenrir { th_token_check, th_sockets, work_recv, + th_work_send, ) .await { diff --git a/src/tests.rs b/src/tests.rs index 4e2a0b6..ff28ae6 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -88,7 +88,7 @@ async fn test_connection_dirsync() { .connect_resolved(dnssec_record, &test_domain, auth::SERVICEID_AUTH) .await { - Ok(()) => {} + Ok((_, _)) => {} Err(e) => { assert!(false, "Err on client connection: {:?}", e); } -- 2.47.2 From 9c67210e3e6c63a22369b1d5060893ab52de18bc Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Thu, 22 Jun 2023 12:50:47 +0200 Subject: [PATCH 2/8] User conn tracking, enqueue data, timers Signed-off-by: Luca Fulchir --- src/connection/mod.rs | 75 ++++++++++++++++++++++-------- src/connection/stream/mod.rs | 2 +- src/inner/mod.rs | 84 +++++++++++++++++++++++++++++++++ src/inner/worker.rs | 90 +++++++++++++++++++++--------------- src/lib.rs | 2 + 5 files changed, 198 insertions(+), 55 deletions(-) diff --git a/src/connection/mod.rs b/src/connection/mod.rs index 58853c7..ea5dfcb 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -5,7 +5,11 @@ pub mod packet; pub mod socket; pub mod stream; -use ::std::{collections::HashMap, rc::Rc, vec::Vec}; +use ::core::num::Wrapping; +use ::std::{ + collections::{BTreeMap, HashMap}, + vec::Vec, +}; pub use crate::connection::{handshake::Handshake, packet::Packet}; @@ -125,12 +129,12 @@ impl ProtocolVersion { } } -#[derive(Debug, Clone, Copy, Eq, Hash, PartialEq)] -pub(crate) struct UserConnTracker(usize); +#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)] +pub(crate) struct UserConnTracker(Wrapping); impl UserConnTracker { fn advance(&mut self) -> Self { let old = self.0; - self.0 = self.0 + 1; + self.0 = self.0 + Wrapping(1); UserConnTracker(old) } } @@ -151,8 +155,13 @@ pub struct Conn { impl Conn { /// Queue some data to be sent in this connection - pub fn send(&mut self, stream: stream::ID, _data: Vec) { - todo!() + // TODO: send_and_wait, that wait for recipient ACK + pub async fn send(&mut self, stream: stream::ID, data: Vec) { + use crate::inner::worker::Work; + let _ = self + .queue + .send(Work::UserSend((self.conn, stream, data))) + .await; } } @@ -160,15 +169,17 @@ impl Conn { #[derive(Debug)] pub(crate) struct Connection { /// Receiving Conn ID - pub id_recv: IDRecv, + pub(crate) id_recv: IDRecv, /// Sending Conn ID - pub id_send: IDSend, + pub(crate) id_send: IDSend, /// The main hkdf used for all secrets in this connection - pub hkdf: Hkdf, + hkdf: Hkdf, /// Cipher for decrypting data - pub cipher_recv: CipherRecv, + pub(crate) cipher_recv: CipherRecv, /// Cipher for encrypting data - pub cipher_send: CipherSend, + pub(crate) cipher_send: CipherSend, + /// send queue for each Stream + send_queue: BTreeMap>>, } /// Role: track the connection direction @@ -210,14 +221,22 @@ impl Connection { hkdf, cipher_recv, cipher_send, + send_queue: BTreeMap::new(), } } + pub(crate) fn send(&mut self, stream: stream::ID, data: Vec) { + let stream = match self.send_queue.get_mut(&stream) { + None => return, + Some(stream) => stream, + }; + stream.push(data); + } } pub(crate) struct ConnList { thread_id: ThreadTracker, - connections: Vec>>, - user_tracker: HashMap, + connections: Vec>, + user_tracker: BTreeMap, last_tracked: UserConnTracker, /// Bitmap to track which connection ids are used or free ids_used: Vec<::bitmaps::Bitmap<1024>>, @@ -234,13 +253,27 @@ impl ConnList { let mut ret = Self { thread_id, connections: Vec::with_capacity(INITIAL_CAP), - user_tracker: HashMap::with_capacity(INITIAL_CAP), - last_tracked: UserConnTracker(0), + user_tracker: BTreeMap::new(), + last_tracked: UserConnTracker(Wrapping(0)), ids_used: vec![bitmap_id], }; ret.connections.resize_with(INITIAL_CAP, || None); ret } + pub fn get_mut( + &mut self, + tracker: UserConnTracker, + ) -> Option<&mut Connection> { + let idx = if let Some(idx) = self.user_tracker.get(&tracker) { + *idx + } else { + return None; + }; + match &mut self.connections[idx] { + None => None, + Some(conn) => Some(conn), + } + } pub fn len(&self) -> usize { let mut total: usize = 0; for bitmap in self.ids_used.iter() { @@ -293,7 +326,7 @@ impl ConnList { /// NOTE: does NOT check if the connection has been previously reserved! pub(crate) fn track( &mut self, - conn: Rc, + conn: Connection, ) -> Result { let conn_id = match conn.id_recv { IDRecv(ID::Handshake) => { @@ -304,8 +337,14 @@ impl ConnList { let id_in_thread: usize = (conn_id.get() / (self.thread_id.total as u64)) as usize; self.connections[id_in_thread] = Some(conn); - let tracked = self.last_tracked.advance(); - let _ = self.user_tracker.insert(tracked, id_in_thread); + let mut tracked; + loop { + tracked = self.last_tracked.advance(); + if self.user_tracker.get(&tracked).is_none() { + let _ = self.user_tracker.insert(tracked, id_in_thread); + break; + } + } Ok(tracked) } pub(crate) fn remove(&mut self, id: IDRecv) { diff --git a/src/connection/stream/mod.rs b/src/connection/stream/mod.rs index 1b56a2b..58dd76e 100644 --- a/src/connection/stream/mod.rs +++ b/src/connection/stream/mod.rs @@ -19,7 +19,7 @@ pub enum Kind { } /// Id of the stream -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] pub struct ID(pub u16); impl ID { diff --git a/src/inner/mod.rs b/src/inner/mod.rs index 001ca16..e23d614 100644 --- a/src/inner/mod.rs +++ b/src/inner/mod.rs @@ -4,6 +4,10 @@ pub(crate) mod worker; +use crate::inner::worker::Work; +use ::std::{collections::BTreeMap, vec::Vec}; +use ::tokio::time::Instant; + /// Track the total number of threads and our index /// 65K cpus should be enough for anybody #[derive(Debug, Clone, Copy)] @@ -12,3 +16,83 @@ pub(crate) struct ThreadTracker { /// Note: starts from 1 pub id: u16, } + +pub(crate) static mut SLEEP_RESOLUTION: ::std::time::Duration = + if cfg!(linux) || cfg!(macos) { + ::std::time::Duration::from_millis(1) + } else { + // windows + ::std::time::Duration::from_millis(16) + }; + +pub(crate) async fn set_minimum_sleep_resolution() { + let nanosleep = ::std::time::Duration::from_nanos(1); + let mut tests: usize = 3; + + while tests > 0 { + let pre_sleep = ::std::time::Instant::now(); + ::tokio::time::sleep(nanosleep).await; + let post_sleep = ::std::time::Instant::now(); + let slept_for = post_sleep - pre_sleep; + #[allow(unsafe_code)] + unsafe { + if slept_for < SLEEP_RESOLUTION { + SLEEP_RESOLUTION = slept_for; + } + } + tests = tests - 1; + } +} + +/// Sleeping has a higher resolution that we would like for packet pacing. +/// So we sleep for however log we need, then chunk up all the work here +/// we will end up chunking the work in SLEEP_RESOLUTION, then we will busy wait +/// for more precise timing +pub(crate) struct Timers { + times: BTreeMap, +} + +impl Timers { + pub(crate) fn new() -> Self { + Self { + times: BTreeMap::new(), + } + } + pub(crate) fn get_next(&self) -> ::tokio::time::Sleep { + match self.times.keys().next() { + Some(entry) => ::tokio::time::sleep_until((*entry).into()), + None => { + ::tokio::time::sleep(::std::time::Duration::from_secs(3600)) + } + } + } + /// Get all the work from now up until now + SLEEP_RESOLUTION + pub(crate) fn get_work(&mut self) -> Vec { + let now: ::tokio::time::Instant = ::std::time::Instant::now().into(); + let mut ret = Vec::with_capacity(4); + let mut count_rm = 0; + #[allow(unsafe_code)] + let next_instant = unsafe { now + SLEEP_RESOLUTION }; + let mut iter = self.times.iter_mut().peekable(); + loop { + match iter.peek() { + None => break, + Some(next) => { + if *next.0 > next_instant { + break; + } + } + } + let mut work = Work::DropHandshake(crate::enc::asym::KeyID(0)); + let mut entry = iter.next().unwrap(); + ::core::mem::swap(&mut work, &mut entry.1); + ret.push(work); + count_rm = count_rm + 1; + } + while count_rm > 0 { + self.times.pop_first(); + count_rm = count_rm - 1; + } + ret + } +} diff --git a/src/inner/worker.rs b/src/inner/worker.rs index 29ad7f7..0bb41d0 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -11,7 +11,8 @@ use crate::{ }, packet::{self, Packet}, socket::{UdpClient, UdpServer}, - AuthSrvConn, ConnList, Connection, IDSend, ServiceConn, + stream, AuthSrvConn, ConnList, Connection, IDSend, ServiceConn, + UserConnTracker, }, dnssec, enc::{ @@ -51,6 +52,7 @@ pub(crate) enum Work { Connect(ConnectInfo), DropHandshake(KeyID), Recv(RawUdp), + UserSend((UserConnTracker, stream::ID, Vec)), } /// Actual worker implementation. @@ -70,6 +72,7 @@ pub struct Worker { thread_channels: Vec<::async_channel::Sender>, connections: ConnList, handshakes: handshake::Tracker, + work_timers: super::Timers, } #[allow(unsafe_code)] @@ -126,12 +129,15 @@ impl Worker { thread_channels: Vec::new(), connections: ConnList::new(thread_id), handshakes, + work_timers: super::Timers::new(), }) } /// Continuously loop and process work as needed pub async fn work_loop(&mut self) { 'mainloop: loop { + let next_timer = self.work_timers.get_next(); + ::tokio::pin!(next_timer); let work = ::tokio::select! { tell_stopped = self.stop_working.recv() => { if let Ok(stop_ch) = tell_stopped { @@ -140,6 +146,13 @@ impl Worker { } break; } + () = &mut next_timer => { + let work_list = self.work_timers.get_work(); + for w in work_list.into_iter() { + let _ = self.queue_sender.send(w).await; + } + continue 'mainloop; + } maybe_timeout = self.queue.recv() => { match maybe_timeout { Ok(work) => work, @@ -419,6 +432,13 @@ impl Worker { Work::Recv(pkt) => { self.recv(pkt).await; } + Work::UserSend((tracker, stream, data)) => { + let conn = match self.connections.get_mut(tracker) { + None => return, + Some(conn) => conn, + }; + conn.send(stream, data); + } } } } @@ -596,21 +616,20 @@ impl Worker { let id_recv = conn.id_recv; let cipher = conn.cipher_recv.kind(); // track the connection to the authentication server - let track_auth_conn = - match self.connections.track(conn.into()) { - Ok(track_auth_conn) => track_auth_conn, - Err(e) => { - ::tracing::error!( - "Could not track new auth srv connection" - ); - self.connections.remove(id_recv); - // FIXME: proper connection closing - let _ = cci.answer.send(Err( - handshake::Error::InternalTracking.into(), - )); - return; - } - }; + let track_auth_conn = match self.connections.track(conn) { + Ok(track_auth_conn) => track_auth_conn, + Err(_) => { + ::tracing::error!( + "Could not track new auth srv connection" + ); + self.connections.remove(id_recv); + // FIXME: proper connection closing + let _ = cci.answer.send(Err( + handshake::Error::InternalTracking.into(), + )); + return; + } + }; let authsrv_conn = AuthSrvConn(connection::Conn { queue: self.queue_sender.clone(), conn: track_auth_conn, @@ -635,26 +654,25 @@ impl Worker { service_connection.id_recv = cci.service_connection_id; service_connection.id_send = IDSend(resp_data.service_connection_id); - let track_serv_conn = match self - .connections - .track(service_connection.into()) - { - Ok(track_serv_conn) => track_serv_conn, - Err(e) => { - ::tracing::error!( - "Could not track new service connection" - ); - self.connections - .remove(cci.service_connection_id); - // FIXME: proper connection closing - // FIXME: drop auth srv connection if we just - // established it - let _ = cci.answer.send(Err( - handshake::Error::InternalTracking.into(), - )); - return; - } - }; + let track_serv_conn = + match self.connections.track(service_connection) { + Ok(track_serv_conn) => track_serv_conn, + Err(_) => { + ::tracing::error!( + "Could not track new service connection" + ); + self.connections + .remove(cci.service_connection_id); + // FIXME: proper connection closing + // FIXME: drop auth srv connection if we just + // established it + let _ = cci.answer.send(Err( + handshake::Error::InternalTracking + .into(), + )); + return; + } + }; service_conn = Some(ServiceConn(connection::Conn { queue: self.queue_sender.clone(), conn: track_serv_conn, diff --git a/src/lib.rs b/src/lib.rs index a0ee7ae..8ac887a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -176,6 +176,7 @@ impl Fenrir { config: &Config, tokio_rt: Arc<::tokio::runtime::Runtime>, ) -> Result { + inner::set_minimum_sleep_resolution().await; let (sender, _) = ::tokio::sync::broadcast::channel(1); let dnssec = dnssec::Dnssec::new(&config.resolvers)?; // bind sockets early so we can change "port 0" (aka: random) @@ -214,6 +215,7 @@ impl Fenrir { pub async fn with_workers( config: &Config, ) -> Result<(Self, Vec), Error> { + inner::set_minimum_sleep_resolution().await; let (stop_working, _) = ::tokio::sync::broadcast::channel(1); let dnssec = dnssec::Dnssec::new(&config.resolvers)?; // bind sockets early so we can change "port 0" (aka: random) -- 2.47.2 From a810fc9a9e3c54a1f8d633cca7cf4ddc680d81ac Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Thu, 22 Jun 2023 20:12:50 +0200 Subject: [PATCH 3/8] Stream enqueue and serialize to the packet Signed-off-by: Luca Fulchir --- src/connection/mod.rs | 158 ++++++++++++++++++++++++++++++----- src/connection/stream/mod.rs | 77 +++++++++++++++++ src/enc/sym.rs | 8 ++ src/inner/worker.rs | 39 ++++++++- 4 files changed, 261 insertions(+), 21 deletions(-) diff --git a/src/connection/mod.rs b/src/connection/mod.rs index ea5dfcb..337b48f 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -14,8 +14,10 @@ use ::std::{ pub use crate::connection::{handshake::Handshake, packet::Packet}; use crate::{ + connection::socket::{UdpClient, UdpServer}, dnssec, enc::{ + self, asym::PubKey, hkdf::Hkdf, sym::{self, CipherRecv, CipherSend}, @@ -165,23 +167,6 @@ impl Conn { } } -/// A single connection and its data -#[derive(Debug)] -pub(crate) struct Connection { - /// Receiving Conn ID - pub(crate) id_recv: IDRecv, - /// Sending Conn ID - pub(crate) id_send: IDSend, - /// The main hkdf used for all secrets in this connection - hkdf: Hkdf, - /// Cipher for decrypting data - pub(crate) cipher_recv: CipherRecv, - /// Cipher for encrypting data - pub(crate) cipher_send: CipherSend, - /// send queue for each Stream - send_queue: BTreeMap>>, -} - /// Role: track the connection direction /// /// The Role is used to select the correct secrets, and track the direction @@ -197,6 +182,41 @@ pub enum Role { Client, } +#[derive(Debug)] +enum TimerKind { + None, + SendData(::tokio::time::Instant), + Keepalive(::tokio::time::Instant), +} + +pub(crate) enum Enqueue { + NoSuchStream, + TimerWait, + Immediate, +} + +/// A single connection and its data +#[derive(Debug)] +pub(crate) struct Connection { + /// Receiving Conn ID + pub(crate) id_recv: IDRecv, + /// Sending Conn ID + pub(crate) id_send: IDSend, + /// Sending address + pub(crate) send_addr: UdpClient, + /// The main hkdf used for all secrets in this connection + hkdf: Hkdf, + /// Cipher for decrypting data + pub(crate) cipher_recv: CipherRecv, + /// Cipher for encrypting data + pub(crate) cipher_send: CipherSend, + mtu: usize, + next_timer: TimerKind, + /// send queue for each Stream + send_queue: BTreeMap, + last_stream_sent: stream::ID, +} + impl Connection { pub(crate) fn new( hkdf: Hkdf, @@ -215,21 +235,119 @@ impl Connection { let cipher_recv = CipherRecv::new(cipher, secret_recv); let cipher_send = CipherSend::new(cipher, secret_send, rand); + use ::std::net::{IpAddr, Ipv4Addr, SocketAddr}; Self { id_recv: IDRecv(ID::Handshake), id_send: IDSend(ID::Handshake), + send_addr: UdpClient(SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + 31337, + )), hkdf, cipher_recv, cipher_send, + mtu: 1280, + next_timer: TimerKind::None, send_queue: BTreeMap::new(), + last_stream_sent: stream::ID(0), } } - pub(crate) fn send(&mut self, stream: stream::ID, data: Vec) { + pub(crate) fn enqueue( + &mut self, + stream: stream::ID, + data: Vec, + ) -> Enqueue { let stream = match self.send_queue.get_mut(&stream) { - None => return, + None => return Enqueue::NoSuchStream, Some(stream) => stream, }; - stream.push(data); + stream.enqueue(data); + let ret; + self.next_timer = match self.next_timer { + TimerKind::None | TimerKind::Keepalive(_) => { + ret = Enqueue::Immediate; + TimerKind::SendData(::tokio::time::Instant::now()) + } + TimerKind::SendData(old_timer) => { + // There already is some data to be sent + // wait for this timer, + // or risk going over max transmission rate + ret = Enqueue::TimerWait; + TimerKind::SendData(old_timer) + } + }; + ret + } + pub(crate) fn write_pkt<'a>( + &mut self, + raw: &'a mut [u8], + ) -> Result<&'a [u8], enc::Error> { + assert!(raw.len() >= 1200, "I should have at least 1200 MTU"); + if self.send_queue.len() == 0 { + return Err(enc::Error::NotEnoughData(0)); + } + raw[..ID::len()] + .copy_from_slice(&self.id_send.0.as_u64().to_le_bytes()); + let data_from = ID::len() + self.cipher_send.nonce_len().0; + let data_max_to = raw.len() - self.cipher_send.tag_len().0; + let mut chunk_from = data_from; + let mut available_len = data_max_to - data_from; + + use std::ops::Bound::{Excluded, Included}; + let last_stream = self.last_stream_sent; + + // Loop over our streams, write them to the packet. + // Notes: + // * to avoid starvation, just round-robin them all for now + // * we can enqueue multiple times the same stream + // This is useful especially for Datagram streams + 'queueloop: { + for (id, stream) in self + .send_queue + .range_mut((Included(last_stream), Included(stream::ID::max()))) + { + if available_len < stream::Chunk::headers_len() + 1 { + break 'queueloop; + } + let bytes = + stream.serialize(*id, &mut raw[chunk_from..data_max_to]); + if bytes == 0 { + break 'queueloop; + } + available_len = available_len - bytes; + chunk_from = chunk_from + bytes; + self.last_stream_sent = *id; + } + if available_len > 0 { + for (id, stream) in self.send_queue.range_mut(( + Included(stream::ID::min()), + Excluded(last_stream), + )) { + if available_len < stream::Chunk::headers_len() + 1 { + break 'queueloop; + } + let bytes = stream + .serialize(*id, &mut raw[chunk_from..data_max_to]); + if bytes == 0 { + break 'queueloop; + } + available_len = available_len - bytes; + chunk_from = chunk_from + bytes; + self.last_stream_sent = *id; + } + } + } + if chunk_from == data_from { + return Err(enc::Error::NotEnoughData(0)); + } + let data_to = chunk_from + self.cipher_send.tag_len().0; + + // encrypt + let aad = sym::AAD(&[]); + match self.cipher_send.encrypt(aad, &mut raw[data_from..data_to]) { + Ok(_) => Ok(&raw[..data_to]), + Err(e) => Err(e), + } } } diff --git a/src/connection/stream/mod.rs b/src/connection/stream/mod.rs index 58dd76e..778c9c5 100644 --- a/src/connection/stream/mod.rs +++ b/src/connection/stream/mod.rs @@ -27,6 +27,14 @@ impl ID { pub const fn len() -> usize { 2 } + /// Minimum possible Stream ID (u16::MIN) + pub const fn min() -> Self { + Self(u16::MIN) + } + /// Maximum possible Stream ID (u16::MAX) + pub const fn max() -> Self { + Self(u16::MAX) + } } /// length of the chunk @@ -79,6 +87,10 @@ impl<'a> Chunk<'a> { const FLAGS_EXCLUDED_BITMASK: u8 = 0x3F; const FLAG_START_BITMASK: u8 = 0x80; const FLAG_END_BITMASK: u8 = 0x40; + /// Return the length of the header of a Chunk + pub const fn headers_len() -> usize { + ID::len() + ChunkLen::len() + Sequence::len() + } /// Returns the total length of the chunk, including headers pub fn len(&self) -> usize { ID::len() + ChunkLen::len() + Sequence::len() + self.data.len() @@ -181,3 +193,68 @@ impl Stream { } } } + +/// Track what has been sent and what has been ACK'd from a stream +#[derive(Debug)] +pub(crate) struct SendTracker { + queue: Vec>, + sent: Vec, + ackd: Vec, + chunk_started: bool, + is_datagram: bool, + next_sequence: Sequence, +} +impl SendTracker { + pub(crate) fn new(rand: &Random) -> Self { + Self { + queue: Vec::with_capacity(4), + sent: Vec::with_capacity(4), + ackd: Vec::with_capacity(4), + chunk_started: false, + is_datagram: false, + next_sequence: Sequence::new(rand), + } + } + /// Enqueue user data to be sent + pub(crate) fn enqueue(&mut self, data: Vec) { + self.queue.push(data); + self.sent.push(0); + self.ackd.push(0); + } + /// Write the user data to the buffer and mark it as sent + pub(crate) fn get(&mut self, out: &mut [u8]) -> usize { + let data = match self.queue.get(0) { + Some(data) => data, + None => return 0, + }; + let len = ::std::cmp::min(out.len(), data.len()); + out[..len].copy_from_slice(&data[self.sent[0]..len]); + self.sent[0] = self.sent[0] + len; + len + } + /// Mark the sent data as successfully received from the receiver + pub(crate) fn ack(&mut self, size: usize) { + todo!() + } + pub(crate) fn serialize(&mut self, id: ID, raw: &mut [u8]) -> usize { + let max_data_len = raw.len() - Chunk::headers_len(); + let data_len = ::std::cmp::min(max_data_len, self.queue[0].len()); + let flag_start = !self.chunk_started; + let flag_end = self.is_datagram && data_len == self.queue[0].len(); + let chunk = Chunk { + id, + flag_start, + flag_end, + sequence: self.next_sequence, + data: &self.queue[0][..data_len], + }; + self.next_sequence = Sequence( + self.next_sequence.0 + ::core::num::Wrapping(data_len as u32), + ); + if chunk.flag_end { + self.chunk_started = false; + } + chunk.serialize(raw); + data_len + } +} diff --git a/src/enc/sym.rs b/src/enc/sym.rs index 14d712d..6c48f73 100644 --- a/src/enc/sym.rs +++ b/src/enc/sym.rs @@ -241,6 +241,14 @@ impl CipherSend { pub fn kind(&self) -> Kind { self.cipher.kind() } + /// Get the length of the nonce for this cipher + pub fn nonce_len(&self) -> NonceLen { + self.cipher.nonce_len() + } + /// Get the length of the nonce for this cipher + pub fn tag_len(&self) -> TagLen { + self.cipher.tag_len() + } } /// XChaCha20Poly1305 cipher diff --git a/src/inner/worker.rs b/src/inner/worker.rs index 0bb41d0..ba9ec3a 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -16,6 +16,7 @@ use crate::{ }, dnssec, enc::{ + self, asym::{self, KeyID, PrivKey, PubKey}, hkdf::{self, Hkdf}, sym, Random, Secret, @@ -53,6 +54,7 @@ pub(crate) enum Work { DropHandshake(KeyID), Recv(RawUdp), UserSend((UserConnTracker, stream::ID, Vec)), + SendData(UserConnTracker), } /// Actual worker implementation. @@ -437,7 +439,42 @@ impl Worker { None => return, Some(conn) => conn, }; - conn.send(stream, data); + use connection::Enqueue; + match conn.enqueue(stream, data) { + Enqueue::Immediate => { + let _ = self + .queue_sender + .send(Work::SendData(tracker)) + .await; + } + Enqueue::TimerWait => {} + Enqueue::NoSuchStream => { + ::tracing::error!( + "Trying to send on unknown stream" + ); + } + } + } + Work::SendData(tracker) => { + let mut raw: Vec = Vec::with_capacity(1280); + raw.resize(raw.capacity(), 0); + let conn = match self.connections.get_mut(tracker) { + None => return, + Some(conn) => conn, + }; + let pkt = match conn.write_pkt(&mut raw) { + Ok(pkt) => pkt, + Err(enc::Error::NotEnoughData(0)) => return, + Err(e) => { + ::tracing::error!("Packet generation: {:?}", e); + return; + } + }; + let dest = conn.send_addr; + let src = UdpServer(self.sockets[0].local_addr().unwrap()); + let len = pkt.len(); + raw.truncate(len); + let _ = self.send_packet(raw, dest, src); } } } -- 2.47.2 From c3c8238730a8f344ba728892956896b444a5b3a2 Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Sun, 25 Jun 2023 19:22:40 +0200 Subject: [PATCH 4/8] Stream ROB: Reconstruct data TCP-like this was more convoluted than I thought. maybe someone will simplify this. Signed-off-by: Luca Fulchir --- flake.lock | 24 +- src/connection/handshake/tracker.rs | 8 +- src/connection/mod.rs | 71 ++++- src/connection/stream/errors.rs | 4 +- src/connection/stream/mod.rs | 71 ++++- src/connection/stream/rob.rs | 29 -- src/connection/stream/rob/mod.rs | 204 ++++++++++++ src/connection/stream/rob/tests.rs | 249 +++++++++++++++ src/inner/mod.rs | 26 ++ src/inner/worker.rs | 471 +++++++++++++++------------- 10 files changed, 878 insertions(+), 279 deletions(-) delete mode 100644 src/connection/stream/rob.rs create mode 100644 src/connection/stream/rob/mod.rs create mode 100644 src/connection/stream/rob/tests.rs diff --git a/flake.lock b/flake.lock index 8eb84e1..85c58c2 100644 --- a/flake.lock +++ b/flake.lock @@ -5,11 +5,11 @@ "systems": "systems" }, "locked": { - "lastModified": 1685518550, - "narHash": "sha256-o2d0KcvaXzTrPRIo0kOLV0/QXHhDQ5DTi+OxcjO8xqY=", + "lastModified": 1687171271, + "narHash": "sha256-BJlq+ozK2B1sJDQXS3tzJM5a+oVZmi1q0FlBK/Xqv7M=", "owner": "numtide", "repo": "flake-utils", - "rev": "a1720a10a6cfe8234c0e93907ffe81be440f4cef", + "rev": "abfb11bd1aec8ced1c9bb9adfe68018230f4fb3c", "type": "github" }, "original": { @@ -38,11 +38,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1686921029, - "narHash": "sha256-J1bX9plPCFhTSh6E3TWn9XSxggBh/zDD4xigyaIQBy8=", + "lastModified": 1687555006, + "narHash": "sha256-GD2Kqb/DXQBRJcHqkM2qFZqbVenyO7Co/80JHRMg2U0=", "owner": "nixos", "repo": "nixpkgs", - "rev": "c7ff1b9b95620ce8728c0d7bd501c458e6da9e04", + "rev": "33223d479ffde3d05ac16c6dff04ae43cc27e577", "type": "github" }, "original": { @@ -54,11 +54,11 @@ }, "nixpkgs-unstable": { "locked": { - "lastModified": 1686960236, - "narHash": "sha256-AYCC9rXNLpUWzD9hm+askOfpliLEC9kwAo7ITJc4HIw=", + "lastModified": 1687502512, + "narHash": "sha256-dBL/01TayOSZYxtY4cMXuNCBk8UMLoqRZA+94xiFpJA=", "owner": "nixos", "repo": "nixpkgs", - "rev": "04af42f3b31dba0ef742d254456dc4c14eedac86", + "rev": "3ae20aa58a6c0d1ca95c9b11f59a2d12eebc511f", "type": "github" }, "original": { @@ -98,11 +98,11 @@ "nixpkgs": "nixpkgs_2" }, "locked": { - "lastModified": 1687055571, - "narHash": "sha256-UvLoO6u5n9TzY80BpM4DaacxvyJl7u9mm9CA72d309g=", + "lastModified": 1687660699, + "narHash": "sha256-crI/CA/OJc778I5qJhwhhl8/PKKzc0D7vvVxOtjfvSo=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "2de557c780dcb127128ae987fca9d6c2b0d7dc0f", + "rev": "b3bd1d49f1ae609c1d68a66bba7a95a9a4256031", "type": "github" }, "original": { diff --git a/src/connection/handshake/tracker.rs b/src/connection/handshake/tracker.rs index c40158e..fbedc34 100644 --- a/src/connection/handshake/tracker.rs +++ b/src/connection/handshake/tracker.rs @@ -37,7 +37,7 @@ pub(crate) struct Client { pub(crate) service_id: ServiceID, pub(crate) service_conn_id: IDRecv, pub(crate) connection: Connection, - pub(crate) timeout: Option<::tokio::task::JoinHandle<()>>, + pub(crate) timeout: Option<::tokio::time::Instant>, pub(crate) answer: oneshot::Sender, pub(crate) srv_key_id: KeyID, } @@ -150,6 +150,8 @@ pub(crate) struct ClientConnectInfo { pub(crate) service_connection_id: IDRecv, /// Parsed handshake packet pub(crate) handshake: Handshake, + /// Old timeout for the handshake completion + pub(crate) old_timeout: ::tokio::time::Instant, /// Connection pub(crate) connection: Connection, /// where to wake up the waiting client @@ -374,13 +376,11 @@ impl Tracker { } let hshake = self.hshake_cli.remove(resp.client_key_id).unwrap(); - if let Some(timeout) = hshake.timeout { - timeout.abort(); - } return Ok(Action::ClientConnect(ClientConnectInfo { service_id: hshake.service_id, service_connection_id: hshake.service_conn_id, handshake, + old_timeout: hshake.timeout.unwrap(), connection: hshake.connection, answer: hshake.answer, srv_key_id: hshake.srv_key_id, diff --git a/src/connection/mod.rs b/src/connection/mod.rs index 337b48f..dbe62e1 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -14,7 +14,7 @@ use ::std::{ pub use crate::connection::{handshake::Handshake, packet::Packet}; use crate::{ - connection::socket::{UdpClient, UdpServer}, + connection::socket::UdpClient, dnssec, enc::{ self, @@ -26,6 +26,16 @@ use crate::{ inner::{worker, ThreadTracker}, }; +/// Connaction errors +#[derive(::thiserror::Error, Debug, Copy, Clone)] +pub(crate) enum Error { + /// Can't decrypt packet + #[error("Decrypt error: {0}")] + Decrypt(#[from] crate::enc::Error), + #[error("Chunk parsing: {0}")] + Parse(#[from] stream::Error), +} + /// Fenrir Connection ID /// /// 0 is special as it represents the handshake @@ -192,7 +202,7 @@ enum TimerKind { pub(crate) enum Enqueue { NoSuchStream, TimerWait, - Immediate, + Immediate(::tokio::time::Instant), } /// A single connection and its data @@ -215,6 +225,8 @@ pub(crate) struct Connection { /// send queue for each Stream send_queue: BTreeMap, last_stream_sent: stream::ID, + /// receive queue for each Stream + recv_queue: BTreeMap, } impl Connection { @@ -246,12 +258,46 @@ impl Connection { hkdf, cipher_recv, cipher_send, - mtu: 1280, + mtu: 1200, next_timer: TimerKind::None, send_queue: BTreeMap::new(), last_stream_sent: stream::ID(0), + recv_queue: BTreeMap::new(), } } + pub(crate) fn recv(&mut self, mut udp: crate::RawUdp) -> Result<(), Error> { + let mut data = &mut udp.data[ID::len()..]; + let aad = enc::sym::AAD(&[]); + self.cipher_recv.decrypt(aad, &mut data)?; + let mut bytes_parsed = 0; + let mut chunks = Vec::with_capacity(2); + loop { + let chunk = match stream::Chunk::deserialize(&data[bytes_parsed..]) + { + Ok(chunk) => chunk, + Err(e) => { + return Err(e.into()); + } + }; + bytes_parsed = bytes_parsed + chunk.len(); + chunks.push(chunk); + if bytes_parsed == data.len() { + break; + } + } + for chunk in chunks.into_iter() { + let stream = match self.recv_queue.get_mut(&chunk.id) { + Some(stream) => stream, + None => { + ::tracing::debug!("Ignoring chunk for unknown stream::ID"); + continue; + } + }; + stream.recv(chunk); + } + // FIXME: report if we need to return data to the user + Ok(()) + } pub(crate) fn enqueue( &mut self, stream: stream::ID, @@ -262,11 +308,13 @@ impl Connection { Some(stream) => stream, }; stream.enqueue(data); + let instant; let ret; self.next_timer = match self.next_timer { TimerKind::None | TimerKind::Keepalive(_) => { - ret = Enqueue::Immediate; - TimerKind::SendData(::tokio::time::Instant::now()) + instant = ::tokio::time::Instant::now(); + ret = Enqueue::Immediate(instant); + TimerKind::SendData(instant) } TimerKind::SendData(old_timer) => { // There already is some data to be sent @@ -282,7 +330,7 @@ impl Connection { &mut self, raw: &'a mut [u8], ) -> Result<&'a [u8], enc::Error> { - assert!(raw.len() >= 1200, "I should have at least 1200 MTU"); + assert!(raw.len() >= self.mtu, "I should have at least 1200 MTU"); if self.send_queue.len() == 0 { return Err(enc::Error::NotEnoughData(0)); } @@ -378,6 +426,17 @@ impl ConnList { ret.connections.resize_with(INITIAL_CAP, || None); ret } + pub fn get_id_mut(&mut self, id: ID) -> Option<&mut Connection> { + let conn_id = match id { + ID::Handshake => { + return None; + } + ID::ID(conn_id) => conn_id, + }; + let id_in_thread: usize = + (conn_id.get() / (self.thread_id.total as u64)) as usize; + (&mut self.connections[id_in_thread]).into() + } pub fn get_mut( &mut self, tracker: UserConnTracker, diff --git a/src/connection/stream/errors.rs b/src/connection/stream/errors.rs index 133d976..07dcbd8 100644 --- a/src/connection/stream/errors.rs +++ b/src/connection/stream/errors.rs @@ -1,10 +1,12 @@ //! Errors while parsing streams - /// Crypto errors #[derive(::thiserror::Error, Debug, Copy, Clone)] pub enum Error { /// Error while parsing key material #[error("Not enough data for stream chunk: {0}")] NotEnoughData(usize), + /// Sequence outside of the window + #[error("Sequence out of the sliding window")] + OutOfWindow, } diff --git a/src/connection/stream/mod.rs b/src/connection/stream/mod.rs index 778c9c5..96971cc 100644 --- a/src/connection/stream/mod.rs +++ b/src/connection/stream/mod.rs @@ -48,6 +48,30 @@ impl ChunkLen { } } +//TODO: make pub? +#[derive(Debug, Copy, Clone)] +pub(crate) struct SequenceStart(pub(crate) Sequence); +impl SequenceStart { + pub(crate) fn plus_u32(&self, other: u32) -> Sequence { + self.0.plus_u32(other) + } + pub(crate) fn offset(&self, seq: Sequence) -> usize { + if self.0 .0 <= seq.0 { + (seq.0 - self.0 .0).0 as usize + } else { + (seq.0 + (Sequence::max().0 - self.0 .0)).0 as usize + } + } +} +// SequenceEnd is INCLUSIVE +#[derive(Debug, Copy, Clone)] +pub(crate) struct SequenceEnd(pub(crate) Sequence); +impl SequenceEnd { + pub(crate) fn plus_u32(&self, other: u32) -> Sequence { + self.0.plus_u32(other) + } +} + /// Sequence number to rebuild the stream correctly #[derive(Debug, Copy, Clone)] pub struct Sequence(pub ::core::num::Wrapping); @@ -56,14 +80,52 @@ impl Sequence { const SEQ_NOFLAG: u32 = 0x3FFFFFFF; /// return a new sequence number, starting at random pub fn new(rand: &Random) -> Self { - let seq: u32 = 0; - rand.fill(&mut seq.to_le_bytes()); + let mut raw_seq: [u8; 4] = [0; 4]; + rand.fill(&mut raw_seq); + let seq = u32::from_le_bytes(raw_seq); Self(::core::num::Wrapping(seq & Self::SEQ_NOFLAG)) } /// Length of the serialized field pub const fn len() -> usize { 4 } + /// Maximum possible sequence + pub const fn max() -> Self { + Self(::core::num::Wrapping(Self::SEQ_NOFLAG)) + } + pub(crate) fn is_between( + &self, + start: SequenceStart, + end: SequenceEnd, + ) -> bool { + if start.0 .0 < end.0 .0 { + start.0 .0 <= self.0 && self.0 <= end.0 .0 + } else { + start.0 .0 <= self.0 || self.0 <= end.0 .0 + } + } + pub(crate) fn remaining_window(&self, end: SequenceEnd) -> u32 { + if self.0 <= end.0 .0 { + (end.0 .0 .0 - self.0 .0) + 1 + } else { + end.0 .0 .0 + 1 + (Self::max().0 - self.0).0 + } + } + pub(crate) fn plus_u32(self, other: u32) -> Self { + Self(::core::num::Wrapping( + (self.0 .0 + other) & Self::SEQ_NOFLAG, + )) + } +} + +impl ::core::ops::Add for Sequence { + type Output = Self; + + fn add(self, other: Self) -> Self { + Self(::core::num::Wrapping( + (self.0 + other.0).0 & Self::SEQ_NOFLAG, + )) + } } /// Chunk of data representing a stream @@ -192,6 +254,11 @@ impl Stream { data: Tracker::new(kind, rand), } } + pub(crate) fn recv(&mut self, chunk: Chunk) -> Result<(), Error> { + match &mut self.data { + Tracker::ROB(tracker) => tracker.recv(chunk), + } + } } /// Track what has been sent and what has been ACK'd from a stream diff --git a/src/connection/stream/rob.rs b/src/connection/stream/rob.rs deleted file mode 100644 index 5d28f59..0000000 --- a/src/connection/stream/rob.rs +++ /dev/null @@ -1,29 +0,0 @@ -//! Implementation of the Reliable, Ordered, Bytestream transmission model -//! AKA: TCP-like - -use crate::{ - connection::stream::{Chunk, Error, Sequence}, - enc::Random, -}; - -/// Reliable, Ordered, Bytestream stream tracker -/// AKA: TCP-like -#[derive(Debug, Clone)] -pub(crate) struct ReliableOrderedBytestream { - window_start: Sequence, - window_len: usize, - data: Vec, -} - -impl ReliableOrderedBytestream { - pub(crate) fn new(rand: &Random) -> Self { - Self { - window_start: Sequence::new(rand), - window_len: 1048576, // 1MB. should be enough for anybody. (lol) - data: Vec::new(), - } - } - pub(crate) fn recv(&mut self, chunk: Chunk) -> Result<(), Error> { - todo!() - } -} diff --git a/src/connection/stream/rob/mod.rs b/src/connection/stream/rob/mod.rs new file mode 100644 index 0000000..1bfd159 --- /dev/null +++ b/src/connection/stream/rob/mod.rs @@ -0,0 +1,204 @@ +//! Implementation of the Reliable, Ordered, Bytestream transmission model +//! AKA: TCP-like + +use crate::{ + connection::stream::{Chunk, Error, Sequence, SequenceEnd, SequenceStart}, + enc::Random, +}; + +#[cfg(test)] +mod tests; + +/// Reliable, Ordered, Bytestream stream tracker +/// AKA: TCP-like +#[derive(Debug, Clone)] +pub(crate) struct ReliableOrderedBytestream { + pub(crate) window_start: SequenceStart, + window_end: SequenceEnd, + pivot: u32, + data: Vec, + missing: Vec<(Sequence, Sequence)>, +} + +impl ReliableOrderedBytestream { + pub(crate) fn new(rand: &Random) -> Self { + let window_len = 1048576; // 1MB. should be enough for anybody. (lol) + let window_start = SequenceStart(Sequence::new(rand)); + let window_end = SequenceEnd(window_start.0.plus_u32(window_len - 1)); + let mut data = Vec::with_capacity(window_len as usize); + data.resize(data.capacity(), 0); + + Self { + window_start, + window_end, + pivot: window_len, + data, + missing: [(window_start.0, window_end.0)].to_vec(), + } + } + pub(crate) fn with_window_size(rand: &Random, size: u32) -> Self { + assert!( + size < Sequence::max().0 .0, + "Max window size is {}", + Sequence::max().0 .0 + ); + let window_len = size; // 1MB. should be enough for anybody. (lol) + let window_start = SequenceStart(Sequence::new(rand)); + let window_end = SequenceEnd(window_start.0.plus_u32(window_len - 1)); + let mut data = Vec::with_capacity(window_len as usize); + data.resize(data.capacity(), 0); + + Self { + window_start, + window_end, + pivot: window_len, + data, + missing: [(window_start.0, window_end.0)].to_vec(), + } + } + pub(crate) fn window_size(&self) -> u32 { + self.data.len() as u32 + } + pub(crate) fn get(&mut self) -> Vec { + if self.missing.len() == 0 { + let (first, second) = self.data.split_at(self.pivot as usize); + let mut ret = Vec::with_capacity(self.data.len()); + ret.extend_from_slice(first); + ret.extend_from_slice(second); + self.window_start = + SequenceStart(self.window_start.plus_u32(ret.len() as u32)); + self.window_end = + SequenceEnd(self.window_end.plus_u32(ret.len() as u32)); + self.data.clear(); + return ret; + } + let data_len = self.window_start.offset(self.missing[0].0); + let last_missing_idx = self.missing.len() - 1; + let mut last_missing = &mut self.missing[last_missing_idx]; + last_missing.1 = last_missing.1.plus_u32(data_len as u32); + self.window_start = + SequenceStart(self.window_start.plus_u32(data_len as u32)); + self.window_end = + SequenceEnd(self.window_end.plus_u32(data_len as u32)); + + let mut ret = Vec::with_capacity(data_len); + let (first, second) = self.data[..].split_at(self.pivot as usize); + let first_len = ::core::cmp::min(data_len, first.len()); + let second_len = data_len - first_len; + + ret.extend_from_slice(&first[..first_len]); + ret.extend_from_slice(&second[..second_len]); + + self.pivot = + ((self.pivot as usize + data_len) % self.data.len()) as u32; + ret + } + pub(crate) fn recv(&mut self, chunk: Chunk) -> Result<(), Error> { + if !chunk + .sequence + .is_between(self.window_start, self.window_end) + { + return Err(Error::OutOfWindow); + } + // make sure we consider only the bytes inside the sliding window + let maxlen = ::std::cmp::min( + chunk.sequence.remaining_window(self.window_end) as usize, + chunk.data.len(), + ); + if maxlen == 0 { + // or empty chunk, but we don't care + return Err(Error::OutOfWindow); + } + // translate Sequences to offsets in self.data + let data = &chunk.data[..maxlen]; + let offset = self.window_start.offset(chunk.sequence); + let offset_end = offset + chunk.data.len() - 1; + + // Find the chunks we are missing that we can copy, + // and fix the missing tracker + let mut copy_ranges = Vec::new(); + let mut to_delete = Vec::new(); + let mut to_add = Vec::new(); + // note: te included ranges are (INCLUSIVE, INCLUSIVE) + for (idx, el) in self.missing.iter_mut().enumerate() { + let missing_from = self.window_start.offset(el.0); + if missing_from > offset_end { + break; + } + let missing_to = self.window_start.offset(el.1); + if missing_to < offset { + continue; + } + if missing_from >= offset && missing_from <= offset_end { + if missing_to <= offset_end { + // [.....chunk.....] + // [..missing..] + to_delete.push(idx); + copy_ranges.push((missing_from, missing_to)); + } else { + // [....chunk....] + // [...missing...] + copy_ranges.push((missing_from, offset_end)); + el.0 = + el.0.plus_u32(((offset_end - missing_from) + 1) as u32); + } + } else if missing_from < offset { + if missing_to > offset_end { + // [..chunk..] + // [....missing....] + // chunk is in the middle of a missing fragment + to_add.push(( + el.0.plus_u32(((offset_end - missing_from) + 1) as u32), + el.1, + )); + el.1 = el.0.plus_u32(((offset - missing_from) - 1) as u32); + copy_ranges.push((offset, offset_end)); + } else if offset <= missing_to { + // [....chunk....] + // [...missing...] + // chunk + copy_ranges.push((offset, (missing_to - 0))); + el.1 = + el.0.plus_u32(((offset_end - missing_from) - 1) as u32); + } + } + } + self.missing.append(&mut to_add); + self.missing + .sort_by(|(from_a, _), (from_b, _)| from_a.0 .0.cmp(&from_b.0 .0)); + { + let mut deleted = 0; + for idx in to_delete.into_iter() { + self.missing.remove(idx + deleted); + deleted = deleted + 1; + } + } + // copy only the missing data + let (first, second) = self.data[..].split_at_mut(self.pivot as usize); + for (from, to) in copy_ranges.into_iter() { + let to = to + 1; + if from <= first.len() { + let first_from = from; + let first_to = ::core::cmp::min(first.len(), to); + let data_first_from = from - offset; + let data_first_to = first_to - offset; + first[first_from..first_to] + .copy_from_slice(&data[data_first_from..data_first_to]); + + let second_to = to - first_to; + let data_second_to = data_first_to + second_to; + second[..second_to] + .copy_from_slice(&data[data_first_to..data_second_to]); + } else { + let second_from = from - first.len(); + let second_to = to - first.len(); + let data_from = from - offset; + let data_to = to - offset; + second[second_from..second_to] + .copy_from_slice(&data[data_from..data_to]); + } + } + + Ok(()) + } +} diff --git a/src/connection/stream/rob/tests.rs b/src/connection/stream/rob/tests.rs new file mode 100644 index 0000000..20cb508 --- /dev/null +++ b/src/connection/stream/rob/tests.rs @@ -0,0 +1,249 @@ +use crate::{ + connection::stream::{self, rob::*, Chunk}, + enc::Random, +}; + +#[::tracing_test::traced_test] +#[test] +fn test_stream_rob_sequential() { + let rand = Random::new(); + let mut rob = ReliableOrderedBytestream::with_window_size(&rand, 1048576); + + let mut data = Vec::with_capacity(1024); + data.resize(data.capacity(), 0); + rand.fill(&mut data[..]); + + let start = rob.window_start.0; + + let chunk = Chunk { + id: stream::ID(42), + flag_start: true, + flag_end: false, + sequence: start, + data: &data[..512], + }; + let got = rob.get(); + assert!(&got[..] == &[], "rob: got data?"); + let _ = rob.recv(chunk); + let got = rob.get(); + assert!( + &data[..512] == &got[..], + "ROB1: DIFF: {:?} {:?}", + &data[..512].len(), + &got[..].len() + ); + let chunk = Chunk { + id: stream::ID(42), + flag_start: false, + flag_end: true, + sequence: start.plus_u32(512), + data: &data[512..], + }; + let _ = rob.recv(chunk); + let got = rob.get(); + assert!( + &data[512..] == &got[..], + "ROB2: DIFF: {:?} {:?}", + &data[512..].len(), + &got[..].len() + ); +} + +#[::tracing_test::traced_test] +#[test] +fn test_stream_rob_retransmit() { + let rand = Random::new(); + let max_window: usize = 100; + let mut rob = + ReliableOrderedBytestream::with_window_size(&rand, max_window as u32); + + let mut data = Vec::with_capacity(120); + data.resize(data.capacity(), 0); + for i in 0..data.len() { + data[i] = i as u8; + } + + let start = rob.window_start.0; + + let chunk = Chunk { + id: stream::ID(42), + flag_start: true, + flag_end: false, + sequence: start, + data: &data[..40], + }; + let _ = rob.recv(chunk); + let chunk = Chunk { + id: stream::ID(42), + flag_start: false, + flag_end: false, + sequence: start.plus_u32(50), + data: &data[50..60], + }; + let _ = rob.recv(chunk); + let chunk = Chunk { + id: stream::ID(42), + flag_start: false, + flag_end: false, + sequence: start.plus_u32(40), + data: &data[40..60], + }; + let _ = rob.recv(chunk); + let chunk = Chunk { + id: stream::ID(42), + flag_start: false, + flag_end: false, + sequence: start.plus_u32(80), + data: &data[80..], + }; + let _ = rob.recv(chunk); + let chunk = Chunk { + id: stream::ID(42), + flag_start: false, + flag_end: false, + sequence: start.plus_u32(50), + data: &data[50..90], + }; + let _ = rob.recv(chunk); + let chunk = Chunk { + id: stream::ID(42), + flag_start: false, + flag_end: false, + sequence: start.plus_u32(max_window as u32), + data: &data[max_window..], + }; + let _ = rob.recv(chunk); + let chunk = Chunk { + id: stream::ID(42), + flag_start: false, + flag_end: true, + sequence: start.plus_u32(90), + data: &data[90..max_window], + }; + let _ = rob.recv(chunk); + let got = rob.get(); + assert!( + &data[..max_window] == &got[..], + "DIFF:\n {:?}\n {:?}", + &data[..max_window], + &got[..], + ); +} +#[::tracing_test::traced_test] +#[test] +fn test_stream_rob_rolling() { + let rand = Random::new(); + let max_window: usize = 100; + let mut rob = + ReliableOrderedBytestream::with_window_size(&rand, max_window as u32); + + let mut data = Vec::with_capacity(120); + data.resize(data.capacity(), 0); + for i in 0..data.len() { + data[i] = i as u8; + } + + let start = rob.window_start.0; + + let chunk = Chunk { + id: stream::ID(42), + flag_start: true, + flag_end: false, + sequence: start, + data: &data[..40], + }; + let _ = rob.recv(chunk); + let chunk = Chunk { + id: stream::ID(42), + flag_start: true, + flag_end: false, + sequence: start.plus_u32(50), + data: &data[50..100], + }; + let _ = rob.recv(chunk); + let got = rob.get(); + assert!( + &data[..40] == &got[..], + "DIFF:\n {:?}\n {:?}", + &data[..40], + &got[..], + ); + let chunk = Chunk { + id: stream::ID(42), + flag_start: true, + flag_end: false, + sequence: start.plus_u32(40), + data: &data[40..], + }; + let _ = rob.recv(chunk); + let got = rob.get(); + assert!( + &data[40..] == &got[..], + "DIFF:\n {:?}\n {:?}", + &data[40..], + &got[..], + ); +} +#[::tracing_test::traced_test] +#[test] +fn test_stream_rob_rolling_second_case() { + let rand = Random::new(); + let max_window: usize = 100; + let mut rob = + ReliableOrderedBytestream::with_window_size(&rand, max_window as u32); + + let mut data = Vec::with_capacity(120); + data.resize(data.capacity(), 0); + for i in 0..data.len() { + data[i] = i as u8; + } + + let start = rob.window_start.0; + + let chunk = Chunk { + id: stream::ID(42), + flag_start: true, + flag_end: false, + sequence: start, + data: &data[..40], + }; + let _ = rob.recv(chunk); + let chunk = Chunk { + id: stream::ID(42), + flag_start: true, + flag_end: false, + sequence: start.plus_u32(50), + data: &data[50..100], + }; + let _ = rob.recv(chunk); + let got = rob.get(); + assert!( + &data[..40] == &got[..], + "DIFF:\n {:?}\n {:?}", + &data[..40], + &got[..], + ); + let chunk = Chunk { + id: stream::ID(42), + flag_start: true, + flag_end: false, + sequence: start.plus_u32(40), + data: &data[40..100], + }; + let _ = rob.recv(chunk); + let chunk = Chunk { + id: stream::ID(42), + flag_start: true, + flag_end: false, + sequence: start.plus_u32(100), + data: &data[100..], + }; + let _ = rob.recv(chunk); + let got = rob.get(); + assert!( + &data[40..] == &got[..], + "DIFF:\n {:?}\n {:?}", + &data[40..], + &got[..], + ); +} diff --git a/src/inner/mod.rs b/src/inner/mod.rs index e23d614..6102fde 100644 --- a/src/inner/mod.rs +++ b/src/inner/mod.rs @@ -66,6 +66,32 @@ impl Timers { } } } + pub(crate) fn add( + &mut self, + duration: ::tokio::time::Duration, + work: Work, + ) -> ::tokio::time::Instant { + // the returned time is the key in the map. + // Make sure it is unique. + // + // We can be pretty sure we won't do a lot of stuff + // in a single nanosecond, so if we hit a time that is already present + // just add a nanosecond and retry + let mut time = ::tokio::time::Instant::now() + duration; + let mut work = work; + loop { + if let Some(old_val) = self.times.insert(time, work) { + work = self.times.insert(time, old_val).unwrap(); + time = time + ::std::time::Duration::from_nanos(1); + } else { + break; + } + } + time + } + pub(crate) fn remove(&mut self, time: ::tokio::time::Instant) { + let _ = self.times.remove(&time); + } /// Get all the work from now up until now + SLEEP_RESOLUTION pub(crate) fn get_work(&mut self) -> Vec { let now: ::tokio::time::Instant = ::std::time::Instant::now().into(); diff --git a/src/inner/worker.rs b/src/inner/worker.rs index ba9ec3a..725578d 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -54,7 +54,7 @@ pub(crate) enum Work { DropHandshake(KeyID), Recv(RawUdp), UserSend((UserConnTracker, stream::ID, Vec)), - SendData(UserConnTracker), + SendData((UserConnTracker, ::tokio::time::Instant)), } /// Actual worker implementation. @@ -317,6 +317,8 @@ impl Worker { connection::Role::Client, &self.rand, ); + let dest = UdpClient(addr.as_sockaddr().unwrap()); + conn.send_addr = dest; let auth_recv_id = self.connections.reserve_first(); let service_conn_id = self.connections.reserve_first(); @@ -407,15 +409,13 @@ impl Worker { // send always from the first socket // FIXME: select based on routing table let sender = self.sockets[0].local_addr().unwrap(); - let dest = UdpClient(addr.as_sockaddr().unwrap()); // start the timeout right before sending the packet - hshake.timeout = Some(::tokio::task::spawn_local( - Self::handshake_timeout( - self.queue_timeouts_send.clone(), - client_key_id, - ), - )); + let time_drop = self.work_timers.add( + ::tokio::time::Duration::from_secs(10), + Work::DropHandshake(client_key_id), + ); + hshake.timeout = Some(time_drop); // send packet self.send_packet(raw, dest, UdpServer(sender)).await; @@ -441,10 +441,10 @@ impl Worker { }; use connection::Enqueue; match conn.enqueue(stream, data) { - Enqueue::Immediate => { + Enqueue::Immediate(instant) => { let _ = self .queue_sender - .send(Work::SendData(tracker)) + .send(Work::SendData((tracker, instant))) .await; } Enqueue::TimerWait => {} @@ -455,8 +455,23 @@ impl Worker { } } } - Work::SendData(tracker) => { - let mut raw: Vec = Vec::with_capacity(1280); + Work::SendData((tracker, instant)) => { + // make sure we don't process events before they are + // actually needed. + // This is basically busy waiting with extra steps, + // but we don't want to spawn lots of timers and + // we don't really have a fine-grained sleep that is + // multiplatform + let now = ::tokio::time::Instant::now(); + if instant <= now { + let _ = self + .queue_sender + .send(Work::SendData((tracker, instant))) + .await; + return; + } + + let mut raw: Vec = Vec::with_capacity(1200); raw.resize(raw.capacity(), 0); let conn = match self.connections.get_mut(tracker) { None => return, @@ -479,13 +494,6 @@ impl Worker { } } } - async fn handshake_timeout( - timeout_queue: mpsc::UnboundedSender, - key_id: KeyID, - ) { - ::tokio::time::sleep(::std::time::Duration::from_secs(10)).await; - let _ = timeout_queue.send(Work::DropHandshake(key_id)); - } /// Read and do stuff with the raw udp packet async fn recv(&mut self, mut udp: RawUdp) { if udp.packet.id.is_handshake() { @@ -508,224 +516,237 @@ impl Worker { return; } }; - match action { - handshake::Action::AuthNeeded(authinfo) => { - let req; - if let handshake::Data::DirSync(DirSync::Req(r)) = - authinfo.handshake.data - { - req = r; - } else { - ::tracing::error!("AuthInfo on non DS::Req"); + self.recv_handshake(udp, action).await; + } else { + self.recv_packet(udp); + } + } + /// Receive a non-handshake packet + fn recv_packet(&mut self, udp: RawUdp) { + let conn = match self.connections.get_id_mut(udp.packet.id) { + None => return, + Some(conn) => conn, + }; + if let Err(e) = conn.recv(udp) { + ::tracing::trace!("Conn Recv: {:?}", e.to_string()); + } + } + /// Receive an handshake packet + async fn recv_handshake(&mut self, udp: RawUdp, action: handshake::Action) { + match action { + handshake::Action::AuthNeeded(authinfo) => { + let req; + if let handshake::Data::DirSync(DirSync::Req(r)) = + authinfo.handshake.data + { + req = r; + } else { + ::tracing::error!("AuthInfo on non DS::Req"); + return; + } + let req_data = match req.data { + dirsync::req::State::ClearText(req_data) => req_data, + _ => { + ::tracing::error!("AuthNeeded: expected ClearText"); + assert!(false, "AuthNeeded: unreachable"); return; } - let req_data = match req.data { - dirsync::req::State::ClearText(req_data) => req_data, - _ => { - ::tracing::error!("AuthNeeded: expected ClearText"); - assert!(false, "AuthNeeded: unreachable"); - return; - } - }; - // FIXME: This part can take a while, - // we should just spawn it probably - let maybe_auth_check = { - match &self.token_check { - None => { - if req_data.auth.user == auth::USERID_ANONYMOUS - { - Ok(true) - } else { - Ok(false) - } - } - Some(token_check) => { - let tk_check = token_check.lock().await; - tk_check( - req_data.auth.user, - req_data.auth.token, - req_data.auth.service_id, - req_data.auth.domain, - ) - .await + }; + // FIXME: This part can take a while, + // we should just spawn it probably + let maybe_auth_check = { + match &self.token_check { + None => { + if req_data.auth.user == auth::USERID_ANONYMOUS { + Ok(true) + } else { + Ok(false) } } - }; - let is_authenticated = match maybe_auth_check { - Ok(is_authenticated) => is_authenticated, - Err(_) => { - ::tracing::error!("error in token auth"); - // TODO: retry? - return; + Some(token_check) => { + let tk_check = token_check.lock().await; + tk_check( + req_data.auth.user, + req_data.auth.token, + req_data.auth.service_id, + req_data.auth.domain, + ) + .await } - }; - if !is_authenticated { - ::tracing::warn!( - "Wrong authentication for user {:?}", - req_data.auth.user - ); - // TODO: error response + } + }; + let is_authenticated = match maybe_auth_check { + Ok(is_authenticated) => is_authenticated, + Err(_) => { + ::tracing::error!("error in token auth"); + // TODO: retry? return; } - // Client has correctly authenticated - // TODO: contact the service, get the key and - // connection ID - let srv_conn_id = connection::ID::new_rand(&self.rand); - let srv_secret = Secret::new_rand(&self.rand); - let head_len = req.cipher.nonce_len(); - let tag_len = req.cipher.tag_len(); + }; + if !is_authenticated { + ::tracing::warn!( + "Wrong authentication for user {:?}", + req_data.auth.user + ); + // TODO: error response + return; + } + // Client has correctly authenticated + // TODO: contact the service, get the key and + // connection ID + let srv_conn_id = connection::ID::new_rand(&self.rand); + let srv_secret = Secret::new_rand(&self.rand); + let head_len = req.cipher.nonce_len(); + let tag_len = req.cipher.tag_len(); - let mut auth_conn = Connection::new( - authinfo.hkdf, - req.cipher, - connection::Role::Server, + let mut auth_conn = Connection::new( + authinfo.hkdf, + req.cipher, + connection::Role::Server, + &self.rand, + ); + auth_conn.id_send = IDSend(req_data.id); + auth_conn.send_addr = udp.src; + // track connection + let auth_id_recv = self.connections.reserve_first(); + auth_conn.id_recv = auth_id_recv; + + let resp_data = dirsync::resp::Data { + client_nonce: req_data.nonce, + id: auth_conn.id_recv.0, + service_connection_id: srv_conn_id, + service_key: srv_secret, + }; + use crate::enc::sym::AAD; + // no aad for now + let aad = AAD(&mut []); + + let resp = dirsync::resp::Resp { + client_key_id: req_data.client_key_id, + data: dirsync::resp::State::ClearText(resp_data), + }; + let encrypt_from = + connection::ID::len() + resp.encrypted_offset(); + let encrypt_until = + encrypt_from + resp.encrypted_length(head_len, tag_len); + let resp_handshake = Handshake::new(handshake::Data::DirSync( + DirSync::Resp(resp), + )); + let packet = Packet { + id: connection::ID::new_handshake(), + data: packet::Data::Handshake(resp_handshake), + }; + let tot_len = packet.len(head_len, tag_len); + let mut raw_out = Vec::::with_capacity(tot_len); + raw_out.resize(tot_len, 0); + packet.serialize(head_len, tag_len, &mut raw_out); + + if let Err(e) = auth_conn + .cipher_send + .encrypt(aad, &mut raw_out[encrypt_from..encrypt_until]) + { + ::tracing::error!("can't encrypt: {:?}", e); + return; + } + self.send_packet(raw_out, udp.src, udp.dst).await; + } + handshake::Action::ClientConnect(cci) => { + self.work_timers.remove(cci.old_timeout); + let ds_resp; + if let handshake::Data::DirSync(DirSync::Resp(resp)) = + cci.handshake.data + { + ds_resp = resp; + } else { + ::tracing::error!("ClientConnect on non DS::Resp"); + return; + } + // track connection + let resp_data; + if let dirsync::resp::State::ClearText(r_data) = ds_resp.data { + resp_data = r_data; + } else { + ::tracing::error!( + "ClientConnect on non DS::Resp::ClearText" + ); + unreachable!(); + } + let auth_id_send = IDSend(resp_data.id); + let mut conn = cci.connection; + conn.id_send = auth_id_send; + let id_recv = conn.id_recv; + let cipher = conn.cipher_recv.kind(); + // track the connection to the authentication server + let track_auth_conn = match self.connections.track(conn) { + Ok(track_auth_conn) => track_auth_conn, + Err(_) => { + ::tracing::error!( + "Could not track new auth srv connection" + ); + self.connections.remove(id_recv); + // FIXME: proper connection closing + let _ = cci.answer.send(Err( + handshake::Error::InternalTracking.into(), + )); + return; + } + }; + let authsrv_conn = AuthSrvConn(connection::Conn { + queue: self.queue_sender.clone(), + conn: track_auth_conn, + }); + let mut service_conn = None; + if cci.service_id != auth::SERVICEID_AUTH { + // create and track the connection to the service + // SECURITY: xor with secrets + //FIXME: the Secret should be XORed with the client + // stored secret (if any) + let hkdf = Hkdf::new( + hkdf::Kind::Sha3, + cci.service_id.as_bytes(), + resp_data.service_key, + ); + let mut service_connection = Connection::new( + hkdf, + cipher, + connection::Role::Client, &self.rand, ); - auth_conn.id_send = IDSend(req_data.id); - // track connection - let auth_id_recv = self.connections.reserve_first(); - auth_conn.id_recv = auth_id_recv; - - let resp_data = dirsync::resp::Data { - client_nonce: req_data.nonce, - id: auth_conn.id_recv.0, - service_connection_id: srv_conn_id, - service_key: srv_secret, - }; - use crate::enc::sym::AAD; - // no aad for now - let aad = AAD(&mut []); - - let resp = dirsync::resp::Resp { - client_key_id: req_data.client_key_id, - data: dirsync::resp::State::ClearText(resp_data), - }; - let encrypt_from = - connection::ID::len() + resp.encrypted_offset(); - let encrypt_until = - encrypt_from + resp.encrypted_length(head_len, tag_len); - let resp_handshake = Handshake::new( - handshake::Data::DirSync(DirSync::Resp(resp)), - ); - let packet = Packet { - id: connection::ID::new_handshake(), - data: packet::Data::Handshake(resp_handshake), - }; - let tot_len = packet.len(head_len, tag_len); - let mut raw_out = Vec::::with_capacity(tot_len); - raw_out.resize(tot_len, 0); - packet.serialize(head_len, tag_len, &mut raw_out); - - if let Err(e) = auth_conn - .cipher_send - .encrypt(aad, &mut raw_out[encrypt_from..encrypt_until]) - { - ::tracing::error!("can't encrypt: {:?}", e); - return; - } - self.send_packet(raw_out, udp.src, udp.dst).await; - } - handshake::Action::ClientConnect(cci) => { - let ds_resp; - if let handshake::Data::DirSync(DirSync::Resp(resp)) = - cci.handshake.data - { - ds_resp = resp; - } else { - ::tracing::error!("ClientConnect on non DS::Resp"); - return; - } - // track connection - let resp_data; - if let dirsync::resp::State::ClearText(r_data) = - ds_resp.data - { - resp_data = r_data; - } else { - ::tracing::error!( - "ClientConnect on non DS::Resp::ClearText" - ); - unreachable!(); - } - let auth_id_send = IDSend(resp_data.id); - let mut conn = cci.connection; - conn.id_send = auth_id_send; - let id_recv = conn.id_recv; - let cipher = conn.cipher_recv.kind(); - // track the connection to the authentication server - let track_auth_conn = match self.connections.track(conn) { - Ok(track_auth_conn) => track_auth_conn, - Err(_) => { - ::tracing::error!( - "Could not track new auth srv connection" - ); - self.connections.remove(id_recv); - // FIXME: proper connection closing - let _ = cci.answer.send(Err( - handshake::Error::InternalTracking.into(), - )); - return; - } - }; - let authsrv_conn = AuthSrvConn(connection::Conn { + service_connection.id_recv = cci.service_connection_id; + service_connection.id_send = + IDSend(resp_data.service_connection_id); + let track_serv_conn = + match self.connections.track(service_connection) { + Ok(track_serv_conn) => track_serv_conn, + Err(_) => { + ::tracing::error!( + "Could not track new service connection" + ); + self.connections + .remove(cci.service_connection_id); + // FIXME: proper connection closing + // FIXME: drop auth srv connection if we just + // established it + let _ = cci.answer.send(Err( + handshake::Error::InternalTracking.into(), + )); + return; + } + }; + service_conn = Some(ServiceConn(connection::Conn { queue: self.queue_sender.clone(), - conn: track_auth_conn, - }); - let mut service_conn = None; - if cci.service_id != auth::SERVICEID_AUTH { - // create and track the connection to the service - // SECURITY: xor with secrets - //FIXME: the Secret should be XORed with the client - // stored secret (if any) - let hkdf = Hkdf::new( - hkdf::Kind::Sha3, - cci.service_id.as_bytes(), - resp_data.service_key, - ); - let mut service_connection = Connection::new( - hkdf, - cipher, - connection::Role::Client, - &self.rand, - ); - service_connection.id_recv = cci.service_connection_id; - service_connection.id_send = - IDSend(resp_data.service_connection_id); - let track_serv_conn = - match self.connections.track(service_connection) { - Ok(track_serv_conn) => track_serv_conn, - Err(_) => { - ::tracing::error!( - "Could not track new service connection" - ); - self.connections - .remove(cci.service_connection_id); - // FIXME: proper connection closing - // FIXME: drop auth srv connection if we just - // established it - let _ = cci.answer.send(Err( - handshake::Error::InternalTracking - .into(), - )); - return; - } - }; - service_conn = Some(ServiceConn(connection::Conn { - queue: self.queue_sender.clone(), - conn: track_serv_conn, - })); - } - let _ = - cci.answer.send(Ok(handshake::tracker::ConnectOk { - auth_key_id: cci.srv_key_id, - auth_id_send, - authsrv_conn, - service_conn, - })); + conn: track_serv_conn, + })); } - handshake::Action::Nothing => {} - }; - } + let _ = cci.answer.send(Ok(handshake::tracker::ConnectOk { + auth_key_id: cci.srv_key_id, + auth_id_send, + authsrv_conn, + service_conn, + })); + } + handshake::Action::Nothing => {} + }; } async fn send_packet( &self, -- 2.47.2 From 9ca4123c3723a144fad22719fde3a200a26fff67 Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Wed, 28 Jun 2023 18:49:33 +0200 Subject: [PATCH 5/8] Review conn tracking, data reporting Signed-off-by: Luca Fulchir --- TODO | 3 ++ src/connection/mod.rs | 70 +++++++++++++++++++++----------- src/connection/stream/mod.rs | 22 +++++++++- src/connection/stream/rob/mod.rs | 28 ++++++++----- src/inner/worker.rs | 43 +++++++++++++------- src/lib.rs | 13 +++++- 6 files changed, 127 insertions(+), 52 deletions(-) diff --git a/TODO b/TODO index 9531367..85a5331 100644 --- a/TODO +++ b/TODO @@ -1 +1,4 @@ * Wrapping for everything that wraps (sigh) +* track user connection (add u64 from user) +* split API in LocalThread and ThreadSafe + * split send/recv API in Centralized, Decentralized diff --git a/src/connection/mod.rs b/src/connection/mod.rs index dbe62e1..07af86e 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -14,7 +14,7 @@ use ::std::{ pub use crate::connection::{handshake::Handshake, packet::Packet}; use crate::{ - connection::socket::UdpClient, + connection::{socket::UdpClient, stream::StreamData}, dnssec, enc::{ self, @@ -141,28 +141,32 @@ impl ProtocolVersion { } } +/// Unique tracker of connections #[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)] -pub(crate) struct UserConnTracker(Wrapping); -impl UserConnTracker { - fn advance(&mut self) -> Self { +pub struct ConnTracker(Wrapping); +impl ConnTracker { + pub(crate) fn new(start: u16) -> Self { + Self(Wrapping(start as u64)) + } + pub(crate) fn advance(&mut self, amount: u16) -> Self { let old = self.0; - self.0 = self.0 + Wrapping(1); - UserConnTracker(old) + self.0 = self.0 + Wrapping(amount as u64); + ConnTracker(old) } } /// Connection to an Authentication Server #[derive(Debug)] -pub struct AuthSrvConn(pub(crate) Conn); +pub struct AuthSrvConn(pub Conn); /// Connection to a service #[derive(Debug)] -pub struct ServiceConn(pub(crate) Conn); +pub struct ServiceConn(pub Conn); /// The connection, as seen from a user of libFenrir #[derive(Debug)] pub struct Conn { pub(crate) queue: ::async_channel::Sender, - pub(crate) conn: UserConnTracker, + pub(crate) fast: ConnTracker, } impl Conn { @@ -172,9 +176,13 @@ impl Conn { use crate::inner::worker::Work; let _ = self .queue - .send(Work::UserSend((self.conn, stream, data))) + .send(Work::UserSend((self.tracker(), stream, data))) .await; } + /// Get the library tracking id + pub fn tracker(&self) -> ConnTracker { + self.fast + } } /// Role: track the connection direction @@ -205,6 +213,10 @@ pub(crate) enum Enqueue { Immediate(::tokio::time::Instant), } +/// Connection tracking id. Set by the user +#[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Copy, Clone)] +pub struct UserTracker(pub ::core::num::NonZeroU64); + /// A single connection and its data #[derive(Debug)] pub(crate) struct Connection { @@ -212,6 +224,9 @@ pub(crate) struct Connection { pub(crate) id_recv: IDRecv, /// Sending Conn ID pub(crate) id_send: IDSend, + /// User-managed id to track this connection + /// the user can set this to better track this connection + pub(crate) user_tracker: Option, /// Sending address pub(crate) send_addr: UdpClient, /// The main hkdf used for all secrets in this connection @@ -251,6 +266,8 @@ impl Connection { Self { id_recv: IDRecv(ID::Handshake), id_send: IDSend(ID::Handshake), + user_tracker: None, + // will be overwritten send_addr: UdpClient(SocketAddr::new( IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 31337, @@ -265,7 +282,10 @@ impl Connection { recv_queue: BTreeMap::new(), } } - pub(crate) fn recv(&mut self, mut udp: crate::RawUdp) -> Result<(), Error> { + pub(crate) fn recv( + &mut self, + mut udp: crate::RawUdp, + ) -> Result { let mut data = &mut udp.data[ID::len()..]; let aad = enc::sym::AAD(&[]); self.cipher_recv.decrypt(aad, &mut data)?; @@ -285,18 +305,22 @@ impl Connection { break; } } + let mut data_ready = StreamData::NotReady; for chunk in chunks.into_iter() { - let stream = match self.recv_queue.get_mut(&chunk.id) { + let stream_id = chunk.id; + let stream = match self.recv_queue.get_mut(&stream_id) { Some(stream) => stream, None => { ::tracing::debug!("Ignoring chunk for unknown stream::ID"); continue; } }; - stream.recv(chunk); + match stream.recv(chunk) { + Ok(status) => data_ready = data_ready | status, + Err(e) => ::tracing::debug!("stream: {:?}: {:?}", stream_id, e), + } } - // FIXME: report if we need to return data to the user - Ok(()) + Ok(data_ready) } pub(crate) fn enqueue( &mut self, @@ -402,8 +426,8 @@ impl Connection { pub(crate) struct ConnList { thread_id: ThreadTracker, connections: Vec>, - user_tracker: BTreeMap, - last_tracked: UserConnTracker, + user_tracker: BTreeMap, + last_tracked: ConnTracker, /// Bitmap to track which connection ids are used or free ids_used: Vec<::bitmaps::Bitmap<1024>>, } @@ -420,7 +444,7 @@ impl ConnList { thread_id, connections: Vec::with_capacity(INITIAL_CAP), user_tracker: BTreeMap::new(), - last_tracked: UserConnTracker(Wrapping(0)), + last_tracked: ConnTracker(Wrapping(0)), ids_used: vec![bitmap_id], }; ret.connections.resize_with(INITIAL_CAP, || None); @@ -437,10 +461,7 @@ impl ConnList { (conn_id.get() / (self.thread_id.total as u64)) as usize; (&mut self.connections[id_in_thread]).into() } - pub fn get_mut( - &mut self, - tracker: UserConnTracker, - ) -> Option<&mut Connection> { + pub fn get_mut(&mut self, tracker: ConnTracker) -> Option<&mut Connection> { let idx = if let Some(idx) = self.user_tracker.get(&tracker) { *idx } else { @@ -504,7 +525,7 @@ impl ConnList { pub(crate) fn track( &mut self, conn: Connection, - ) -> Result { + ) -> Result { let conn_id = match conn.id_recv { IDRecv(ID::Handshake) => { return Err(()); @@ -516,8 +537,9 @@ impl ConnList { self.connections[id_in_thread] = Some(conn); let mut tracked; loop { - tracked = self.last_tracked.advance(); + tracked = self.last_tracked.advance(self.thread_id.total); if self.user_tracker.get(&tracked).is_none() { + // like, never gonna happen, it's 64 bit let _ = self.user_tracker.insert(tracked, id_in_thread); break; } diff --git a/src/connection/stream/mod.rs b/src/connection/stream/mod.rs index 96971cc..425cbd8 100644 --- a/src/connection/stream/mod.rs +++ b/src/connection/stream/mod.rs @@ -238,6 +238,26 @@ impl Tracker { } } +#[derive(Debug, Eq, PartialEq)] +pub(crate) enum StreamData { + /// not enough data to return somthing to the user + NotReady = 0, + /// we can return something to the user + Ready, +} +impl ::core::ops::BitOr for StreamData { + type Output = Self; + + // Required method + fn bitor(self, other: Self) -> Self::Output { + if self == StreamData::Ready || other == StreamData::Ready { + StreamData::Ready + } else { + StreamData::NotReady + } + } +} + /// Actual stream-tracking structure #[derive(Debug, Clone)] pub(crate) struct Stream { @@ -254,7 +274,7 @@ impl Stream { data: Tracker::new(kind, rand), } } - pub(crate) fn recv(&mut self, chunk: Chunk) -> Result<(), Error> { + pub(crate) fn recv(&mut self, chunk: Chunk) -> Result { match &mut self.data { Tracker::ROB(tracker) => tracker.recv(chunk), } diff --git a/src/connection/stream/rob/mod.rs b/src/connection/stream/rob/mod.rs index 1bfd159..21361bb 100644 --- a/src/connection/stream/rob/mod.rs +++ b/src/connection/stream/rob/mod.rs @@ -2,7 +2,9 @@ //! AKA: TCP-like use crate::{ - connection::stream::{Chunk, Error, Sequence, SequenceEnd, SequenceStart}, + connection::stream::{ + Chunk, Error, Sequence, SequenceEnd, SequenceStart, StreamData, + }, enc::Random, }; @@ -93,7 +95,7 @@ impl ReliableOrderedBytestream { ((self.pivot as usize + data_len) % self.data.len()) as u32; ret } - pub(crate) fn recv(&mut self, chunk: Chunk) -> Result<(), Error> { + pub(crate) fn recv(&mut self, chunk: Chunk) -> Result { if !chunk .sequence .is_between(self.window_start, self.window_end) @@ -106,7 +108,7 @@ impl ReliableOrderedBytestream { chunk.data.len(), ); if maxlen == 0 { - // or empty chunk, but we don't care + // empty window or empty chunk, but we don't care return Err(Error::OutOfWindow); } // translate Sequences to offsets in self.data @@ -119,7 +121,7 @@ impl ReliableOrderedBytestream { let mut copy_ranges = Vec::new(); let mut to_delete = Vec::new(); let mut to_add = Vec::new(); - // note: te included ranges are (INCLUSIVE, INCLUSIVE) + // note: the ranges are (INCLUSIVE, INCLUSIVE) for (idx, el) in self.missing.iter_mut().enumerate() { let missing_from = self.window_start.offset(el.0); if missing_from > offset_end { @@ -146,7 +148,6 @@ impl ReliableOrderedBytestream { if missing_to > offset_end { // [..chunk..] // [....missing....] - // chunk is in the middle of a missing fragment to_add.push(( el.0.plus_u32(((offset_end - missing_from) + 1) as u32), el.1, @@ -156,16 +157,12 @@ impl ReliableOrderedBytestream { } else if offset <= missing_to { // [....chunk....] // [...missing...] - // chunk copy_ranges.push((offset, (missing_to - 0))); el.1 = el.0.plus_u32(((offset_end - missing_from) - 1) as u32); } } } - self.missing.append(&mut to_add); - self.missing - .sort_by(|(from_a, _), (from_b, _)| from_a.0 .0.cmp(&from_b.0 .0)); { let mut deleted = 0; for idx in to_delete.into_iter() { @@ -173,6 +170,10 @@ impl ReliableOrderedBytestream { deleted = deleted + 1; } } + self.missing.append(&mut to_add); + self.missing + .sort_by(|(from_a, _), (from_b, _)| from_a.0 .0.cmp(&from_b.0 .0)); + // copy only the missing data let (first, second) = self.data[..].split_at_mut(self.pivot as usize); for (from, to) in copy_ranges.into_iter() { @@ -198,7 +199,12 @@ impl ReliableOrderedBytestream { .copy_from_slice(&data[data_from..data_to]); } } - - Ok(()) + if self.missing.len() == 0 + || self.window_start.offset(self.missing[0].0) == 0 + { + Ok(StreamData::Ready) + } else { + Ok(StreamData::NotReady) + } } } diff --git a/src/inner/worker.rs b/src/inner/worker.rs index 725578d..c71b6f2 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -11,8 +11,8 @@ use crate::{ }, packet::{self, Packet}, socket::{UdpClient, UdpServer}, - stream, AuthSrvConn, ConnList, Connection, IDSend, ServiceConn, - UserConnTracker, + stream, AuthSrvConn, ConnList, ConnTracker, Connection, IDSend, + ServiceConn, }, dnssec, enc::{ @@ -46,6 +46,16 @@ pub(crate) struct ConnectInfo { // TODO: UserID, Token information } +/// Connection event. Mostly used to give the data to the user +#[derive(Debug, Eq, PartialEq, Clone)] +#[non_exhaustive] +pub enum Event { + /// Work loop has exited. nothing more to do + End, + /// Data from a connection + Data(Vec), +} + pub(crate) enum Work { /// ask the thread to report to the main thread the total number of /// connections present @@ -53,8 +63,8 @@ pub(crate) enum Work { Connect(ConnectInfo), DropHandshake(KeyID), Recv(RawUdp), - UserSend((UserConnTracker, stream::ID, Vec)), - SendData((UserConnTracker, ::tokio::time::Instant)), + UserSend((ConnTracker, stream::ID, Vec)), + SendData((ConnTracker, ::tokio::time::Instant)), } /// Actual worker implementation. @@ -136,7 +146,7 @@ impl Worker { } /// Continuously loop and process work as needed - pub async fn work_loop(&mut self) { + pub async fn work_loop(&mut self) -> Result { 'mainloop: loop { let next_timer = self.work_timers.get_next(); ::tokio::pin!(next_timer); @@ -436,7 +446,7 @@ impl Worker { } Work::UserSend((tracker, stream, data)) => { let conn = match self.connections.get_mut(tracker) { - None => return, + None => continue, Some(conn) => conn, }; use connection::Enqueue; @@ -468,21 +478,21 @@ impl Worker { .queue_sender .send(Work::SendData((tracker, instant))) .await; - return; + continue; } let mut raw: Vec = Vec::with_capacity(1200); raw.resize(raw.capacity(), 0); let conn = match self.connections.get_mut(tracker) { - None => return, + None => continue, Some(conn) => conn, }; let pkt = match conn.write_pkt(&mut raw) { Ok(pkt) => pkt, - Err(enc::Error::NotEnoughData(0)) => return, + Err(enc::Error::NotEnoughData(0)) => continue, Err(e) => { ::tracing::error!("Packet generation: {:?}", e); - return; + continue; } }; let dest = conn.send_addr; @@ -493,6 +503,7 @@ impl Worker { } } } + Ok(Event::End) } /// Read and do stuff with the raw udp packet async fn recv(&mut self, mut udp: RawUdp) { @@ -527,8 +538,12 @@ impl Worker { None => return, Some(conn) => conn, }; - if let Err(e) = conn.recv(udp) { - ::tracing::trace!("Conn Recv: {:?}", e.to_string()); + match conn.recv(udp) { + Ok(stream::StreamData::NotReady) => {} + Ok(stream::StreamData::Ready) => { + // + } + Err(e) => ::tracing::trace!("Conn Recv: {:?}", e.to_string()), } } /// Receive an handshake packet @@ -693,7 +708,7 @@ impl Worker { }; let authsrv_conn = AuthSrvConn(connection::Conn { queue: self.queue_sender.clone(), - conn: track_auth_conn, + fast: track_auth_conn, }); let mut service_conn = None; if cci.service_id != auth::SERVICEID_AUTH { @@ -735,7 +750,7 @@ impl Worker { }; service_conn = Some(ServiceConn(connection::Conn { queue: self.queue_sender.clone(), - conn: track_serv_conn, + fast: track_serv_conn, })); } let _ = cci.answer.send(Ok(handshake::tracker::ConnectOk { diff --git a/src/lib.rs b/src/lib.rs index 8ac887a..61af27a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -34,7 +34,7 @@ use crate::{ AuthServerConnections, Packet, }, inner::{ - worker::{ConnectInfo, RawUdp, Work, Worker}, + worker::{ConnectInfo, Event, RawUdp, Work, Worker}, ThreadTracker, }, }; @@ -638,7 +638,16 @@ impl Fenrir { Ok(worker) => worker, Err(_) => return, }; - worker.work_loop().await + loop { + match worker.work_loop().await { + Ok(_) => continue, + Ok(Event::End) => break, + Err(e) => { + ::tracing::error!("Worker: {:?}", e); + break; + } + } + } }); }); loop { -- 2.47.2 From 4ddfed358d1cec302b3786f43f48f65b452bf72d Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Thu, 6 Jul 2023 10:48:18 +0200 Subject: [PATCH 6/8] Connections: Track, send, receive Signed-off-by: Luca Fulchir --- TODO | 6 +- src/connection/mod.rs | 163 ++++++++++++++++++++++-------- src/connection/stream/mod.rs | 5 + src/inner/worker.rs | 189 +++++++++++++++++++++++++++-------- src/lib.rs | 10 +- 5 files changed, 281 insertions(+), 92 deletions(-) diff --git a/TODO b/TODO index 85a5331..33fa189 100644 --- a/TODO +++ b/TODO @@ -1,4 +1,6 @@ * Wrapping for everything that wraps (sigh) * track user connection (add u64 from user) -* split API in LocalThread and ThreadSafe - * split send/recv API in Centralized, Decentralized +* API plit + * split API in ThreadLocal, ThreadSafe + * split send/recv API in Centralized, Connection + * all re wrappers on ThreadLocal-Centralized diff --git a/src/connection/mod.rs b/src/connection/mod.rs index 07af86e..e205e33 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -7,7 +7,7 @@ pub mod stream; use ::core::num::Wrapping; use ::std::{ - collections::{BTreeMap, HashMap}, + collections::{BTreeMap, HashMap, VecDeque}, vec::Vec, }; @@ -26,14 +26,21 @@ use crate::{ inner::{worker, ThreadTracker}, }; -/// Connaction errors +/// Connection errors #[derive(::thiserror::Error, Debug, Copy, Clone)] -pub(crate) enum Error { +pub enum Error { /// Can't decrypt packet #[error("Decrypt error: {0}")] Decrypt(#[from] crate::enc::Error), + /// Error in parsing a packet realated to the connection #[error("Chunk parsing: {0}")] Parse(#[from] stream::Error), + /// No such Connection + #[error("No suck connection")] + NoSuchConnection, + /// No such Stream + #[error("No suck Stream")] + NoSuchStream, } /// Fenrir Connection ID @@ -141,32 +148,53 @@ impl ProtocolVersion { } } +/// Connection tracking id. Set by the user +#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)] +pub struct UserTracker(pub ::core::num::NonZeroU64); + /// Unique tracker of connections #[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)] -pub struct ConnTracker(Wrapping); -impl ConnTracker { +pub struct LibTracker(Wrapping); +impl LibTracker { pub(crate) fn new(start: u16) -> Self { Self(Wrapping(start as u64)) } pub(crate) fn advance(&mut self, amount: u16) -> Self { let old = self.0; self.0 = self.0 + Wrapping(amount as u64); - ConnTracker(old) + LibTracker(old) } } +/// Collection of connection tracking, but user-given and library generated +#[derive(Debug, Copy, Clone)] +pub struct ConnTracker { + /// Optional tracker set by the user + pub user: Option, + /// library generated tracker. Unique and non-repeating + pub(crate) lib: LibTracker, +} + +impl PartialEq for ConnTracker { + fn eq(&self, other: &Self) -> bool { + self.lib == other.lib + } +} +impl Eq for ConnTracker {} /// Connection to an Authentication Server -#[derive(Debug)] -pub struct AuthSrvConn(pub Conn); +#[derive(Debug, Copy, Clone)] +pub struct AuthSrvConn(pub ConnTracker); /// Connection to a service -#[derive(Debug)] -pub struct ServiceConn(pub Conn); +#[derive(Debug, Copy, Clone)] +pub struct ServiceConn(pub ConnTracker); +/* + * TODO: only on Thread{Local,Safe}::Connection oriented flows /// The connection, as seen from a user of libFenrir #[derive(Debug)] pub struct Conn { pub(crate) queue: ::async_channel::Sender, - pub(crate) fast: ConnTracker, + pub(crate) tracker: ConnTracker, } impl Conn { @@ -176,14 +204,15 @@ impl Conn { use crate::inner::worker::Work; let _ = self .queue - .send(Work::UserSend((self.tracker(), stream, data))) + .send(Work::UserSend((self.tracker.lib, stream, data))) .await; } /// Get the library tracking id pub fn tracker(&self) -> ConnTracker { - self.fast + self.tracker } } +*/ /// Role: track the connection direction /// @@ -208,15 +237,10 @@ enum TimerKind { } pub(crate) enum Enqueue { - NoSuchStream, TimerWait, Immediate(::tokio::time::Instant), } -/// Connection tracking id. Set by the user -#[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Copy, Clone)] -pub struct UserTracker(pub ::core::num::NonZeroU64); - /// A single connection and its data #[derive(Debug)] pub(crate) struct Connection { @@ -227,6 +251,7 @@ pub(crate) struct Connection { /// User-managed id to track this connection /// the user can set this to better track this connection pub(crate) user_tracker: Option, + pub(crate) lib_tracker: LibTracker, /// Sending address pub(crate) send_addr: UdpClient, /// The main hkdf used for all secrets in this connection @@ -242,6 +267,7 @@ pub(crate) struct Connection { last_stream_sent: stream::ID, /// receive queue for each Stream recv_queue: BTreeMap, + streams_ready: VecDeque, } impl Connection { @@ -267,6 +293,7 @@ impl Connection { id_recv: IDRecv(ID::Handshake), id_send: IDSend(ID::Handshake), user_tracker: None, + lib_tracker: LibTracker::new(0), // will be overwritten send_addr: UdpClient(SocketAddr::new( IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), @@ -280,8 +307,24 @@ impl Connection { send_queue: BTreeMap::new(), last_stream_sent: stream::ID(0), recv_queue: BTreeMap::new(), + streams_ready: VecDeque::with_capacity(4), } } + pub(crate) fn get_data(&mut self) -> Option)>> { + if self.streams_ready.is_empty() { + return None; + } + let ret_len = self.streams_ready.len(); + let mut ret = Vec::with_capacity(ret_len); + while let Some(stream_id) = self.streams_ready.pop_front() { + let stream = match self.recv_queue.get_mut(&stream_id) { + Some(stream) => stream, + None => continue, + }; + ret.push((stream_id, stream.get())); + } + Some(ret) + } pub(crate) fn recv( &mut self, mut udp: crate::RawUdp, @@ -316,7 +359,12 @@ impl Connection { } }; match stream.recv(chunk) { - Ok(status) => data_ready = data_ready | status, + Ok(status) => { + if !self.streams_ready.contains(&stream_id) { + self.streams_ready.push_back(stream_id); + } + data_ready = data_ready | status; + } Err(e) => ::tracing::debug!("stream: {:?}: {:?}", stream_id, e), } } @@ -326,9 +374,9 @@ impl Connection { &mut self, stream: stream::ID, data: Vec, - ) -> Enqueue { + ) -> Result { let stream = match self.send_queue.get_mut(&stream) { - None => return Enqueue::NoSuchStream, + None => return Err(Error::NoSuchStream), Some(stream) => stream, }; stream.enqueue(data); @@ -348,7 +396,7 @@ impl Connection { TimerKind::SendData(old_timer) } }; - ret + Ok(ret) } pub(crate) fn write_pkt<'a>( &mut self, @@ -426,8 +474,8 @@ impl Connection { pub(crate) struct ConnList { thread_id: ThreadTracker, connections: Vec>, - user_tracker: BTreeMap, - last_tracked: ConnTracker, + user_tracker: BTreeMap, + last_tracked: LibTracker, /// Bitmap to track which connection ids are used or free ids_used: Vec<::bitmaps::Bitmap<1024>>, } @@ -444,32 +492,40 @@ impl ConnList { thread_id, connections: Vec::with_capacity(INITIAL_CAP), user_tracker: BTreeMap::new(), - last_tracked: ConnTracker(Wrapping(0)), + last_tracked: LibTracker(Wrapping(0)), ids_used: vec![bitmap_id], }; ret.connections.resize_with(INITIAL_CAP, || None); ret } - pub fn get_id_mut(&mut self, id: ID) -> Option<&mut Connection> { + pub fn get_id_mut(&mut self, id: ID) -> Result<&mut Connection, Error> { let conn_id = match id { - ID::Handshake => { - return None; - } ID::ID(conn_id) => conn_id, + ID::Handshake => { + return Err(Error::NoSuchConnection); + } }; let id_in_thread: usize = (conn_id.get() / (self.thread_id.total as u64)) as usize; - (&mut self.connections[id_in_thread]).into() + if let Some(conn) = &mut self.connections[id_in_thread] { + Ok(conn) + } else { + return Err(Error::NoSuchConnection); + } } - pub fn get_mut(&mut self, tracker: ConnTracker) -> Option<&mut Connection> { + pub fn get_mut( + &mut self, + tracker: LibTracker, + ) -> Result<&mut Connection, Error> { let idx = if let Some(idx) = self.user_tracker.get(&tracker) { *idx } else { - return None; + return Err(Error::NoSuchConnection); }; - match &mut self.connections[idx] { - None => None, - Some(conn) => Some(conn), + if let Some(conn) = &mut self.connections[idx] { + Ok(conn) + } else { + return Err(Error::NoSuchConnection); } } pub fn len(&self) -> usize { @@ -481,7 +537,23 @@ impl ConnList { } /// Only *Reserve* a connection, /// without actually tracking it in self.connections + pub(crate) fn reserve_and_track<'a>( + &'a mut self, + mut conn: Connection, + ) -> (LibTracker, &'a mut Connection) { + let (id_conn, id_in_thread) = self.reserve_first_with_idx(); + conn.id_recv = id_conn; + let tracker = self.get_new_tracker(id_in_thread); + conn.lib_tracker = tracker; + self.connections[id_in_thread] = Some(conn); + (tracker, self.connections[id_in_thread].as_mut().unwrap()) + } + /// Only *Reserve* a connection, + /// without actually tracking it in self.connections pub(crate) fn reserve_first(&mut self) -> IDRecv { + self.reserve_first_with_idx().0 + } + fn reserve_first_with_idx(&mut self) -> (IDRecv, usize) { // uhm... bad things are going on here: // * id must be initialized, but only because: // * rust does not understand that after the `!found` id is always @@ -519,13 +591,13 @@ impl ConnList { let actual_id = ((id_in_thread as u64) * (self.thread_id.total as u64)) + (self.thread_id.id as u64); let new_id = IDRecv(ID::new_u64(actual_id)); - new_id + (new_id, id_in_thread) } /// NOTE: does NOT check if the connection has been previously reserved! pub(crate) fn track( &mut self, - conn: Connection, - ) -> Result { + mut conn: Connection, + ) -> Result { let conn_id = match conn.id_recv { IDRecv(ID::Handshake) => { return Err(()); @@ -534,17 +606,22 @@ impl ConnList { }; let id_in_thread: usize = (conn_id.get() / (self.thread_id.total as u64)) as usize; + let tracker = self.get_new_tracker(id_in_thread); + conn.lib_tracker = tracker; self.connections[id_in_thread] = Some(conn); - let mut tracked; + Ok(tracker) + } + fn get_new_tracker(&mut self, id_in_thread: usize) -> LibTracker { + let mut tracker; loop { - tracked = self.last_tracked.advance(self.thread_id.total); - if self.user_tracker.get(&tracked).is_none() { + tracker = self.last_tracked.advance(self.thread_id.total); + if self.user_tracker.get(&tracker).is_none() { // like, never gonna happen, it's 64 bit - let _ = self.user_tracker.insert(tracked, id_in_thread); + let _ = self.user_tracker.insert(tracker, id_in_thread); break; } } - Ok(tracked) + tracker } pub(crate) fn remove(&mut self, id: IDRecv) { if let IDRecv(ID::ID(raw_id)) = id { diff --git a/src/connection/stream/mod.rs b/src/connection/stream/mod.rs index 425cbd8..44ac3be 100644 --- a/src/connection/stream/mod.rs +++ b/src/connection/stream/mod.rs @@ -279,6 +279,11 @@ impl Stream { Tracker::ROB(tracker) => tracker.recv(chunk), } } + pub(crate) fn get(&mut self) -> Vec { + match &mut self.data { + Tracker::ROB(tracker) => tracker.get(), + } + } } /// Track what has been sent and what has been ACK'd from a stream diff --git a/src/inner/worker.rs b/src/inner/worker.rs index c71b6f2..f344bb6 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -12,7 +12,7 @@ use crate::{ packet::{self, Packet}, socket::{UdpClient, UdpServer}, stream, AuthSrvConn, ConnList, ConnTracker, Connection, IDSend, - ServiceConn, + LibTracker, ServiceConn, }, dnssec, enc::{ @@ -23,7 +23,7 @@ use crate::{ }, inner::ThreadTracker, }; -use ::std::{sync::Arc, vec::Vec}; +use ::std::{collections::VecDeque, sync::Arc, vec::Vec}; /// This worker must be cpu-pinned use ::tokio::{ net::UdpSocket, @@ -46,14 +46,23 @@ pub(crate) struct ConnectInfo { // TODO: UserID, Token information } +/// return to the user the data received from a connection +#[derive(Debug, Clone)] +pub struct ConnData { + /// Connection tracking information + pub conn: ConnTracker, + /// received data, for each stream + pub data: Vec<(stream::ID, Vec)>, +} + /// Connection event. Mostly used to give the data to the user -#[derive(Debug, Eq, PartialEq, Clone)] +#[derive(Debug, Clone)] #[non_exhaustive] pub enum Event { /// Work loop has exited. nothing more to do End, /// Data from a connection - Data(Vec), + Data(ConnData), } pub(crate) enum Work { @@ -63,8 +72,8 @@ pub(crate) enum Work { Connect(ConnectInfo), DropHandshake(KeyID), Recv(RawUdp), - UserSend((ConnTracker, stream::ID, Vec)), - SendData((ConnTracker, ::tokio::time::Instant)), + UserSend((LibTracker, stream::ID, Vec)), + SendData((LibTracker, ::tokio::time::Instant)), } /// Actual worker implementation. @@ -83,6 +92,8 @@ pub struct Worker { queue_timeouts_send: mpsc::UnboundedSender, thread_channels: Vec<::async_channel::Sender>, connections: ConnList, + // connectsion untracker by the user. (users still needs to get(..) them) + untracked_connections: VecDeque, handshakes: handshake::Tracker, work_timers: super::Timers, } @@ -140,10 +151,82 @@ impl Worker { queue_timeouts_send, thread_channels: Vec::new(), connections: ConnList::new(thread_id), + untracked_connections: VecDeque::with_capacity(8), handshakes, work_timers: super::Timers::new(), }) } + /// return a handle to the worker that you can use to send data + /// The handle will enqueue work in the main worker and is thread-local safe + /// + /// While this does not require `&mut` on the `Worker`, everything + /// will be put in the work queue, + /// So you might have less immediate results in a few cases + pub fn handle(&self) -> Handle { + Handle { + queue: self.queue_sender.clone(), + } + } + /// change the UserTracker in the connection + /// + /// This is `unsafe` because you will be responsible for manually updating + /// any copy of the `ConnTracker` you might have cloned around + #[allow(unsafe_code)] + pub unsafe fn set_connection_tracker( + &mut self, + tracker: ConnTracker, + new_id: connection::UserTracker, + ) -> Result { + let conn = self.connections.get_mut(tracker.lib)?; + conn.user_tracker = Some(new_id); + Ok(ConnTracker { + lib: tracker.lib, + user: Some(new_id), + }) + } + /// Enqueue data to send + pub fn send( + &mut self, + tracker: LibTracker, + stream: stream::ID, + data: Vec, + ) -> Result<(), crate::Error> { + let conn = self.connections.get_mut(tracker)?; + conn.enqueue(stream, data)?; + Ok(()) + } + /// Returns new connections, if any + /// + /// You can provide an optional tracker, different from the library tracker. + /// + /// Differently from the library tracker, you can change this later on, + /// but you will be responsible to change it on every `ConnTracker` + /// you might have cloned elsewhere + pub fn try_get_connection( + &mut self, + tracker: Option, + ) -> Option { + let ret_tracker = ConnTracker { + lib: self.untracked_connections.pop_front()?, + user: None, + }; + match tracker { + Some(tracker) => { + #[allow(unsafe_code)] + match unsafe { + self.set_connection_tracker(ret_tracker, tracker) + } { + Ok(tracker) => Some(tracker), + Err(_) => { + // we had a connection, but it expired before the user + // remembered to get it. Just remove it from the queue. + None + } + } + } + None => Some(ret_tracker), + } + } /// Continuously loop and process work as needed pub async fn work_loop(&mut self) -> Result { @@ -441,27 +524,25 @@ impl Worker { } }; } - Work::Recv(pkt) => { - self.recv(pkt).await; - } + Work::Recv(pkt) => match self.recv(pkt).await { + Ok(event) => return Ok(event), + Err(_) => continue 'mainloop, + }, Work::UserSend((tracker, stream, data)) => { let conn = match self.connections.get_mut(tracker) { - None => continue, - Some(conn) => conn, + Ok(conn) => conn, + Err(_) => continue 'mainloop, }; use connection::Enqueue; - match conn.enqueue(stream, data) { - Enqueue::Immediate(instant) => { - let _ = self - .queue_sender - .send(Work::SendData((tracker, instant))) - .await; - } - Enqueue::TimerWait => {} - Enqueue::NoSuchStream => { - ::tracing::error!( - "Trying to send on unknown stream" - ); + if let Ok(enqueued) = conn.enqueue(stream, data) { + match enqueued { + Enqueue::Immediate(instant) => { + let _ = self + .queue_sender + .send(Work::SendData((tracker, instant))) + .await; + } + Enqueue::TimerWait => {} } } } @@ -484,8 +565,8 @@ impl Worker { let mut raw: Vec = Vec::with_capacity(1200); raw.resize(raw.capacity(), 0); let conn = match self.connections.get_mut(tracker) { - None => continue, - Some(conn) => conn, + Ok(conn) => conn, + Err(_) => continue, }; let pkt = match conn.write_pkt(&mut raw) { Ok(pkt) => pkt, @@ -506,7 +587,7 @@ impl Worker { Ok(Event::End) } /// Read and do stuff with the raw udp packet - async fn recv(&mut self, mut udp: RawUdp) { + async fn recv(&mut self, mut udp: RawUdp) -> Result { if udp.packet.id.is_handshake() { let handshake = match Handshake::deserialize( &udp.data[connection::ID::len()..], @@ -514,7 +595,7 @@ impl Worker { Ok(handshake) => handshake, Err(e) => { ::tracing::debug!("Handshake parsing: {}", e); - return; + return Err(()); } }; let action = match self.handshakes.recv_handshake( @@ -524,26 +605,34 @@ impl Worker { Ok(action) => action, Err(err) => { ::tracing::debug!("Handshake recv error {}", err); - return; + return Err(()); } }; self.recv_handshake(udp, action).await; + Err(()) } else { - self.recv_packet(udp); + self.recv_packet(udp) } } /// Receive a non-handshake packet - fn recv_packet(&mut self, udp: RawUdp) { + fn recv_packet(&mut self, udp: RawUdp) -> Result { let conn = match self.connections.get_id_mut(udp.packet.id) { - None => return, - Some(conn) => conn, + Ok(conn) => conn, + Err(_) => return Err(()), }; match conn.recv(udp) { - Ok(stream::StreamData::NotReady) => {} - Ok(stream::StreamData::Ready) => { - // + Ok(stream::StreamData::NotReady) => Err(()), + Ok(stream::StreamData::Ready) => Ok(Event::Data(ConnData { + conn: ConnTracker { + user: conn.user_tracker, + lib: conn.lib_tracker, + }, + data: conn.get_data().unwrap(), + })), + Err(e) => { + ::tracing::trace!("Conn Recv: {:?}", e.to_string()); + Err(()) } - Err(e) => ::tracing::trace!("Conn Recv: {:?}", e.to_string()), } } /// Receive an handshake packet @@ -625,6 +714,9 @@ impl Worker { // track connection let auth_id_recv = self.connections.reserve_first(); auth_conn.id_recv = auth_id_recv; + let (tracker, auth_conn) = + self.connections.reserve_and_track(auth_conn); + self.untracked_connections.push_back(tracker); let resp_data = dirsync::resp::Data { client_nonce: req_data.nonce, @@ -706,9 +798,9 @@ impl Worker { return; } }; - let authsrv_conn = AuthSrvConn(connection::Conn { - queue: self.queue_sender.clone(), - fast: track_auth_conn, + let authsrv_conn = AuthSrvConn(ConnTracker { + lib: track_auth_conn, + user: None, }); let mut service_conn = None; if cci.service_id != auth::SERVICEID_AUTH { @@ -748,9 +840,9 @@ impl Worker { return; } }; - service_conn = Some(ServiceConn(connection::Conn { - queue: self.queue_sender.clone(), - fast: track_serv_conn, + service_conn = Some(ServiceConn(ConnTracker { + lib: track_serv_conn, + user: None, })); } let _ = cci.answer.send(Ok(handshake::tracker::ConnectOk { @@ -786,3 +878,16 @@ impl Worker { let _res = src_sock.send_to(&data, client.0).await; } } + +/// Handle to send work asyncronously to the worker +#[derive(Debug, Clone)] +pub struct Handle { + queue: ::async_channel::Sender, +} + +impl Handle { + // TODO + // pub fn send(..) + // pub fn set_connection_id(..) + // try_get_new_connections() +} diff --git a/src/lib.rs b/src/lib.rs index 61af27a..347b646 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -59,15 +59,15 @@ pub enum Error { /// Handshake errors #[error("Handshake: {0:?}")] Handshake(#[from] handshake::Error), - /// Key error - #[error("key: {0:?}")] - Key(#[from] crate::enc::Error), /// Resolution problems. wrong or incomplete DNSSEC data #[error("DNSSEC resolution: {0}")] Resolution(String), /// Wrapper on encryption errors - #[error("Encrypt: {0}")] - Encrypt(enc::Error), + #[error("Crypto: {0}")] + Crypto(#[from] enc::Error), + /// Wrapper on connection errors + #[error("Connection: {0}")] + Connection(#[from] connection::Error), } pub(crate) enum StopWorking { -- 2.47.2 From 62a71a2af5cea7d4f0d482df8ad89b29932f94e6 Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Fri, 21 Mar 2025 17:57:08 +0100 Subject: [PATCH 7/8] [transport] uud/uudl, first tests Signed-off-by: Luca Fulchir --- .gitignore | 1 + Cargo.toml | 34 +- flake.lock | 80 +--- flake.nix | 65 ++-- rustfmt.toml | 2 +- src/connection/handshake/tracker.rs | 16 +- src/connection/mod.rs | 7 +- src/connection/stream/errors.rs | 6 + src/connection/stream/mod.rs | 136 +++++-- src/connection/stream/rob/mod.rs | 25 +- src/connection/stream/rob/tests.rs | 24 +- src/connection/stream/uud/mod.rs | 557 ++++++++++++++++++++++++++++ src/connection/stream/uud/tests.rs | 249 +++++++++++++ src/connection/stream/uudl/mod.rs | 43 +++ src/connection/stream/uudl/tests.rs | 56 +++ src/dnssec/mod.rs | 7 +- src/enc/asym.rs | 3 +- src/enc/hkdf.rs | 5 +- src/enc/mod.rs | 2 +- src/enc/tests.rs | 4 +- src/inner/mod.rs | 2 +- src/inner/worker.rs | 2 +- src/tests.rs | 2 +- 23 files changed, 1143 insertions(+), 185 deletions(-) create mode 100644 src/connection/stream/uud/mod.rs create mode 100644 src/connection/stream/uud/tests.rs create mode 100644 src/connection/stream/uudl/mod.rs create mode 100644 src/connection/stream/uudl/tests.rs diff --git a/.gitignore b/.gitignore index 4ea67ad..3438345 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ *.swp /target /Cargo.lock +/flake.profile* diff --git a/Cargo.toml b/Cargo.toml index 0544e30..d857346 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,11 +2,11 @@ name = "fenrir" version = "0.1.0" -edition = "2021" +edition = "2024" # Fenrir won't be ready for a while, # we might as well use async fn in trait, which is nightly # remember to update this -rust-version = "1.67.0" +rust-version = "1.85.0" homepage = "https://git.runesauth.com/RunesAuth/libFenrir" repository = "https://git.runesauth.com/RunesAuth/libFenrir" license = "Apache-2.0 WITH LLVM-exception" @@ -21,16 +21,16 @@ publish = false [lib] -crate_type = [ "lib", "cdylib", "staticlib" ] +crate-type = [ "lib", "cdylib", "staticlib" ] [dependencies] # please keep these in alphabetical order -arc-swap = { version = "1.6" } +arc-swap = { version = "1.7" } arrayref = { version = "0.3" } -async-channel = { version = "1.8" } +async-channel = { version = "2.3" } # base85 repo has no tags, fix on a commit. v1.1.1 points to older, wrong version -base85 = { git = "https://gitlab.com/darkwyrm/base85", rev = "d98efbfd171dd9ba48e30a5c88f94db92fc7b3c6" } +base85 = { git = "https://gitlab.com/darkwyrm/base85", rev = "b5389888aca6208a7563c8dbf2af46a82e724fa1" } bitmaps = { version = "3.2" } chacha20poly1305 = { version = "0.10" } futures = { version = "0.3" } @@ -38,27 +38,27 @@ hkdf = { version = "0.12" } hwloc2 = {version = "2.2" } libc = { version = "0.2" } num-traits = { version = "0.2" } -num-derive = { version = "0.3" } +num-derive = { version = "0.4" } rand_core = {version = "0.6" } -ring = { version = "0.16" } +ring = { version = "0.17" } bincode = { version = "1.3" } sha3 = { version = "0.10" } -strum = { version = "0.24" } -strum_macros = { version = "0.24" } -thiserror = { version = "1.0" } +strum = { version = "0.26" } +strum_macros = { version = "0.26" } +thiserror = { version = "2.0" } tokio = { version = "1", features = ["full"] } # PERF: todo linux-only, behind "iouring" feature #tokio-uring = { version = "0.4" } tracing = { version = "0.1" } tracing-test = { version = "0.2" } -trust-dns-resolver = { version = "0.22", features = [ "dnssec-ring" ] } -trust-dns-client = { version = "0.22", features = [ "dnssec" ] } -trust-dns-proto = { version = "0.22" } +trust-dns-resolver = { version = "0.23", features = [ "dnssec-ring" ] } +trust-dns-client = { version = "0.23", features = [ "dnssec" ] } +trust-dns-proto = { version = "0.23" } # don't use stable dalek. forces zeroize 1.3, # breaks our and chacha20poly1305 # reason: zeroize is not pure rust, # so we can't have multiple versions of if -x25519-dalek = { version = "2.0.0-pre.1", features = [ "serde" ] } +x25519-dalek = { version = "2.0", features = [ "serde", "static_secrets" ] } zeroize = { version = "1" } [profile.dev] @@ -84,3 +84,7 @@ incremental = true codegen-units = 256 rpath = false +#[target.x86_64-unknown-linux-gnu] +#linker = "clang" +#rustflags = ["-C", "link-arg=--ld-path=mold"] + diff --git a/flake.lock b/flake.lock index 85c58c2..01b37d0 100644 --- a/flake.lock +++ b/flake.lock @@ -5,29 +5,11 @@ "systems": "systems" }, "locked": { - "lastModified": 1687171271, - "narHash": "sha256-BJlq+ozK2B1sJDQXS3tzJM5a+oVZmi1q0FlBK/Xqv7M=", + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", "owner": "numtide", "repo": "flake-utils", - "rev": "abfb11bd1aec8ced1c9bb9adfe68018230f4fb3c", - "type": "github" - }, - "original": { - "owner": "numtide", - "repo": "flake-utils", - "type": "github" - } - }, - "flake-utils_2": { - "inputs": { - "systems": "systems_2" - }, - "locked": { - "lastModified": 1681202837, - "narHash": "sha256-H+Rh19JDwRtpVPAWp64F+rlEtxUWBAQW28eAi3SRSzg=", - "owner": "numtide", - "repo": "flake-utils", - "rev": "cfacdce06f30d2b68473a46042957675eebb3401", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", "type": "github" }, "original": { @@ -38,27 +20,27 @@ }, "nixpkgs": { "locked": { - "lastModified": 1687555006, - "narHash": "sha256-GD2Kqb/DXQBRJcHqkM2qFZqbVenyO7Co/80JHRMg2U0=", + "lastModified": 1741862977, + "narHash": "sha256-prZ0M8vE/ghRGGZcflvxCu40ObKaB+ikn74/xQoNrGQ=", "owner": "nixos", "repo": "nixpkgs", - "rev": "33223d479ffde3d05ac16c6dff04ae43cc27e577", + "rev": "cdd2ef009676ac92b715ff26630164bb88fec4e0", "type": "github" }, "original": { "owner": "nixos", - "ref": "nixos-23.05", + "ref": "nixos-24.11", "repo": "nixpkgs", "type": "github" } }, "nixpkgs-unstable": { "locked": { - "lastModified": 1687502512, - "narHash": "sha256-dBL/01TayOSZYxtY4cMXuNCBk8UMLoqRZA+94xiFpJA=", + "lastModified": 1741851582, + "narHash": "sha256-cPfs8qMccim2RBgtKGF+x9IBCduRvd/N5F4nYpU0TVE=", "owner": "nixos", "repo": "nixpkgs", - "rev": "3ae20aa58a6c0d1ca95c9b11f59a2d12eebc511f", + "rev": "6607cf789e541e7873d40d3a8f7815ea92204f32", "type": "github" }, "original": { @@ -68,22 +50,6 @@ "type": "github" } }, - "nixpkgs_2": { - "locked": { - "lastModified": 1681358109, - "narHash": "sha256-eKyxW4OohHQx9Urxi7TQlFBTDWII+F+x2hklDOQPB50=", - "owner": "NixOS", - "repo": "nixpkgs", - "rev": "96ba1c52e54e74c3197f4d43026b3f3d92e83ff9", - "type": "github" - }, - "original": { - "owner": "NixOS", - "ref": "nixpkgs-unstable", - "repo": "nixpkgs", - "type": "github" - } - }, "root": { "inputs": { "flake-utils": "flake-utils", @@ -94,15 +60,16 @@ }, "rust-overlay": { "inputs": { - "flake-utils": "flake-utils_2", - "nixpkgs": "nixpkgs_2" + "nixpkgs": [ + "nixpkgs" + ] }, "locked": { - "lastModified": 1687660699, - "narHash": "sha256-crI/CA/OJc778I5qJhwhhl8/PKKzc0D7vvVxOtjfvSo=", + "lastModified": 1742005800, + "narHash": "sha256-6wuOGWkyW6R4A6Th9NMi6WK2jjddvZt7V2+rLPk6L3o=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "b3bd1d49f1ae609c1d68a66bba7a95a9a4256031", + "rev": "028cd247a6375f83b94adc33d83676480fc9c294", "type": "github" }, "original": { @@ -125,21 +92,6 @@ "repo": "default", "type": "github" } - }, - "systems_2": { - "locked": { - "lastModified": 1681028828, - "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", - "owner": "nix-systems", - "repo": "default", - "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", - "type": "github" - }, - "original": { - "owner": "nix-systems", - "repo": "default", - "type": "github" - } } }, "root": "root", diff --git a/flake.nix b/flake.nix index 21e2442..9354115 100644 --- a/flake.nix +++ b/flake.nix @@ -2,9 +2,12 @@ description = "libFenrir"; inputs = { - nixpkgs.url = "github:nixos/nixpkgs/nixos-23.05"; + nixpkgs.url = "github:nixos/nixpkgs/nixos-24.11"; nixpkgs-unstable.url = "github:nixos/nixpkgs/nixos-unstable"; - rust-overlay.url = "github:oxalica/rust-overlay"; + rust-overlay = { + url = "github:oxalica/rust-overlay"; + inputs.nixpkgs.follows = "nixpkgs"; + }; flake-utils.url = "github:numtide/flake-utils"; }; @@ -18,35 +21,47 @@ pkgs-unstable = import nixpkgs-unstable { inherit system overlays; }; - RUST_VERSION="1.69.0"; + #RUST_VERSION="1.85.0"; + RUST_VERSION="2025-03-15"; in { devShells.default = pkgs.mkShell { + name = "libFenrir"; buildInputs = with pkgs; [ - git - gnupg - openssh - openssl - pkg-config - exa - fd - #(rust-bin.stable.latest.default.override { - # go with nightly to have async fn in traits - #(rust-bin.nightly."2023-02-01".default.override { - # #extensions = [ "rust-src" ]; - # #targets = [ "arm-unknown-linux-gnueabihf" ]; - #}) - clippy - cargo-watch - cargo-flamegraph - cargo-license - lld - rust-bin.stable.${RUST_VERSION}.default - rustfmt - rust-analyzer + # system deps + git + gnupg + openssh + openssl + pkg-config + fd + # rust deps + #(rust-bin.stable.latest.default.override { + # go with nightly to have async fn in traits + #(rust-bin.nightly."2023-02-01".default.override { + # #extensions = [ "rust-src" ]; + # #targets = [ "arm-unknown-linux-gnueabihf" ]; + #}) + clippy + cargo-watch + cargo-flamegraph + cargo-license + lld + #rust-bin.stable.${RUST_VERSION}.default + #rust-bin.beta.${RUST_VERSION}.default + rust-bin.nightly.${RUST_VERSION}.default + rustfmt + rust-analyzer + #clang_16 + #mold # fenrir deps - hwloc + hwloc ]; + # if you want to try the mold linker, add 'clang_16', 'mold', and append this to ~/.cargo/config.toml: + # [target.x86_64-unknown-linux-gnu] + # linker = "clang" + # rustflags = ["-C", "link-arg=--ld-path=mold"] + shellHook = '' # use zsh or other custom shell USER_SHELL="$(grep $USER /etc/passwd | cut -d ':' -f 7)" diff --git a/rustfmt.toml b/rustfmt.toml index 8008a3e..a562fa2 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,4 +1,4 @@ -edition = "2021" +edition = "2024" unstable_features = true format_strings = true max_width = 80 diff --git a/src/connection/handshake/tracker.rs b/src/connection/handshake/tracker.rs index fbedc34..ad4fd85 100644 --- a/src/connection/handshake/tracker.rs +++ b/src/connection/handshake/tracker.rs @@ -277,7 +277,7 @@ impl Tracker { use handshake::dirsync::DirSync; match handshake.data { handshake::Data::DirSync(ref mut ds) => match ds { - DirSync::Req(ref mut req) => { + &mut DirSync::Req(ref mut req) => { if !self.key_exchanges.contains(&req.exchange) { return Err(enc::Error::UnsupportedKeyExchange.into()); } @@ -298,21 +298,19 @@ impl Tracker { let ephemeral_key; match has_key { Some(s_k) => { - if let PrivKey::Exchange(ref k) = &s_k.key { + if let &PrivKey::Exchange(ref k) = &s_k.key { ephemeral_key = k; } else { unreachable!(); } } - None => { - return Err(handshake::Error::UnknownKeyID.into()) - } + None => return Err(Error::UnknownKeyID.into()), } let shared_key = match ephemeral_key .key_exchange(req.exchange, req.exchange_key) { Ok(shared_key) => shared_key, - Err(e) => return Err(handshake::Error::Key(e).into()), + Err(e) => return Err(Error::Key(e).into()), }; let hkdf = Hkdf::new(hkdf::Kind::Sha3, b"fenrir", shared_key); @@ -335,7 +333,7 @@ impl Tracker { req.data.deserialize_as_cleartext(cleartext)?; } Err(e) => { - return Err(handshake::Error::Key(e).into()); + return Err(Error::Key(e).into()); } } @@ -352,7 +350,7 @@ impl Tracker { "No such client key id: {:?}", resp.client_key_id ); - return Err(handshake::Error::UnknownKeyID.into()); + return Err(Error::UnknownKeyID.into()); } }; let cipher_recv = &hshake.connection.cipher_recv; @@ -371,7 +369,7 @@ impl Tracker { resp.data.deserialize_as_cleartext(&cleartext)?; } Err(e) => { - return Err(handshake::Error::Key(e).into()); + return Err(Error::Key(e).into()); } } let hshake = diff --git a/src/connection/mod.rs b/src/connection/mod.rs index e205e33..4b88d34 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -31,7 +31,7 @@ use crate::{ pub enum Error { /// Can't decrypt packet #[error("Decrypt error: {0}")] - Decrypt(#[from] crate::enc::Error), + Decrypt(#[from] enc::Error), /// Error in parsing a packet realated to the connection #[error("Chunk parsing: {0}")] Parse(#[from] stream::Error), @@ -321,7 +321,8 @@ impl Connection { Some(stream) => stream, None => continue, }; - ret.push((stream_id, stream.get())); + let data = stream.get(); // FIXME + ret.push((stream_id, data.1)); } Some(ret) } @@ -330,7 +331,7 @@ impl Connection { mut udp: crate::RawUdp, ) -> Result { let mut data = &mut udp.data[ID::len()..]; - let aad = enc::sym::AAD(&[]); + let aad = sym::AAD(&[]); self.cipher_recv.decrypt(aad, &mut data)?; let mut bytes_parsed = 0; let mut chunks = Vec::with_capacity(2); diff --git a/src/connection/stream/errors.rs b/src/connection/stream/errors.rs index 07dcbd8..8efdc49 100644 --- a/src/connection/stream/errors.rs +++ b/src/connection/stream/errors.rs @@ -9,4 +9,10 @@ pub enum Error { /// Sequence outside of the window #[error("Sequence out of the sliding window")] OutOfWindow, + /// Wrong start/end flags received, can't reconstruct data + #[error("Wrong start/end flags received")] + WrongFlags, + /// Can't reconstruct the data + #[error("Error in reconstructing the bytestream/datagrams")] + Reconstructing, } diff --git a/src/connection/stream/mod.rs b/src/connection/stream/mod.rs index 44ac3be..c63c3c3 100644 --- a/src/connection/stream/mod.rs +++ b/src/connection/stream/mod.rs @@ -4,9 +4,17 @@ mod errors; mod rob; +mod uud; +mod uudl; pub use errors::Error; -use crate::{connection::stream::rob::ReliableOrderedBytestream, enc::Random}; +use crate::{ + connection::stream::{ + rob::ReliableOrderedBytestream, uud::UnreliableUnorderedDatagram, + uudl::UnreliableUnorderedDatagramLimited, + }, + enc::Random, +}; /// Kind of stream. any combination of: /// reliable/unreliable ordered/unordered, bytestream/datagram @@ -16,6 +24,9 @@ pub enum Kind { /// ROB: Reliable, Ordered, Bytestream /// AKA: TCP-like ROB = 0, + /// UUDL: Unreliable, Unordered, Datagram Limited + /// Aka: UDP-like. Data limited to the packet size + UUDL, } /// Id of the stream @@ -52,28 +63,49 @@ impl ChunkLen { #[derive(Debug, Copy, Clone)] pub(crate) struct SequenceStart(pub(crate) Sequence); impl SequenceStart { - pub(crate) fn plus_u32(&self, other: u32) -> Sequence { - self.0.plus_u32(other) - } pub(crate) fn offset(&self, seq: Sequence) -> usize { - if self.0 .0 <= seq.0 { - (seq.0 - self.0 .0).0 as usize + if self.0.0 <= seq.0 { + (seq.0 - self.0.0).0 as usize } else { - (seq.0 + (Sequence::max().0 - self.0 .0)).0 as usize + (seq.0 + (Sequence::max().0 - self.0.0)).0 as usize } } } -// SequenceEnd is INCLUSIVE -#[derive(Debug, Copy, Clone)] -pub(crate) struct SequenceEnd(pub(crate) Sequence); -impl SequenceEnd { - pub(crate) fn plus_u32(&self, other: u32) -> Sequence { - self.0.plus_u32(other) + +impl ::core::ops::Add for SequenceStart { + type Output = SequenceStart; + fn add(self, other: u32) -> SequenceStart { + SequenceStart(self.0 + other) } } -/// Sequence number to rebuild the stream correctly +impl ::core::ops::AddAssign for SequenceStart { + fn add_assign(&mut self, other: u32) { + self.0 += other; + } +} + +// SequenceEnd is INCLUSIVE #[derive(Debug, Copy, Clone)] +pub(crate) struct SequenceEnd(pub(crate) Sequence); + +impl ::core::ops::Add for SequenceEnd { + type Output = SequenceEnd; + fn add(self, other: u32) -> SequenceEnd { + SequenceEnd(self.0 + other) + } +} + +impl ::core::ops::AddAssign for SequenceEnd { + fn add_assign(&mut self, other: u32) { + self.0 += other; + } +} + +// TODO: how to tell the compiler we don't use the two most significant bits? +// maybe NonZero + always using 2nd most significant bit? +/// Sequence number to rebuild the stream correctly +#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)] pub struct Sequence(pub ::core::num::Wrapping); impl Sequence { @@ -90,6 +122,10 @@ impl Sequence { 4 } /// Maximum possible sequence + pub const fn min() -> Self { + Self(::core::num::Wrapping(0)) + } + /// Maximum possible sequence pub const fn max() -> Self { Self(::core::num::Wrapping(Self::SEQ_NOFLAG)) } @@ -98,27 +134,48 @@ impl Sequence { start: SequenceStart, end: SequenceEnd, ) -> bool { - if start.0 .0 < end.0 .0 { - start.0 .0 <= self.0 && self.0 <= end.0 .0 + if start.0 < end.0 { + start.0.0 <= self.0 && self.0 <= end.0.0 } else { - start.0 .0 <= self.0 || self.0 <= end.0 .0 + start.0.0 <= self.0 || self.0 <= end.0.0 } } + pub(crate) fn cmp_in_window( + &self, + window_start: SequenceStart, + compare: Sequence, + ) -> ::core::cmp::Ordering { + let offset_self = self.0 - window_start.0.0; + let offset_compare = compare.0 - window_start.0.0; + return offset_self.cmp(&offset_compare); + } pub(crate) fn remaining_window(&self, end: SequenceEnd) -> u32 { - if self.0 <= end.0 .0 { - (end.0 .0 .0 - self.0 .0) + 1 + if self.0 <= end.0.0 { + (end.0.0.0 - self.0.0) + 1 } else { - end.0 .0 .0 + 1 + (Self::max().0 - self.0).0 + end.0.0.0 + 1 + (Self::max().0 - self.0).0 } } - pub(crate) fn plus_u32(self, other: u32) -> Self { + pub(crate) fn diff_from(self, other: Sequence) -> u32 { + assert!( + self.0.0 > other.0.0, + "Sequence::diff_from inverted parameters" + ); + self.0.0 - other.0.0 + } +} + +impl ::core::ops::Sub for Sequence { + type Output = Self; + + fn sub(self, other: u32) -> Self { Self(::core::num::Wrapping( - (self.0 .0 + other) & Self::SEQ_NOFLAG, + (self.0 - ::core::num::Wrapping::(other)).0 & Self::SEQ_NOFLAG, )) } } -impl ::core::ops::Add for Sequence { +impl ::core::ops::Add for Sequence { type Output = Self; fn add(self, other: Self) -> Self { @@ -128,10 +185,24 @@ impl ::core::ops::Add for Sequence { } } +impl ::core::ops::Add for Sequence { + type Output = Sequence; + fn add(self, other: u32) -> Sequence { + Sequence(self.0 + ::core::num::Wrapping::(other)) + } +} + +impl ::core::ops::AddAssign for Sequence { + fn add_assign(&mut self, other: u32) { + self.0 += ::core::num::Wrapping::(other); + } +} + /// Chunk of data representing a stream /// Every chunk is as follows: /// | id (2 bytes) | length (2 bytes) | /// | flag_start (1 BIT) | flag_end (1 BIT) | sequence (30 bits) | +/// | ...data... | #[derive(Debug, Clone)] pub struct Chunk<'a> { /// Id of the stream this chunk is part of @@ -203,7 +274,7 @@ impl<'a> Chunk<'a> { let bytes = bytes_next; bytes_next = bytes_next + Sequence::len(); raw_out[bytes..bytes_next] - .copy_from_slice(&self.sequence.0 .0.to_le_bytes()); + .copy_from_slice(&self.sequence.0.0.to_le_bytes()); let mut flag_byte = raw_out[bytes] & Self::FLAGS_EXCLUDED_BITMASK; if self.flag_start { flag_byte = flag_byte | Self::FLAG_START_BITMASK; @@ -223,17 +294,21 @@ impl<'a> Chunk<'a> { /// differences from Kind: /// * not public /// * has actual data -#[derive(Debug, Clone)] +#[derive(Debug)] pub(crate) enum Tracker { /// ROB: Reliable, Ordered, Bytestream /// AKA: TCP-like ROB(ReliableOrderedBytestream), + UUDL(UnreliableUnorderedDatagramLimited), } impl Tracker { pub(crate) fn new(kind: Kind, rand: &Random) -> Self { match kind { Kind::ROB => Tracker::ROB(ReliableOrderedBytestream::new(rand)), + Kind::UUDL => { + Tracker::UUDL(UnreliableUnorderedDatagramLimited::new()) + } } } } @@ -259,7 +334,7 @@ impl ::core::ops::BitOr for StreamData { } /// Actual stream-tracking structure -#[derive(Debug, Clone)] +#[derive(Debug)] pub(crate) struct Stream { id: ID, data: Tracker, @@ -277,11 +352,16 @@ impl Stream { pub(crate) fn recv(&mut self, chunk: Chunk) -> Result { match &mut self.data { Tracker::ROB(tracker) => tracker.recv(chunk), + Tracker::UUDL(tracker) => tracker.recv(chunk), } } - pub(crate) fn get(&mut self) -> Vec { + pub(crate) fn get(&mut self) -> (SequenceStart, Vec) { match &mut self.data { - Tracker::ROB(tracker) => tracker.get(), + // FIXME + Tracker::ROB(tracker) => { + (SequenceStart(Sequence::min()), tracker.get()) + } + Tracker::UUDL(tracker) => tracker.get(), } } } diff --git a/src/connection/stream/rob/mod.rs b/src/connection/stream/rob/mod.rs index 21361bb..901dc69 100644 --- a/src/connection/stream/rob/mod.rs +++ b/src/connection/stream/rob/mod.rs @@ -26,7 +26,7 @@ impl ReliableOrderedBytestream { pub(crate) fn new(rand: &Random) -> Self { let window_len = 1048576; // 1MB. should be enough for anybody. (lol) let window_start = SequenceStart(Sequence::new(rand)); - let window_end = SequenceEnd(window_start.0.plus_u32(window_len - 1)); + let window_end = SequenceEnd(window_start.0 +(window_len - 1)); let mut data = Vec::with_capacity(window_len as usize); data.resize(data.capacity(), 0); @@ -44,9 +44,9 @@ impl ReliableOrderedBytestream { "Max window size is {}", Sequence::max().0 .0 ); - let window_len = size; // 1MB. should be enough for anybody. (lol) + let window_len = size; let window_start = SequenceStart(Sequence::new(rand)); - let window_end = SequenceEnd(window_start.0.plus_u32(window_len - 1)); + let window_end = SequenceEnd(window_start.0 +(window_len - 1)); let mut data = Vec::with_capacity(window_len as usize); data.resize(data.capacity(), 0); @@ -68,20 +68,20 @@ impl ReliableOrderedBytestream { ret.extend_from_slice(first); ret.extend_from_slice(second); self.window_start = - SequenceStart(self.window_start.plus_u32(ret.len() as u32)); + self.window_start + (ret.len() as u32); self.window_end = - SequenceEnd(self.window_end.plus_u32(ret.len() as u32)); + self.window_end + (ret.len() as u32); self.data.clear(); return ret; } let data_len = self.window_start.offset(self.missing[0].0); let last_missing_idx = self.missing.len() - 1; let mut last_missing = &mut self.missing[last_missing_idx]; - last_missing.1 = last_missing.1.plus_u32(data_len as u32); + last_missing.1 = last_missing.1 + (data_len as u32); self.window_start = - SequenceStart(self.window_start.plus_u32(data_len as u32)); + self.window_start + (data_len as u32); self.window_end = - SequenceEnd(self.window_end.plus_u32(data_len as u32)); + self.window_end + (data_len as u32); let mut ret = Vec::with_capacity(data_len); let (first, second) = self.data[..].split_at(self.pivot as usize); @@ -141,25 +141,24 @@ impl ReliableOrderedBytestream { // [....chunk....] // [...missing...] copy_ranges.push((missing_from, offset_end)); - el.0 = - el.0.plus_u32(((offset_end - missing_from) + 1) as u32); + el.0 +=((offset_end - missing_from) + 1) as u32; } } else if missing_from < offset { if missing_to > offset_end { // [..chunk..] // [....missing....] to_add.push(( - el.0.plus_u32(((offset_end - missing_from) + 1) as u32), + el.0 + (((offset_end - missing_from) + 1) as u32), el.1, )); - el.1 = el.0.plus_u32(((offset - missing_from) - 1) as u32); + el.1 = el.0 + (((offset - missing_from) - 1) as u32); copy_ranges.push((offset, offset_end)); } else if offset <= missing_to { // [....chunk....] // [...missing...] copy_ranges.push((offset, (missing_to - 0))); el.1 = - el.0.plus_u32(((offset_end - missing_from) - 1) as u32); + el.0 + (((offset_end - missing_from) - 1) as u32); } } } diff --git a/src/connection/stream/rob/tests.rs b/src/connection/stream/rob/tests.rs index 20cb508..f599d09 100644 --- a/src/connection/stream/rob/tests.rs +++ b/src/connection/stream/rob/tests.rs @@ -36,7 +36,7 @@ fn test_stream_rob_sequential() { id: stream::ID(42), flag_start: false, flag_end: true, - sequence: start.plus_u32(512), + sequence: start + 512, data: &data[512..], }; let _ = rob.recv(chunk); @@ -77,7 +77,7 @@ fn test_stream_rob_retransmit() { id: stream::ID(42), flag_start: false, flag_end: false, - sequence: start.plus_u32(50), + sequence: start +50, data: &data[50..60], }; let _ = rob.recv(chunk); @@ -85,7 +85,7 @@ fn test_stream_rob_retransmit() { id: stream::ID(42), flag_start: false, flag_end: false, - sequence: start.plus_u32(40), + sequence: start + 40, data: &data[40..60], }; let _ = rob.recv(chunk); @@ -93,7 +93,7 @@ fn test_stream_rob_retransmit() { id: stream::ID(42), flag_start: false, flag_end: false, - sequence: start.plus_u32(80), + sequence: start + 80, data: &data[80..], }; let _ = rob.recv(chunk); @@ -101,7 +101,7 @@ fn test_stream_rob_retransmit() { id: stream::ID(42), flag_start: false, flag_end: false, - sequence: start.plus_u32(50), + sequence: start + 50, data: &data[50..90], }; let _ = rob.recv(chunk); @@ -109,7 +109,7 @@ fn test_stream_rob_retransmit() { id: stream::ID(42), flag_start: false, flag_end: false, - sequence: start.plus_u32(max_window as u32), + sequence: start +(max_window as u32), data: &data[max_window..], }; let _ = rob.recv(chunk); @@ -117,7 +117,7 @@ fn test_stream_rob_retransmit() { id: stream::ID(42), flag_start: false, flag_end: true, - sequence: start.plus_u32(90), + sequence: start +90, data: &data[90..max_window], }; let _ = rob.recv(chunk); @@ -157,7 +157,7 @@ fn test_stream_rob_rolling() { id: stream::ID(42), flag_start: true, flag_end: false, - sequence: start.plus_u32(50), + sequence: start + 50, data: &data[50..100], }; let _ = rob.recv(chunk); @@ -172,7 +172,7 @@ fn test_stream_rob_rolling() { id: stream::ID(42), flag_start: true, flag_end: false, - sequence: start.plus_u32(40), + sequence: start + 40, data: &data[40..], }; let _ = rob.recv(chunk); @@ -212,7 +212,7 @@ fn test_stream_rob_rolling_second_case() { id: stream::ID(42), flag_start: true, flag_end: false, - sequence: start.plus_u32(50), + sequence: start + 50, data: &data[50..100], }; let _ = rob.recv(chunk); @@ -227,7 +227,7 @@ fn test_stream_rob_rolling_second_case() { id: stream::ID(42), flag_start: true, flag_end: false, - sequence: start.plus_u32(40), + sequence: start + 40, data: &data[40..100], }; let _ = rob.recv(chunk); @@ -235,7 +235,7 @@ fn test_stream_rob_rolling_second_case() { id: stream::ID(42), flag_start: true, flag_end: false, - sequence: start.plus_u32(100), + sequence: start + 100, data: &data[100..], }; let _ = rob.recv(chunk); diff --git a/src/connection/stream/uud/mod.rs b/src/connection/stream/uud/mod.rs new file mode 100644 index 0000000..f14e1f7 --- /dev/null +++ b/src/connection/stream/uud/mod.rs @@ -0,0 +1,557 @@ +//! Implementation of the Unreliable, unordered, Datagram transmission model +//! +//! AKA: UDP-like, but the datagram can cross the packet-size (MTU) limit. +//! +//! Only fully received datagrams will be delivered to the user, and +//! half-received ones will be discarded after a timeout + +use crate::{ + connection::stream::{ + Chunk, Error, Sequence, SequenceEnd, SequenceStart, StreamData, + }, + enc::Random, +}; + +use ::std::collections::{BTreeMap, VecDeque}; + +#[cfg(test)] +mod tests; + +#[derive(Debug, PartialEq, Eq)] +enum Fragment { + Start((Sequence, Sequence)), + Middle((Sequence, Sequence)), + End((Sequence, Sequence)), + Full((Sequence, Sequence)), + Delivered((Sequence, Sequence)), +} + +impl Fragment { + // FIXME: sequence start/end? + fn get_seqs(&self) -> (Sequence, Sequence) { + match self { + Fragment::Start((f, t)) + | Fragment::Middle((f, t)) + | Fragment::End((f, t)) + | Fragment::Full((f, t)) + | Fragment::Delivered((f, t)) => (*f, *t), + } + } + fn is_start(&self) -> bool { + match self { + Fragment::Start(_) | Fragment::Full(_) | Fragment::Delivered(_) => { + true + } + Fragment::End(_) | Fragment::Middle(_) => false, + } + } + fn is_end(&self) -> bool { + match self { + Fragment::End(_) | Fragment::Full(_) | Fragment::Delivered(_) => { + true + } + Fragment::Start(_) | Fragment::Middle(_) => false, + } + } +} +/* +impl ::std::cmp::PartialEq for Fragment { + fn eq(&self, other: &Sequence) -> bool { + self.get_seq() == *other + } +} + +impl ::std::cmp::PartialOrd for Fragment { + fn partial_cmp(&self, other: &Sequence) -> Option<::std::cmp::Ordering> { + Some(self.get_seq().cmp(other)) + } +} + +impl ::std::cmp::PartialOrd for Fragment { + fn partial_cmp(&self, other: &Fragment) -> Option<::std::cmp::Ordering> { + Some(self.get_seq().cmp(&other.get_seq())) + } +} +impl Ord for Fragment { + fn cmp(&self, other: &Fragment) -> ::std::cmp::Ordering { + self.get_seq().cmp(&other.get_seq()) + } +} +*/ + +type Timer = u64; + +pub(crate) struct Uud { + pub(crate) window_start: SequenceStart, + window_end: SequenceEnd, + pivot: u32, + data: Vec, + track: VecDeque<(Fragment, Timer)>, +} + +impl Uud { + pub(crate) fn new(rand: &Random) -> Self { + let window_len = 1048576; // 1MB. should be enough for anybody. (lol) + let window_start = SequenceStart(Sequence::new(rand)); + let window_end = SequenceEnd(window_start.0 + (window_len - 1)); + let mut data = Vec::with_capacity(window_len as usize); + data.resize(data.capacity(), 0); + + Self { + window_start, + window_end, + pivot: window_len, + data, + track: VecDeque::with_capacity(4), + } + } + pub(crate) fn with_window_size(rand: &Random, size: u32) -> Self { + assert!( + size < Sequence::max().0.0, + "Max window size is {}", + Sequence::max().0.0 + ); + let window_len = size; + let window_start = SequenceStart(Sequence::new(rand)); + let window_end = SequenceEnd(window_start.0 + (window_len - 1)); + let mut data = Vec::with_capacity(window_len as usize); + data.resize(data.capacity(), 0); + + Self { + window_start, + window_end, + pivot: window_len, + data, + track: VecDeque::with_capacity(4), + } + } + pub(crate) fn window_size(&self) -> u32 { + self.data.len() as u32 + } + pub(crate) fn recv(&mut self, chunk: Chunk) -> Result { + let chunk_to = chunk.sequence + chunk.data.len() as u32; + if !chunk + .sequence + .is_between(self.window_start, self.window_end) + { + return Err(Error::OutOfWindow); + } + + // make sure we consider only the bytes inside the sliding window + let maxlen = ::std::cmp::min( + chunk.sequence.remaining_window(self.window_end) as usize, + chunk.data.len(), + ); + if maxlen == 0 { + // empty window or empty chunk, but we don't care + return Err(Error::OutOfWindow); + } + let chunk_flag_end: bool; + if maxlen != chunk.data.len() { + // we are not considering the full chunk, so + // make sure the end flag is not set + chunk_flag_end = false; + + // FIXME: what happens if we "truncate" this chunk now, + // then we have more space in the window + // then we receive the same packet again? + } else { + chunk_flag_end = chunk.flag_end; + } + // translate Sequences to offsets in self.data + let data = &chunk.data[..maxlen]; + let chunk_to = chunk.sequence + data.len() as u32; + let mut last_usable = self.window_end.0; + let mut ret = StreamData::NotReady; + let mut copy_data_idx_from = 0; + // FIXME: and on receiving first fragment after the second, or empty + // track? + for (idx, (fragment, _)) in self.track.iter_mut().enumerate().rev() { + let (from, to) = fragment.get_seqs(); + let to_next = to + 1; + match to_next.cmp_in_window(self.window_start, chunk.sequence) { + ::core::cmp::Ordering::Equal => { + // `chunk` is immediately after `fragment` + if !chunk_to.is_between( + SequenceStart(to_next), + SequenceEnd(last_usable), + ) { + return Err(Error::Reconstructing); + } + match fragment { + Fragment::Start((_, f_end)) => { + if chunk.flag_start { + // we can't start a datagram twice. + // ignore the data + return Err(Error::WrongFlags); + } + if chunk_flag_end { + *fragment = Fragment::Full(( + from, + to + (data.len() as u32), + )); + ret = StreamData::Ready; + } else { + *f_end += data.len() as u32; + } + copy_data_idx_from = + chunk.sequence.diff_from(self.window_start.0) + as usize; + } + Fragment::Middle((_, f_end)) => { + if chunk.flag_start { + // we can't start a datagram twice. + // ignore the data + return Err(Error::WrongFlags); + } + if chunk_flag_end { + *fragment = Fragment::End(( + from, + to + (data.len() as u32), + )); + } else { + *f_end += data.len() as u32; + } + copy_data_idx_from = + chunk.sequence.diff_from(self.window_start.0) + as usize; + } + Fragment::End(_) + | Fragment::Full(_) + | Fragment::Delivered(_) => { + if !chunk.flag_start { + return Err(Error::WrongFlags); + } + let toinsert = if chunk_flag_end { + ret = StreamData::Ready; + Fragment::Full((chunk.sequence, chunk_to)) + } else { + Fragment::Start((chunk.sequence, chunk_to)) + }; + self.track.insert(idx + 1, (toinsert, 0)); + copy_data_idx_from = + chunk.sequence.diff_from(self.window_start.0) + as usize; + } + } + break; + } + ::core::cmp::Ordering::Less => { + // there is a data hole between `chunk` and `fragment` + + if !chunk_to.is_between( + SequenceStart(to_next), + SequenceEnd(last_usable), + ) { + return Err(Error::Reconstructing); + } + let toinsert = if chunk.flag_start { + if chunk_flag_end { + ret = StreamData::Ready; + Fragment::Full((chunk.sequence, chunk_to)) + } else { + Fragment::Start((chunk.sequence, chunk_to)) + } + } else { + if chunk_flag_end { + Fragment::End((chunk.sequence, chunk_to)) + } else { + Fragment::Middle((chunk.sequence, chunk_to)) + } + }; + self.track.insert(idx + 1, (toinsert, 0)); + copy_data_idx_from = + chunk.sequence.diff_from(self.window_start.0) as usize; + break; + } + ::core::cmp::Ordering::Greater => { + // to_next > chunk.sequence + // `fragment` is too new, need to look at older ones + + if from.cmp_in_window(self.window_start, chunk.sequence) + != ::core::cmp::Ordering::Greater + { + // to_next > chunk.sequence >= from + // overlapping not yet allowed + return Err(Error::Reconstructing); + } + if idx == 0 { + // check if we can add before everything + if chunk_to == from { + if fragment.is_start() { + if chunk_flag_end { + // add, don't merge + } else { + //fragment.start, but !chunk.end + return Err(Error::WrongFlags); + } + } else { + if chunk_flag_end { + //chunk.end but !fragment.start + return Err(Error::WrongFlags); + } else { + if chunk.flag_start { + if fragment.is_end() { + *fragment = Fragment::Full(( + chunk.sequence, + to, + )); + ret = StreamData::Ready; + } else { + *fragment = Fragment::Start(( + chunk.sequence, + to, + )); + } + } else { + if fragment.is_end() { + *fragment = Fragment::End(( + chunk.sequence, + to, + )); + } else { + *fragment = Fragment::Middle(( + chunk.sequence, + to, + )); + } + } + } + } + copy_data_idx_from = + chunk.sequence.diff_from(self.window_start.0) + as usize; + break; + } + // chunk before fragment + let toinsert = if chunk.flag_start { + if chunk_flag_end { + ret = StreamData::Ready; + Fragment::Full((chunk.sequence, chunk_to)) + } else { + Fragment::Start((chunk.sequence, chunk_to)) + } + } else { + if chunk_flag_end { + Fragment::End((chunk.sequence, chunk_to)) + } else { + Fragment::Middle((chunk.sequence, chunk_to)) + } + }; + self.track.insert(0, (toinsert, 0)); + copy_data_idx_from = + chunk.sequence.diff_from(self.window_start.0) + as usize; + break; + } + last_usable = from - 1; + } + } + } + let data_idx_from = + (copy_data_idx_from + self.pivot as usize) % self.data.len(); + let data_idx_to = (data_idx_from + data.len()) % self.data.len(); + if data_idx_from < data_idx_to { + self.data[data_idx_from..data_idx_to].copy_from_slice(&data); + } else { + let data_pivot = self.data.len() - data_idx_from; + let (first, second) = data.split_at(data_pivot); + self.data[data_idx_from..].copy_from_slice(&first); + self.data[..data_idx_to].copy_from_slice(&data); + } + Ok(ret) + } +} + +/// Copy of ROB for reference +#[derive(Debug, Clone)] +pub(crate) struct UnreliableUnorderedDatagram { + pub(crate) window_start: SequenceStart, + window_end: SequenceEnd, + pivot: u32, + data: Vec, + missing: Vec<(Sequence, Sequence)>, +} + +impl UnreliableUnorderedDatagram { + pub(crate) fn new(rand: &Random) -> Self { + let window_len = 1048576; // 1MB. should be enough for anybody. (lol) + let window_start = SequenceStart(Sequence::new(rand)); + let window_end = SequenceEnd(window_start.0 + (window_len - 1)); + let mut data = Vec::with_capacity(window_len as usize); + data.resize(data.capacity(), 0); + + Self { + window_start, + window_end, + pivot: window_len, + data, + missing: [(window_start.0, window_end.0)].to_vec(), + } + } + pub(crate) fn with_window_size(rand: &Random, size: u32) -> Self { + assert!( + size < Sequence::max().0.0, + "Max window size is {}", + Sequence::max().0.0 + ); + let window_len = size; // 1MB. should be enough for anybody. (lol) + let window_start = SequenceStart(Sequence::new(rand)); + let window_end = SequenceEnd(window_start.0 + (window_len - 1)); + let mut data = Vec::with_capacity(window_len as usize); + data.resize(data.capacity(), 0); + + Self { + window_start, + window_end, + pivot: window_len, + data, + missing: [(window_start.0, window_end.0)].to_vec(), + } + } + pub(crate) fn window_size(&self) -> u32 { + self.data.len() as u32 + } + pub(crate) fn get(&mut self) -> Vec { + if self.missing.len() == 0 { + let (first, second) = self.data.split_at(self.pivot as usize); + let mut ret = Vec::with_capacity(self.data.len()); + ret.extend_from_slice(first); + ret.extend_from_slice(second); + self.window_start += ret.len() as u32; + self.window_end = SequenceEnd(Sequence( + ::core::num::Wrapping::(ret.len() as u32), + )); + self.data.clear(); + return ret; + } + let data_len = self.window_start.offset(self.missing[0].0); + let last_missing_idx = self.missing.len() - 1; + let mut last_missing = &mut self.missing[last_missing_idx]; + last_missing.1 += data_len as u32; + self.window_start += data_len as u32; + self.window_end += data_len as u32; + + let mut ret = Vec::with_capacity(data_len); + let (first, second) = self.data[..].split_at(self.pivot as usize); + let first_len = ::core::cmp::min(data_len, first.len()); + let second_len = data_len - first_len; + + ret.extend_from_slice(&first[..first_len]); + ret.extend_from_slice(&second[..second_len]); + + self.pivot = + ((self.pivot as usize + data_len) % self.data.len()) as u32; + ret + } + pub(crate) fn recv(&mut self, chunk: Chunk) -> Result { + if !chunk + .sequence + .is_between(self.window_start, self.window_end) + { + return Err(Error::OutOfWindow); + } + // make sure we consider only the bytes inside the sliding window + let maxlen = ::std::cmp::min( + chunk.sequence.remaining_window(self.window_end) as usize, + chunk.data.len(), + ); + if maxlen == 0 { + // empty window or empty chunk, but we don't care + return Err(Error::OutOfWindow); + } + // translate Sequences to offsets in self.data + let data = &chunk.data[..maxlen]; + let offset = self.window_start.offset(chunk.sequence); + let offset_end = offset + chunk.data.len() - 1; + + // Find the chunks we are missing that we can copy, + // and fix the missing tracker + let mut copy_ranges = Vec::new(); + let mut to_delete = Vec::new(); + let mut to_add = Vec::new(); + // note: the ranges are (INCLUSIVE, INCLUSIVE) + for (idx, el) in self.missing.iter_mut().enumerate() { + let missing_from = self.window_start.offset(el.0); + if missing_from > offset_end { + break; + } + let missing_to = self.window_start.offset(el.1); + if missing_to < offset { + continue; + } + if missing_from >= offset && missing_from <= offset_end { + if missing_to <= offset_end { + // [.....chunk.....] + // [..missing..] + to_delete.push(idx); + copy_ranges.push((missing_from, missing_to)); + } else { + // [....chunk....] + // [...missing...] + copy_ranges.push((missing_from, offset_end)); + el.0 += ((offset_end - missing_from) + 1) as u32; + } + } else if missing_from < offset { + if missing_to > offset_end { + // [..chunk..] + // [....missing....] + to_add.push(( + el.0 + (((offset_end - missing_from) + 1) as u32), + el.1, + )); + el.1 = el.0 + (((offset - missing_from) - 1) as u32); + copy_ranges.push((offset, offset_end)); + } else if offset <= missing_to { + // [....chunk....] + // [...missing...] + copy_ranges.push((offset, (missing_to - 0))); + el.1 = el.0 + (((offset_end - missing_from) - 1) as u32); + } + } + } + { + let mut deleted = 0; + for idx in to_delete.into_iter() { + self.missing.remove(idx + deleted); + deleted = deleted + 1; + } + } + self.missing.append(&mut to_add); + self.missing + .sort_by(|(from_a, _), (from_b, _)| from_a.0.0.cmp(&from_b.0.0)); + + // copy only the missing data + let (first, second) = self.data[..].split_at_mut(self.pivot as usize); + for (from, to) in copy_ranges.into_iter() { + let to = to + 1; + if from <= first.len() { + let first_from = from; + let first_to = ::core::cmp::min(first.len(), to); + let data_first_from = from - offset; + let data_first_to = first_to - offset; + first[first_from..first_to] + .copy_from_slice(&data[data_first_from..data_first_to]); + + let second_to = to - first_to; + let data_second_to = data_first_to + second_to; + second[..second_to] + .copy_from_slice(&data[data_first_to..data_second_to]); + } else { + let second_from = from - first.len(); + let second_to = to - first.len(); + let data_from = from - offset; + let data_to = to - offset; + second[second_from..second_to] + .copy_from_slice(&data[data_from..data_to]); + } + } + if self.missing.len() == 0 + || self.window_start.offset(self.missing[0].0) == 0 + { + Ok(StreamData::Ready) + } else { + Ok(StreamData::NotReady) + } + } +} diff --git a/src/connection/stream/uud/tests.rs b/src/connection/stream/uud/tests.rs new file mode 100644 index 0000000..634853e --- /dev/null +++ b/src/connection/stream/uud/tests.rs @@ -0,0 +1,249 @@ +use crate::{ + connection::stream::{self, Chunk, uud::*}, + enc::Random, +}; + +#[::tracing_test::traced_test] +#[test] +fn test_stream_uud_sequential() { + let rand = Random::new(); + let mut uud = UnreliableUnorderedDatagram::with_window_size(&rand, 1048576); + + let mut data = Vec::with_capacity(1024); + data.resize(data.capacity(), 0); + rand.fill(&mut data[..]); + + let start = uud.window_start.0; + + let chunk = Chunk { + id: stream::ID(42), + flag_start: true, + flag_end: false, + sequence: start, + data: &data[..512], + }; + let got = uud.get(); + assert!(&got[..] == &[], "uud: got data?"); + let _ = uud.recv(chunk); + let got = uud.get(); + assert!( + &data[..512] == &got[..], + "UUD1: DIFF: {:?} {:?}", + &data[..512].len(), + &got[..].len() + ); + let chunk = Chunk { + id: stream::ID(42), + flag_start: false, + flag_end: true, + sequence: start + 512, + data: &data[512..], + }; + let _ = uud.recv(chunk); + let got = uud.get(); + assert!( + &data[512..] == &got[..], + "UUD2: DIFF: {:?} {:?}", + &data[512..].len(), + &got[..].len() + ); +} + +#[::tracing_test::traced_test] +#[test] +fn test_stream_uud_retransmit() { + let rand = Random::new(); + let max_window: usize = 100; + let mut uud = + UnreliableUnorderedDatagram::with_window_size(&rand, max_window as u32); + + let mut data = Vec::with_capacity(120); + data.resize(data.capacity(), 0); + for i in 0..data.len() { + data[i] = i as u8; + } + + let start = uud.window_start.0; + + let chunk = Chunk { + id: stream::ID(42), + flag_start: true, + flag_end: false, + sequence: start, + data: &data[..40], + }; + let _ = uud.recv(chunk); + let chunk = Chunk { + id: stream::ID(42), + flag_start: false, + flag_end: false, + sequence: start + 50, + data: &data[50..60], + }; + let _ = uud.recv(chunk); + let chunk = Chunk { + id: stream::ID(42), + flag_start: false, + flag_end: false, + sequence: start + 40, + data: &data[40..60], + }; + let _ = uud.recv(chunk); + let chunk = Chunk { + id: stream::ID(42), + flag_start: false, + flag_end: false, + sequence: start + 80, + data: &data[80..], + }; + let _ = uud.recv(chunk); + let chunk = Chunk { + id: stream::ID(42), + flag_start: false, + flag_end: false, + sequence: start + 50, + data: &data[50..90], + }; + let _ = uud.recv(chunk); + let chunk = Chunk { + id: stream::ID(42), + flag_start: false, + flag_end: false, + sequence: start + (max_window as u32), + data: &data[max_window..], + }; + let _ = uud.recv(chunk); + let chunk = Chunk { + id: stream::ID(42), + flag_start: false, + flag_end: true, + sequence: start + 90, + data: &data[90..max_window], + }; + let _ = uud.recv(chunk); + let got = uud.get(); + assert!( + &data[..max_window] == &got[..], + "DIFF:\n {:?}\n {:?}", + &data[..max_window], + &got[..], + ); +} +#[::tracing_test::traced_test] +#[test] +fn test_stream_uud_rolling() { + let rand = Random::new(); + let max_window: usize = 100; + let mut uud = + UnreliableUnorderedDatagram::with_window_size(&rand, max_window as u32); + + let mut data = Vec::with_capacity(120); + data.resize(data.capacity(), 0); + for i in 0..data.len() { + data[i] = i as u8; + } + + let start = uud.window_start.0; + + let chunk = Chunk { + id: stream::ID(42), + flag_start: true, + flag_end: false, + sequence: start, + data: &data[..40], + }; + let _ = uud.recv(chunk); + let chunk = Chunk { + id: stream::ID(42), + flag_start: true, + flag_end: false, + sequence: start + 50, + data: &data[50..100], + }; + let _ = uud.recv(chunk); + let got = uud.get(); + assert!( + &data[..40] == &got[..], + "DIFF:\n {:?}\n {:?}", + &data[..40], + &got[..], + ); + let chunk = Chunk { + id: stream::ID(42), + flag_start: true, + flag_end: false, + sequence: start + 40, + data: &data[40..], + }; + let _ = uud.recv(chunk); + let got = uud.get(); + assert!( + &data[40..] == &got[..], + "DIFF:\n {:?}\n {:?}", + &data[40..], + &got[..], + ); +} +#[::tracing_test::traced_test] +#[test] +fn test_stream_uud_rolling_second_case() { + let rand = Random::new(); + let max_window: usize = 100; + let mut uud = + UnreliableUnorderedDatagram::with_window_size(&rand, max_window as u32); + + let mut data = Vec::with_capacity(120); + data.resize(data.capacity(), 0); + for i in 0..data.len() { + data[i] = i as u8; + } + + let start = uud.window_start.0; + + let chunk = Chunk { + id: stream::ID(42), + flag_start: true, + flag_end: false, + sequence: start, + data: &data[..40], + }; + let _ = uud.recv(chunk); + let chunk = Chunk { + id: stream::ID(42), + flag_start: true, + flag_end: false, + sequence: start + 50, + data: &data[50..100], + }; + let _ = uud.recv(chunk); + let got = uud.get(); + assert!( + &data[..40] == &got[..], + "DIFF:\n {:?}\n {:?}", + &data[..40], + &got[..], + ); + let chunk = Chunk { + id: stream::ID(42), + flag_start: true, + flag_end: false, + sequence: start + 40, + data: &data[40..100], + }; + let _ = uud.recv(chunk); + let chunk = Chunk { + id: stream::ID(42), + flag_start: true, + flag_end: false, + sequence: start + 100, + data: &data[100..], + }; + let _ = uud.recv(chunk); + let got = uud.get(); + assert!( + &data[40..] == &got[..], + "DIFF:\n {:?}\n {:?}", + &data[40..], + &got[..], + ); +} diff --git a/src/connection/stream/uudl/mod.rs b/src/connection/stream/uudl/mod.rs new file mode 100644 index 0000000..aeebd9a --- /dev/null +++ b/src/connection/stream/uudl/mod.rs @@ -0,0 +1,43 @@ +//! Implementation of the Unreliable, Unordered, Datagram Limited +//! transmission model +//! +//! AKA: UDP-like. "Limited" because the data must fit in a single packet +//! + +use crate::connection::stream::{ + Chunk, Error, Sequence, SequenceStart, StreamData, +}; + +use ::std::collections::{BTreeMap, VecDeque}; + +#[cfg(test)] +mod tests; + +/// UnReliable, UnOrdered, Datagram, Limited to the packet size +/// AKA: UDP-like +#[derive(Debug)] +pub(crate) struct UnreliableUnorderedDatagramLimited { + received: VecDeque<(SequenceStart, Vec)>, +} + +impl UnreliableUnorderedDatagramLimited { + pub(crate) fn new() -> Self { + Self { + received: VecDeque::with_capacity(4), + } + } + pub(crate) fn get(&mut self) -> (SequenceStart, Vec) { + match self.received.pop_front() { + Some(data) => data, + None => (SequenceStart(Sequence::min()), Vec::new()), + } + } + pub(crate) fn recv(&mut self, chunk: Chunk) -> Result { + if !chunk.flag_start || !chunk.flag_end { + return Err(Error::WrongFlags); + } + self.received + .push_back((SequenceStart(chunk.sequence), chunk.data.to_vec())); + Ok(StreamData::Ready) + } +} diff --git a/src/connection/stream/uudl/tests.rs b/src/connection/stream/uudl/tests.rs new file mode 100644 index 0000000..7814902 --- /dev/null +++ b/src/connection/stream/uudl/tests.rs @@ -0,0 +1,56 @@ +use crate::{ + connection::stream::{self, uudl::*, Chunk}, + enc::Random, +}; + +#[::tracing_test::traced_test] +#[test] +fn test_stream_uudl_sequential() { + let rand = Random::new(); + let mut uudl = UnreliableUnorderedDatagramLimited::new(); + + let mut data = Vec::with_capacity(1024); + data.resize(data.capacity(), 0); + rand.fill(&mut data[..]); + + //let start = uudl.window_start.0; + let start = Sequence( + ::core::num::Wrapping(0) + ); + + let chunk = Chunk { + id: stream::ID(42), + flag_start: true, + flag_end: true, + sequence: start, + data: &data[..512], + }; + let got = uudl.get().1; + assert!(&got[..] == &[], "uudl: got data?"); + let _ = uudl.recv(chunk); + let got = uudl.get().1; + assert!( + &data[..512] == &got[..], + "UUDL1: DIFF: {:?} {:?}", + &data[..512].len(), + &got[..].len() + ); + let chunk = Chunk { + id: stream::ID(42), + flag_start: true, + flag_end: true, + sequence: start + 512, + data: &data[512..], + }; + let _ = uudl.recv(chunk); + let got = uudl.get().1; + assert!( + &data[512..] == &got[..], + "UUDL2: DIFF: {:?} {:?}", + &data[512..].len(), + &got[..].len() + ); + let got = uudl.get().1; + assert!(&got[..] == &[], "uudl: got data?"); +} + diff --git a/src/dnssec/mod.rs b/src/dnssec/mod.rs index d1128c1..9013ea6 100644 --- a/src/dnssec/mod.rs +++ b/src/dnssec/mod.rs @@ -87,14 +87,11 @@ impl Dnssec { )); } - let resolver = match TokioAsyncResolver::tokio(config, opts) { - Ok(resolver) => resolver, - Err(e) => return Err(Error::Setup(e.to_string())), - }; + let resolver = TokioAsyncResolver::tokio(config, opts); Ok(Self { resolver }) } - const TXT_RECORD_START: &str = "v=Fenrir1 "; + const TXT_RECORD_START: &'static str = "v=Fenrir1 "; /// Get the fenrir data for a domain pub async fn resolv(&self, domain: &Domain) -> ::std::io::Result { use ::trust_dns_client::rr::Name; diff --git a/src/enc/asym.rs b/src/enc/asym.rs index 4958a38..240f889 100644 --- a/src/enc/asym.rs +++ b/src/enc/asym.rs @@ -162,7 +162,8 @@ impl KeyExchangeKind { ) -> Result<(ExchangePrivKey, ExchangePubKey), Error> { match self { KeyExchangeKind::X25519DiffieHellman => { - let raw_priv = ::x25519_dalek::StaticSecret::new(rnd); + let raw_priv = + ::x25519_dalek::StaticSecret::random_from_rng(rnd); let pub_key = ExchangePubKey::X25519( ::x25519_dalek::PublicKey::from(&raw_priv), ); diff --git a/src/enc/hkdf.rs b/src/enc/hkdf.rs index 7ba885b..8d47ff3 100644 --- a/src/enc/hkdf.rs +++ b/src/enc/hkdf.rs @@ -71,7 +71,7 @@ impl Hkdf { // Hack & tricks: // HKDF are pretty important, but this lib don't zero out the data. // we can't use #[derive(Zeroing)] either. -// So we craete a union with a Zeroing object, and drop both manually. +// So we create a union with a Zeroing object, and drop the zeroable buffer. // TODO: move this to Hkdf instead of Sha3 @@ -88,8 +88,7 @@ impl Drop for HkdfInner { fn drop(&mut self) { #[allow(unsafe_code)] unsafe { - drop(&mut self.hkdf); - drop(&mut self.zeroable); + ::core::mem::ManuallyDrop::drop(&mut self.zeroable); } } } diff --git a/src/enc/mod.rs b/src/enc/mod.rs index 663c72d..6f919db 100644 --- a/src/enc/mod.rs +++ b/src/enc/mod.rs @@ -74,7 +74,7 @@ impl ::rand_core::RngCore for &Random { ) -> Result<(), ::rand_core::Error> { match self.rnd.fill(dest) { Ok(()) => Ok(()), - Err(e) => Err(::rand_core::Error::new(e)), + Err(e) => Err(::rand_core::Error::new(e.to_string())), } } } diff --git a/src/enc/tests.rs b/src/enc/tests.rs index ddd4125..b3fd73f 100644 --- a/src/enc/tests.rs +++ b/src/enc/tests.rs @@ -70,7 +70,7 @@ fn test_encrypt_decrypt() { let encrypt_to = encrypt_from + resp.encrypted_length(nonce_len, tag_len); let h_resp = - Handshake::new(handshake::Data::DirSync(dirsync::DirSync::Resp(resp))); + Handshake::new(Data::DirSync(dirsync::DirSync::Resp(resp))); let mut bytes = Vec::::with_capacity( h_resp.len(cipher.nonce_len(), cipher.tag_len()), @@ -119,7 +119,7 @@ fn test_encrypt_decrypt() { } }; // reparse - if let handshake::Data::DirSync(dirsync::DirSync::Resp(r_a)) = + if let Data::DirSync(dirsync::DirSync::Resp(r_a)) = &mut deserialized.data { let enc_start = r_a.encrypted_offset() + cipher.nonce_len().0; diff --git a/src/inner/mod.rs b/src/inner/mod.rs index 6102fde..dafcb5b 100644 --- a/src/inner/mod.rs +++ b/src/inner/mod.rs @@ -18,7 +18,7 @@ pub(crate) struct ThreadTracker { } pub(crate) static mut SLEEP_RESOLUTION: ::std::time::Duration = - if cfg!(linux) || cfg!(macos) { + if cfg!(target_os = "linux") || cfg!(target_os = "macos") { ::std::time::Duration::from_millis(1) } else { // windows diff --git a/src/inner/worker.rs b/src/inner/worker.rs index f344bb6..f4101b8 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -52,7 +52,7 @@ pub struct ConnData { /// Connection tracking information pub conn: ConnTracker, /// received data, for each stream - pub data: Vec<(stream::ID, Vec)>, + pub data: Vec<(stream::ID, Vec)>, //FIXME: ChunkOwned } /// Connection event. Mostly used to give the data to the user diff --git a/src/tests.rs b/src/tests.rs index ff28ae6..fe7eea2 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -17,7 +17,7 @@ async fn test_connection_dirsync() { } }; let cfg_client = { - let mut cfg = config::Config::default(); + let mut cfg = Config::default(); cfg.threads = Some(::core::num::NonZeroUsize::new(1).unwrap()); cfg }; -- 2.47.2 From 9ec52a0151fbda75af7976b954a063f67647c192 Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Fri, 28 Mar 2025 18:35:44 +0100 Subject: [PATCH 8/8] [stream] fragment, uud free/get Signed-off-by: Luca Fulchir --- flake.lock | 18 +- src/connection/stream/mod.rs | 36 +++- src/connection/stream/rob/mod.rs | 27 +-- src/connection/stream/uud/mod.rs | 348 +++++++++++++++++++------------ 4 files changed, 273 insertions(+), 156 deletions(-) diff --git a/flake.lock b/flake.lock index 01b37d0..8ee4474 100644 --- a/flake.lock +++ b/flake.lock @@ -20,11 +20,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1741862977, - "narHash": "sha256-prZ0M8vE/ghRGGZcflvxCu40ObKaB+ikn74/xQoNrGQ=", + "lastModified": 1742751704, + "narHash": "sha256-rBfc+H1dDBUQ2mgVITMGBPI1PGuCznf9rcWX/XIULyE=", "owner": "nixos", "repo": "nixpkgs", - "rev": "cdd2ef009676ac92b715ff26630164bb88fec4e0", + "rev": "f0946fa5f1fb876a9dc2e1850d9d3a4e3f914092", "type": "github" }, "original": { @@ -36,11 +36,11 @@ }, "nixpkgs-unstable": { "locked": { - "lastModified": 1741851582, - "narHash": "sha256-cPfs8qMccim2RBgtKGF+x9IBCduRvd/N5F4nYpU0TVE=", + "lastModified": 1742889210, + "narHash": "sha256-hw63HnwnqU3ZQfsMclLhMvOezpM7RSB0dMAtD5/sOiw=", "owner": "nixos", "repo": "nixpkgs", - "rev": "6607cf789e541e7873d40d3a8f7815ea92204f32", + "rev": "698214a32beb4f4c8e3942372c694f40848b360d", "type": "github" }, "original": { @@ -65,11 +65,11 @@ ] }, "locked": { - "lastModified": 1742005800, - "narHash": "sha256-6wuOGWkyW6R4A6Th9NMi6WK2jjddvZt7V2+rLPk6L3o=", + "lastModified": 1742956365, + "narHash": "sha256-Slrqmt6kJ/M7Z/ce4ebQWsz2aeEodrX56CsupOEPoz0=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "028cd247a6375f83b94adc33d83676480fc9c294", + "rev": "a0e3395c63cdbc9c1ec17915f8328c077c79c4a1", "type": "github" }, "original": { diff --git a/src/connection/stream/mod.rs b/src/connection/stream/mod.rs index c63c3c3..7b92131 100644 --- a/src/connection/stream/mod.rs +++ b/src/connection/stream/mod.rs @@ -29,6 +29,38 @@ pub enum Kind { UUDL, } +/// Tracking for a contiguous set of data +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +pub enum Fragment { + /// Beginning, no end + Start((SequenceStart, SequenceEnd)), + /// Neither beginning nor end + Middle((SequenceStart, SequenceEnd)), + /// No beginning, but with end + End((SequenceStart, SequenceEnd)), + /// both beginning and end, waiting to be delivered to the user + Ready((SequenceStart, SequenceEnd)), + /// both beginning and end, already delivered to the user + Delivered((SequenceStart, SequenceEnd)), + /// both beginning and end, data might not be available anymore + Deallocated((SequenceStart, SequenceEnd)), +} + +impl Fragment { + // FIXME: sequence start/end? + /// extract the sequences from the fragment + pub fn get_seqs(&self) -> (SequenceStart, SequenceEnd) { + match self { + Fragment::Start((f, t)) + | Fragment::Middle((f, t)) + | Fragment::End((f, t)) + | Fragment::Ready((f, t)) + | Fragment::Delivered((f, t)) + | Fragment::Deallocated((f, t)) => (*f, *t), + } + } +} + /// Id of the stream #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] pub struct ID(pub u16); @@ -60,7 +92,7 @@ impl ChunkLen { } //TODO: make pub? -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, Eq, PartialEq)] pub(crate) struct SequenceStart(pub(crate) Sequence); impl SequenceStart { pub(crate) fn offset(&self, seq: Sequence) -> usize { @@ -86,7 +118,7 @@ impl ::core::ops::AddAssign for SequenceStart { } // SequenceEnd is INCLUSIVE -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, Eq, PartialEq)] pub(crate) struct SequenceEnd(pub(crate) Sequence); impl ::core::ops::Add for SequenceEnd { diff --git a/src/connection/stream/rob/mod.rs b/src/connection/stream/rob/mod.rs index 901dc69..6626b08 100644 --- a/src/connection/stream/rob/mod.rs +++ b/src/connection/stream/rob/mod.rs @@ -26,7 +26,7 @@ impl ReliableOrderedBytestream { pub(crate) fn new(rand: &Random) -> Self { let window_len = 1048576; // 1MB. should be enough for anybody. (lol) let window_start = SequenceStart(Sequence::new(rand)); - let window_end = SequenceEnd(window_start.0 +(window_len - 1)); + let window_end = SequenceEnd(window_start.0 + (window_len - 1)); let mut data = Vec::with_capacity(window_len as usize); data.resize(data.capacity(), 0); @@ -40,13 +40,13 @@ impl ReliableOrderedBytestream { } pub(crate) fn with_window_size(rand: &Random, size: u32) -> Self { assert!( - size < Sequence::max().0 .0, + size < Sequence::max().0.0, "Max window size is {}", - Sequence::max().0 .0 + Sequence::max().0.0 ); let window_len = size; let window_start = SequenceStart(Sequence::new(rand)); - let window_end = SequenceEnd(window_start.0 +(window_len - 1)); + let window_end = SequenceEnd(window_start.0 + (window_len - 1)); let mut data = Vec::with_capacity(window_len as usize); data.resize(data.capacity(), 0); @@ -67,10 +67,8 @@ impl ReliableOrderedBytestream { let mut ret = Vec::with_capacity(self.data.len()); ret.extend_from_slice(first); ret.extend_from_slice(second); - self.window_start = - self.window_start + (ret.len() as u32); - self.window_end = - self.window_end + (ret.len() as u32); + self.window_start = self.window_start + (ret.len() as u32); + self.window_end = self.window_end + (ret.len() as u32); self.data.clear(); return ret; } @@ -78,10 +76,8 @@ impl ReliableOrderedBytestream { let last_missing_idx = self.missing.len() - 1; let mut last_missing = &mut self.missing[last_missing_idx]; last_missing.1 = last_missing.1 + (data_len as u32); - self.window_start = - self.window_start + (data_len as u32); - self.window_end = - self.window_end + (data_len as u32); + self.window_start = self.window_start + (data_len as u32); + self.window_end = self.window_end + (data_len as u32); let mut ret = Vec::with_capacity(data_len); let (first, second) = self.data[..].split_at(self.pivot as usize); @@ -141,7 +137,7 @@ impl ReliableOrderedBytestream { // [....chunk....] // [...missing...] copy_ranges.push((missing_from, offset_end)); - el.0 +=((offset_end - missing_from) + 1) as u32; + el.0 += ((offset_end - missing_from) + 1) as u32; } } else if missing_from < offset { if missing_to > offset_end { @@ -157,8 +153,7 @@ impl ReliableOrderedBytestream { // [....chunk....] // [...missing...] copy_ranges.push((offset, (missing_to - 0))); - el.1 = - el.0 + (((offset_end - missing_from) - 1) as u32); + el.1 = el.0 + (((offset_end - missing_from) - 1) as u32); } } } @@ -171,7 +166,7 @@ impl ReliableOrderedBytestream { } self.missing.append(&mut to_add); self.missing - .sort_by(|(from_a, _), (from_b, _)| from_a.0 .0.cmp(&from_b.0 .0)); + .sort_by(|(from_a, _), (from_b, _)| from_a.0.0.cmp(&from_b.0.0)); // copy only the missing data let (first, second) = self.data[..].split_at_mut(self.pivot as usize); diff --git a/src/connection/stream/uud/mod.rs b/src/connection/stream/uud/mod.rs index f14e1f7..ebde42e 100644 --- a/src/connection/stream/uud/mod.rs +++ b/src/connection/stream/uud/mod.rs @@ -7,80 +7,70 @@ use crate::{ connection::stream::{ - Chunk, Error, Sequence, SequenceEnd, SequenceStart, StreamData, + Chunk, Error, Fragment, Sequence, SequenceEnd, SequenceStart, + StreamData, }, enc::Random, }; +use ::core::{ + cmp::{self, Ordering}, + marker::PhantomData, + num::Wrapping, + ops, +}; use ::std::collections::{BTreeMap, VecDeque}; #[cfg(test)] mod tests; -#[derive(Debug, PartialEq, Eq)] -enum Fragment { - Start((Sequence, Sequence)), - Middle((Sequence, Sequence)), - End((Sequence, Sequence)), - Full((Sequence, Sequence)), - Delivered((Sequence, Sequence)), -} - -impl Fragment { - // FIXME: sequence start/end? - fn get_seqs(&self) -> (Sequence, Sequence) { - match self { - Fragment::Start((f, t)) - | Fragment::Middle((f, t)) - | Fragment::End((f, t)) - | Fragment::Full((f, t)) - | Fragment::Delivered((f, t)) => (*f, *t), - } - } - fn is_start(&self) -> bool { - match self { - Fragment::Start(_) | Fragment::Full(_) | Fragment::Delivered(_) => { - true - } - Fragment::End(_) | Fragment::Middle(_) => false, - } - } - fn is_end(&self) -> bool { - match self { - Fragment::End(_) | Fragment::Full(_) | Fragment::Delivered(_) => { - true - } - Fragment::Start(_) | Fragment::Middle(_) => false, - } - } -} -/* -impl ::std::cmp::PartialEq for Fragment { - fn eq(&self, other: &Sequence) -> bool { - self.get_seq() == *other - } -} - -impl ::std::cmp::PartialOrd for Fragment { - fn partial_cmp(&self, other: &Sequence) -> Option<::std::cmp::Ordering> { - Some(self.get_seq().cmp(other)) - } -} - -impl ::std::cmp::PartialOrd for Fragment { - fn partial_cmp(&self, other: &Fragment) -> Option<::std::cmp::Ordering> { - Some(self.get_seq().cmp(&other.get_seq())) - } -} -impl Ord for Fragment { - fn cmp(&self, other: &Fragment) -> ::std::cmp::Ordering { - self.get_seq().cmp(&other.get_seq()) - } -} -*/ - type Timer = u64; +pub struct Data<'a> { + data_first: &'a mut [u8], + data_second: &'a mut [u8], + pub from: SequenceStart, + //pub(crate) stream: &'a Uud, + pub(crate) stream: ::std::ptr::NonNull, + _not_send_sync: PhantomData<*const ()>, +} + +impl<'a> Drop for Data<'a> { + fn drop(&mut self) { + // safe because we are !Send + #[allow(unsafe_code)] + unsafe { + let uud = self.stream.as_mut(); + uud.free( + self.from, + (self.data_first.len() + self.data_second.len()) as u32, + ); + } + } +} + +impl<'a> ops::Index for Data<'a> { + type Output = u8; + + fn index(&self, index: usize) -> &Self::Output { + let first_len = self.data_first.len(); + if index < first_len { + return &self.data_first[index]; + } + return &self.data_second[index - first_len]; + } +} + +impl<'a> ops::IndexMut for Data<'a> { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + let first_len = self.data_first.len(); + if index < first_len { + return &mut self.data_first[index]; + } + return &mut self.data_second[index - first_len]; + } +} + pub(crate) struct Uud { pub(crate) window_start: SequenceStart, window_end: SequenceEnd, @@ -128,6 +118,78 @@ impl Uud { pub(crate) fn window_size(&self) -> u32 { self.data.len() as u32 } + pub(crate) fn get(&mut self) -> Option { + let self_ptr = ::std::ptr::NonNull::new(self).unwrap(); + for track in self.track.iter_mut() { + if let Fragment::Ready((start, end)) = track.0 { + let data_from = (self.window_start.offset(start.0) + + self.pivot as usize) + % self.data.len(); + let data_to = (self.window_start.offset(end.0) + + self.pivot as usize) + % self.data.len(); + + track.0 = Fragment::Delivered((start, end)); + let first: &mut [u8]; + let second: &mut [u8]; + if data_from < data_to { + let (tmp_first, tmp_second) = + self.data.split_at_mut(data_to); + first = &mut tmp_first[data_from..]; + second = &mut tmp_second[0..0]; + } else { + let (tmp_second, tmp_first) = + self.data.split_at_mut(self.pivot as usize); + first = &mut tmp_first[(data_from - self.pivot as usize)..]; + second = &mut tmp_second[..data_to]; + } + + return Some(Data { + from: start, + data_first: first, + data_second: second, + stream: self_ptr, + _not_send_sync: PhantomData::default(), + }); + } + } + None + } + pub(crate) fn free(&mut self, from: SequenceStart, len: u32) { + if !from.0.is_between(self.window_start, self.window_end) { + return; + } + let mut first_keep = 0; + let mut last_sequence = self.window_start.0; + let mut deallocated = false; + for (idx, track) in self.track.iter_mut().enumerate() { + if let Fragment::Delivered((start, to)) = track.0 { + if start == from && to.0 == from.0 + len { + track.0 = Fragment::Deallocated((start, to)); + deallocated = true; + if idx == first_keep { + first_keep = idx + 1; + last_sequence = to.0; + continue; + } + } + } + if idx == first_keep { + if let Fragment::Deallocated((_, to)) = track.0 { + first_keep = idx + 1; + last_sequence = to.0; + continue; + } + } + if deallocated { + break; + } + } + self.track.drain(..first_keep); + self.pivot = ((self.pivot as usize + + self.window_start.offset(last_sequence)) + % self.data.len()) as u32; + } pub(crate) fn recv(&mut self, chunk: Chunk) -> Result { let chunk_to = chunk.sequence + chunk.data.len() as u32; if !chunk @@ -164,16 +226,14 @@ impl Uud { let mut last_usable = self.window_end.0; let mut ret = StreamData::NotReady; let mut copy_data_idx_from = 0; - // FIXME: and on receiving first fragment after the second, or empty - // track? for (idx, (fragment, _)) in self.track.iter_mut().enumerate().rev() { let (from, to) = fragment.get_seqs(); let to_next = to + 1; - match to_next.cmp_in_window(self.window_start, chunk.sequence) { - ::core::cmp::Ordering::Equal => { + match to_next.0.cmp_in_window(self.window_start, chunk.sequence) { + Ordering::Equal => { // `chunk` is immediately after `fragment` if !chunk_to.is_between( - SequenceStart(to_next), + SequenceStart(to_next.0), SequenceEnd(last_usable), ) { return Err(Error::Reconstructing); @@ -186,7 +246,7 @@ impl Uud { return Err(Error::WrongFlags); } if chunk_flag_end { - *fragment = Fragment::Full(( + *fragment = Fragment::Ready(( from, to + (data.len() as u32), )); @@ -217,16 +277,23 @@ impl Uud { as usize; } Fragment::End(_) - | Fragment::Full(_) - | Fragment::Delivered(_) => { + | Fragment::Ready(_) + | Fragment::Delivered(_) + | Fragment::Deallocated(_) => { if !chunk.flag_start { return Err(Error::WrongFlags); } let toinsert = if chunk_flag_end { ret = StreamData::Ready; - Fragment::Full((chunk.sequence, chunk_to)) + Fragment::Ready(( + SequenceStart(chunk.sequence), + SequenceEnd(chunk_to), + )) } else { - Fragment::Start((chunk.sequence, chunk_to)) + Fragment::Start(( + SequenceStart(chunk.sequence), + SequenceEnd(chunk_to), + )) }; self.track.insert(idx + 1, (toinsert, 0)); copy_data_idx_from = @@ -236,11 +303,11 @@ impl Uud { } break; } - ::core::cmp::Ordering::Less => { + Ordering::Less => { // there is a data hole between `chunk` and `fragment` if !chunk_to.is_between( - SequenceStart(to_next), + SequenceStart(to_next.0), SequenceEnd(last_usable), ) { return Err(Error::Reconstructing); @@ -248,15 +315,27 @@ impl Uud { let toinsert = if chunk.flag_start { if chunk_flag_end { ret = StreamData::Ready; - Fragment::Full((chunk.sequence, chunk_to)) + Fragment::Ready(( + SequenceStart(chunk.sequence), + SequenceEnd(chunk_to), + )) } else { - Fragment::Start((chunk.sequence, chunk_to)) + Fragment::Start(( + SequenceStart(chunk.sequence), + SequenceEnd(chunk_to), + )) } } else { if chunk_flag_end { - Fragment::End((chunk.sequence, chunk_to)) + Fragment::End(( + SequenceStart(chunk.sequence), + SequenceEnd(chunk_to), + )) } else { - Fragment::Middle((chunk.sequence, chunk_to)) + Fragment::Middle(( + SequenceStart(chunk.sequence), + SequenceEnd(chunk_to), + )) } }; self.track.insert(idx + 1, (toinsert, 0)); @@ -264,12 +343,12 @@ impl Uud { chunk.sequence.diff_from(self.window_start.0) as usize; break; } - ::core::cmp::Ordering::Greater => { + Ordering::Greater => { // to_next > chunk.sequence // `fragment` is too new, need to look at older ones - if from.cmp_in_window(self.window_start, chunk.sequence) - != ::core::cmp::Ordering::Greater + if from.0.cmp_in_window(self.window_start, chunk.sequence) + != Ordering::Greater { // to_next > chunk.sequence >= from // overlapping not yet allowed @@ -277,44 +356,44 @@ impl Uud { } if idx == 0 { // check if we can add before everything - if chunk_to == from { - if fragment.is_start() { - if chunk_flag_end { - // add, don't merge - } else { - //fragment.start, but !chunk.end - return Err(Error::WrongFlags); - } - } else { - if chunk_flag_end { - //chunk.end but !fragment.start - return Err(Error::WrongFlags); - } else { + if chunk_to == from.0 { + match fragment { + Fragment::Middle(_) => { if chunk.flag_start { - if fragment.is_end() { - *fragment = Fragment::Full(( - chunk.sequence, - to, - )); - ret = StreamData::Ready; - } else { - *fragment = Fragment::Start(( - chunk.sequence, - to, - )); - } + *fragment = Fragment::Start(( + SequenceStart(chunk.sequence), + to, + )); } else { - if fragment.is_end() { - *fragment = Fragment::End(( - chunk.sequence, - to, - )); - } else { - *fragment = Fragment::Middle(( - chunk.sequence, - to, - )); - } + *fragment = Fragment::Middle(( + SequenceStart(chunk.sequence), + to, + )); + } + } + Fragment::End(_) => { + if chunk.flag_start { + *fragment = Fragment::Ready(( + SequenceStart(chunk.sequence), + to, + )); + ret = StreamData::Ready; + } else { + *fragment = Fragment::End(( + SequenceStart(chunk.sequence), + to, + )); + } + } + Fragment::Start(_) + | Fragment::Ready(_) + | Fragment::Delivered(_) + | Fragment::Deallocated(_) => { + if chunk_flag_end { + // add, don't merge + } else { + // fragment.start, but !chunk.end + return Err(Error::WrongFlags); } } } @@ -327,15 +406,27 @@ impl Uud { let toinsert = if chunk.flag_start { if chunk_flag_end { ret = StreamData::Ready; - Fragment::Full((chunk.sequence, chunk_to)) + Fragment::Ready(( + SequenceStart(chunk.sequence), + SequenceEnd(chunk_to), + )) } else { - Fragment::Start((chunk.sequence, chunk_to)) + Fragment::Start(( + SequenceStart(chunk.sequence), + SequenceEnd(chunk_to), + )) } } else { if chunk_flag_end { - Fragment::End((chunk.sequence, chunk_to)) + Fragment::End(( + SequenceStart(chunk.sequence), + SequenceEnd(chunk_to), + )) } else { - Fragment::Middle((chunk.sequence, chunk_to)) + Fragment::Middle(( + SequenceStart(chunk.sequence), + SequenceEnd(chunk_to), + )) } }; self.track.insert(0, (toinsert, 0)); @@ -344,7 +435,7 @@ impl Uud { as usize; break; } - last_usable = from - 1; + last_usable = from.0 - 1; } } } @@ -357,7 +448,7 @@ impl Uud { let data_pivot = self.data.len() - data_idx_from; let (first, second) = data.split_at(data_pivot); self.data[data_idx_from..].copy_from_slice(&first); - self.data[..data_idx_to].copy_from_slice(&data); + self.data[..data_idx_to].copy_from_slice(&second); } Ok(ret) } @@ -419,9 +510,8 @@ impl UnreliableUnorderedDatagram { ret.extend_from_slice(first); ret.extend_from_slice(second); self.window_start += ret.len() as u32; - self.window_end = SequenceEnd(Sequence( - ::core::num::Wrapping::(ret.len() as u32), - )); + self.window_end = + SequenceEnd(Sequence(Wrapping::(ret.len() as u32))); self.data.clear(); return ret; } @@ -434,7 +524,7 @@ impl UnreliableUnorderedDatagram { let mut ret = Vec::with_capacity(data_len); let (first, second) = self.data[..].split_at(self.pivot as usize); - let first_len = ::core::cmp::min(data_len, first.len()); + let first_len = cmp::min(data_len, first.len()); let second_len = data_len - first_len; ret.extend_from_slice(&first[..first_len]); @@ -527,7 +617,7 @@ impl UnreliableUnorderedDatagram { let to = to + 1; if from <= first.len() { let first_from = from; - let first_to = ::core::cmp::min(first.len(), to); + let first_to = cmp::min(first.len(), to); let data_first_from = from - offset; let data_first_to = first_to - offset; first[first_from..first_to] -- 2.47.2