diff --git a/TODO b/TODO index 85a5331..33fa189 100644 --- a/TODO +++ b/TODO @@ -1,4 +1,6 @@ * Wrapping for everything that wraps (sigh) * track user connection (add u64 from user) -* split API in LocalThread and ThreadSafe - * split send/recv API in Centralized, Decentralized +* API plit + * split API in ThreadLocal, ThreadSafe + * split send/recv API in Centralized, Connection + * all re wrappers on ThreadLocal-Centralized diff --git a/src/connection/mod.rs b/src/connection/mod.rs index 07af86e..e205e33 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -7,7 +7,7 @@ pub mod stream; use ::core::num::Wrapping; use ::std::{ - collections::{BTreeMap, HashMap}, + collections::{BTreeMap, HashMap, VecDeque}, vec::Vec, }; @@ -26,14 +26,21 @@ use crate::{ inner::{worker, ThreadTracker}, }; -/// Connaction errors +/// Connection errors #[derive(::thiserror::Error, Debug, Copy, Clone)] -pub(crate) enum Error { +pub enum Error { /// Can't decrypt packet #[error("Decrypt error: {0}")] Decrypt(#[from] crate::enc::Error), + /// Error in parsing a packet realated to the connection #[error("Chunk parsing: {0}")] Parse(#[from] stream::Error), + /// No such Connection + #[error("No suck connection")] + NoSuchConnection, + /// No such Stream + #[error("No suck Stream")] + NoSuchStream, } /// Fenrir Connection ID @@ -141,32 +148,53 @@ impl ProtocolVersion { } } +/// Connection tracking id. Set by the user +#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)] +pub struct UserTracker(pub ::core::num::NonZeroU64); + /// Unique tracker of connections #[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)] -pub struct ConnTracker(Wrapping); -impl ConnTracker { +pub struct LibTracker(Wrapping); +impl LibTracker { pub(crate) fn new(start: u16) -> Self { Self(Wrapping(start as u64)) } pub(crate) fn advance(&mut self, amount: u16) -> Self { let old = self.0; self.0 = self.0 + Wrapping(amount as u64); - ConnTracker(old) + LibTracker(old) } } +/// Collection of connection tracking, but user-given and library generated +#[derive(Debug, Copy, Clone)] +pub struct ConnTracker { + /// Optional tracker set by the user + pub user: Option, + /// library generated tracker. Unique and non-repeating + pub(crate) lib: LibTracker, +} + +impl PartialEq for ConnTracker { + fn eq(&self, other: &Self) -> bool { + self.lib == other.lib + } +} +impl Eq for ConnTracker {} /// Connection to an Authentication Server -#[derive(Debug)] -pub struct AuthSrvConn(pub Conn); +#[derive(Debug, Copy, Clone)] +pub struct AuthSrvConn(pub ConnTracker); /// Connection to a service -#[derive(Debug)] -pub struct ServiceConn(pub Conn); +#[derive(Debug, Copy, Clone)] +pub struct ServiceConn(pub ConnTracker); +/* + * TODO: only on Thread{Local,Safe}::Connection oriented flows /// The connection, as seen from a user of libFenrir #[derive(Debug)] pub struct Conn { pub(crate) queue: ::async_channel::Sender, - pub(crate) fast: ConnTracker, + pub(crate) tracker: ConnTracker, } impl Conn { @@ -176,14 +204,15 @@ impl Conn { use crate::inner::worker::Work; let _ = self .queue - .send(Work::UserSend((self.tracker(), stream, data))) + .send(Work::UserSend((self.tracker.lib, stream, data))) .await; } /// Get the library tracking id pub fn tracker(&self) -> ConnTracker { - self.fast + self.tracker } } +*/ /// Role: track the connection direction /// @@ -208,15 +237,10 @@ enum TimerKind { } pub(crate) enum Enqueue { - NoSuchStream, TimerWait, Immediate(::tokio::time::Instant), } -/// Connection tracking id. Set by the user -#[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Copy, Clone)] -pub struct UserTracker(pub ::core::num::NonZeroU64); - /// A single connection and its data #[derive(Debug)] pub(crate) struct Connection { @@ -227,6 +251,7 @@ pub(crate) struct Connection { /// User-managed id to track this connection /// the user can set this to better track this connection pub(crate) user_tracker: Option, + pub(crate) lib_tracker: LibTracker, /// Sending address pub(crate) send_addr: UdpClient, /// The main hkdf used for all secrets in this connection @@ -242,6 +267,7 @@ pub(crate) struct Connection { last_stream_sent: stream::ID, /// receive queue for each Stream recv_queue: BTreeMap, + streams_ready: VecDeque, } impl Connection { @@ -267,6 +293,7 @@ impl Connection { id_recv: IDRecv(ID::Handshake), id_send: IDSend(ID::Handshake), user_tracker: None, + lib_tracker: LibTracker::new(0), // will be overwritten send_addr: UdpClient(SocketAddr::new( IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), @@ -280,8 +307,24 @@ impl Connection { send_queue: BTreeMap::new(), last_stream_sent: stream::ID(0), recv_queue: BTreeMap::new(), + streams_ready: VecDeque::with_capacity(4), } } + pub(crate) fn get_data(&mut self) -> Option)>> { + if self.streams_ready.is_empty() { + return None; + } + let ret_len = self.streams_ready.len(); + let mut ret = Vec::with_capacity(ret_len); + while let Some(stream_id) = self.streams_ready.pop_front() { + let stream = match self.recv_queue.get_mut(&stream_id) { + Some(stream) => stream, + None => continue, + }; + ret.push((stream_id, stream.get())); + } + Some(ret) + } pub(crate) fn recv( &mut self, mut udp: crate::RawUdp, @@ -316,7 +359,12 @@ impl Connection { } }; match stream.recv(chunk) { - Ok(status) => data_ready = data_ready | status, + Ok(status) => { + if !self.streams_ready.contains(&stream_id) { + self.streams_ready.push_back(stream_id); + } + data_ready = data_ready | status; + } Err(e) => ::tracing::debug!("stream: {:?}: {:?}", stream_id, e), } } @@ -326,9 +374,9 @@ impl Connection { &mut self, stream: stream::ID, data: Vec, - ) -> Enqueue { + ) -> Result { let stream = match self.send_queue.get_mut(&stream) { - None => return Enqueue::NoSuchStream, + None => return Err(Error::NoSuchStream), Some(stream) => stream, }; stream.enqueue(data); @@ -348,7 +396,7 @@ impl Connection { TimerKind::SendData(old_timer) } }; - ret + Ok(ret) } pub(crate) fn write_pkt<'a>( &mut self, @@ -426,8 +474,8 @@ impl Connection { pub(crate) struct ConnList { thread_id: ThreadTracker, connections: Vec>, - user_tracker: BTreeMap, - last_tracked: ConnTracker, + user_tracker: BTreeMap, + last_tracked: LibTracker, /// Bitmap to track which connection ids are used or free ids_used: Vec<::bitmaps::Bitmap<1024>>, } @@ -444,32 +492,40 @@ impl ConnList { thread_id, connections: Vec::with_capacity(INITIAL_CAP), user_tracker: BTreeMap::new(), - last_tracked: ConnTracker(Wrapping(0)), + last_tracked: LibTracker(Wrapping(0)), ids_used: vec![bitmap_id], }; ret.connections.resize_with(INITIAL_CAP, || None); ret } - pub fn get_id_mut(&mut self, id: ID) -> Option<&mut Connection> { + pub fn get_id_mut(&mut self, id: ID) -> Result<&mut Connection, Error> { let conn_id = match id { - ID::Handshake => { - return None; - } ID::ID(conn_id) => conn_id, + ID::Handshake => { + return Err(Error::NoSuchConnection); + } }; let id_in_thread: usize = (conn_id.get() / (self.thread_id.total as u64)) as usize; - (&mut self.connections[id_in_thread]).into() + if let Some(conn) = &mut self.connections[id_in_thread] { + Ok(conn) + } else { + return Err(Error::NoSuchConnection); + } } - pub fn get_mut(&mut self, tracker: ConnTracker) -> Option<&mut Connection> { + pub fn get_mut( + &mut self, + tracker: LibTracker, + ) -> Result<&mut Connection, Error> { let idx = if let Some(idx) = self.user_tracker.get(&tracker) { *idx } else { - return None; + return Err(Error::NoSuchConnection); }; - match &mut self.connections[idx] { - None => None, - Some(conn) => Some(conn), + if let Some(conn) = &mut self.connections[idx] { + Ok(conn) + } else { + return Err(Error::NoSuchConnection); } } pub fn len(&self) -> usize { @@ -481,7 +537,23 @@ impl ConnList { } /// Only *Reserve* a connection, /// without actually tracking it in self.connections + pub(crate) fn reserve_and_track<'a>( + &'a mut self, + mut conn: Connection, + ) -> (LibTracker, &'a mut Connection) { + let (id_conn, id_in_thread) = self.reserve_first_with_idx(); + conn.id_recv = id_conn; + let tracker = self.get_new_tracker(id_in_thread); + conn.lib_tracker = tracker; + self.connections[id_in_thread] = Some(conn); + (tracker, self.connections[id_in_thread].as_mut().unwrap()) + } + /// Only *Reserve* a connection, + /// without actually tracking it in self.connections pub(crate) fn reserve_first(&mut self) -> IDRecv { + self.reserve_first_with_idx().0 + } + fn reserve_first_with_idx(&mut self) -> (IDRecv, usize) { // uhm... bad things are going on here: // * id must be initialized, but only because: // * rust does not understand that after the `!found` id is always @@ -519,13 +591,13 @@ impl ConnList { let actual_id = ((id_in_thread as u64) * (self.thread_id.total as u64)) + (self.thread_id.id as u64); let new_id = IDRecv(ID::new_u64(actual_id)); - new_id + (new_id, id_in_thread) } /// NOTE: does NOT check if the connection has been previously reserved! pub(crate) fn track( &mut self, - conn: Connection, - ) -> Result { + mut conn: Connection, + ) -> Result { let conn_id = match conn.id_recv { IDRecv(ID::Handshake) => { return Err(()); @@ -534,17 +606,22 @@ impl ConnList { }; let id_in_thread: usize = (conn_id.get() / (self.thread_id.total as u64)) as usize; + let tracker = self.get_new_tracker(id_in_thread); + conn.lib_tracker = tracker; self.connections[id_in_thread] = Some(conn); - let mut tracked; + Ok(tracker) + } + fn get_new_tracker(&mut self, id_in_thread: usize) -> LibTracker { + let mut tracker; loop { - tracked = self.last_tracked.advance(self.thread_id.total); - if self.user_tracker.get(&tracked).is_none() { + tracker = self.last_tracked.advance(self.thread_id.total); + if self.user_tracker.get(&tracker).is_none() { // like, never gonna happen, it's 64 bit - let _ = self.user_tracker.insert(tracked, id_in_thread); + let _ = self.user_tracker.insert(tracker, id_in_thread); break; } } - Ok(tracked) + tracker } pub(crate) fn remove(&mut self, id: IDRecv) { if let IDRecv(ID::ID(raw_id)) = id { diff --git a/src/connection/stream/mod.rs b/src/connection/stream/mod.rs index 425cbd8..44ac3be 100644 --- a/src/connection/stream/mod.rs +++ b/src/connection/stream/mod.rs @@ -279,6 +279,11 @@ impl Stream { Tracker::ROB(tracker) => tracker.recv(chunk), } } + pub(crate) fn get(&mut self) -> Vec { + match &mut self.data { + Tracker::ROB(tracker) => tracker.get(), + } + } } /// Track what has been sent and what has been ACK'd from a stream diff --git a/src/inner/worker.rs b/src/inner/worker.rs index c71b6f2..f344bb6 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -12,7 +12,7 @@ use crate::{ packet::{self, Packet}, socket::{UdpClient, UdpServer}, stream, AuthSrvConn, ConnList, ConnTracker, Connection, IDSend, - ServiceConn, + LibTracker, ServiceConn, }, dnssec, enc::{ @@ -23,7 +23,7 @@ use crate::{ }, inner::ThreadTracker, }; -use ::std::{sync::Arc, vec::Vec}; +use ::std::{collections::VecDeque, sync::Arc, vec::Vec}; /// This worker must be cpu-pinned use ::tokio::{ net::UdpSocket, @@ -46,14 +46,23 @@ pub(crate) struct ConnectInfo { // TODO: UserID, Token information } +/// return to the user the data received from a connection +#[derive(Debug, Clone)] +pub struct ConnData { + /// Connection tracking information + pub conn: ConnTracker, + /// received data, for each stream + pub data: Vec<(stream::ID, Vec)>, +} + /// Connection event. Mostly used to give the data to the user -#[derive(Debug, Eq, PartialEq, Clone)] +#[derive(Debug, Clone)] #[non_exhaustive] pub enum Event { /// Work loop has exited. nothing more to do End, /// Data from a connection - Data(Vec), + Data(ConnData), } pub(crate) enum Work { @@ -63,8 +72,8 @@ pub(crate) enum Work { Connect(ConnectInfo), DropHandshake(KeyID), Recv(RawUdp), - UserSend((ConnTracker, stream::ID, Vec)), - SendData((ConnTracker, ::tokio::time::Instant)), + UserSend((LibTracker, stream::ID, Vec)), + SendData((LibTracker, ::tokio::time::Instant)), } /// Actual worker implementation. @@ -83,6 +92,8 @@ pub struct Worker { queue_timeouts_send: mpsc::UnboundedSender, thread_channels: Vec<::async_channel::Sender>, connections: ConnList, + // connectsion untracker by the user. (users still needs to get(..) them) + untracked_connections: VecDeque, handshakes: handshake::Tracker, work_timers: super::Timers, } @@ -140,10 +151,82 @@ impl Worker { queue_timeouts_send, thread_channels: Vec::new(), connections: ConnList::new(thread_id), + untracked_connections: VecDeque::with_capacity(8), handshakes, work_timers: super::Timers::new(), }) } + /// return a handle to the worker that you can use to send data + /// The handle will enqueue work in the main worker and is thread-local safe + /// + /// While this does not require `&mut` on the `Worker`, everything + /// will be put in the work queue, + /// So you might have less immediate results in a few cases + pub fn handle(&self) -> Handle { + Handle { + queue: self.queue_sender.clone(), + } + } + /// change the UserTracker in the connection + /// + /// This is `unsafe` because you will be responsible for manually updating + /// any copy of the `ConnTracker` you might have cloned around + #[allow(unsafe_code)] + pub unsafe fn set_connection_tracker( + &mut self, + tracker: ConnTracker, + new_id: connection::UserTracker, + ) -> Result { + let conn = self.connections.get_mut(tracker.lib)?; + conn.user_tracker = Some(new_id); + Ok(ConnTracker { + lib: tracker.lib, + user: Some(new_id), + }) + } + /// Enqueue data to send + pub fn send( + &mut self, + tracker: LibTracker, + stream: stream::ID, + data: Vec, + ) -> Result<(), crate::Error> { + let conn = self.connections.get_mut(tracker)?; + conn.enqueue(stream, data)?; + Ok(()) + } + /// Returns new connections, if any + /// + /// You can provide an optional tracker, different from the library tracker. + /// + /// Differently from the library tracker, you can change this later on, + /// but you will be responsible to change it on every `ConnTracker` + /// you might have cloned elsewhere + pub fn try_get_connection( + &mut self, + tracker: Option, + ) -> Option { + let ret_tracker = ConnTracker { + lib: self.untracked_connections.pop_front()?, + user: None, + }; + match tracker { + Some(tracker) => { + #[allow(unsafe_code)] + match unsafe { + self.set_connection_tracker(ret_tracker, tracker) + } { + Ok(tracker) => Some(tracker), + Err(_) => { + // we had a connection, but it expired before the user + // remembered to get it. Just remove it from the queue. + None + } + } + } + None => Some(ret_tracker), + } + } /// Continuously loop and process work as needed pub async fn work_loop(&mut self) -> Result { @@ -441,27 +524,25 @@ impl Worker { } }; } - Work::Recv(pkt) => { - self.recv(pkt).await; - } + Work::Recv(pkt) => match self.recv(pkt).await { + Ok(event) => return Ok(event), + Err(_) => continue 'mainloop, + }, Work::UserSend((tracker, stream, data)) => { let conn = match self.connections.get_mut(tracker) { - None => continue, - Some(conn) => conn, + Ok(conn) => conn, + Err(_) => continue 'mainloop, }; use connection::Enqueue; - match conn.enqueue(stream, data) { - Enqueue::Immediate(instant) => { - let _ = self - .queue_sender - .send(Work::SendData((tracker, instant))) - .await; - } - Enqueue::TimerWait => {} - Enqueue::NoSuchStream => { - ::tracing::error!( - "Trying to send on unknown stream" - ); + if let Ok(enqueued) = conn.enqueue(stream, data) { + match enqueued { + Enqueue::Immediate(instant) => { + let _ = self + .queue_sender + .send(Work::SendData((tracker, instant))) + .await; + } + Enqueue::TimerWait => {} } } } @@ -484,8 +565,8 @@ impl Worker { let mut raw: Vec = Vec::with_capacity(1200); raw.resize(raw.capacity(), 0); let conn = match self.connections.get_mut(tracker) { - None => continue, - Some(conn) => conn, + Ok(conn) => conn, + Err(_) => continue, }; let pkt = match conn.write_pkt(&mut raw) { Ok(pkt) => pkt, @@ -506,7 +587,7 @@ impl Worker { Ok(Event::End) } /// Read and do stuff with the raw udp packet - async fn recv(&mut self, mut udp: RawUdp) { + async fn recv(&mut self, mut udp: RawUdp) -> Result { if udp.packet.id.is_handshake() { let handshake = match Handshake::deserialize( &udp.data[connection::ID::len()..], @@ -514,7 +595,7 @@ impl Worker { Ok(handshake) => handshake, Err(e) => { ::tracing::debug!("Handshake parsing: {}", e); - return; + return Err(()); } }; let action = match self.handshakes.recv_handshake( @@ -524,26 +605,34 @@ impl Worker { Ok(action) => action, Err(err) => { ::tracing::debug!("Handshake recv error {}", err); - return; + return Err(()); } }; self.recv_handshake(udp, action).await; + Err(()) } else { - self.recv_packet(udp); + self.recv_packet(udp) } } /// Receive a non-handshake packet - fn recv_packet(&mut self, udp: RawUdp) { + fn recv_packet(&mut self, udp: RawUdp) -> Result { let conn = match self.connections.get_id_mut(udp.packet.id) { - None => return, - Some(conn) => conn, + Ok(conn) => conn, + Err(_) => return Err(()), }; match conn.recv(udp) { - Ok(stream::StreamData::NotReady) => {} - Ok(stream::StreamData::Ready) => { - // + Ok(stream::StreamData::NotReady) => Err(()), + Ok(stream::StreamData::Ready) => Ok(Event::Data(ConnData { + conn: ConnTracker { + user: conn.user_tracker, + lib: conn.lib_tracker, + }, + data: conn.get_data().unwrap(), + })), + Err(e) => { + ::tracing::trace!("Conn Recv: {:?}", e.to_string()); + Err(()) } - Err(e) => ::tracing::trace!("Conn Recv: {:?}", e.to_string()), } } /// Receive an handshake packet @@ -625,6 +714,9 @@ impl Worker { // track connection let auth_id_recv = self.connections.reserve_first(); auth_conn.id_recv = auth_id_recv; + let (tracker, auth_conn) = + self.connections.reserve_and_track(auth_conn); + self.untracked_connections.push_back(tracker); let resp_data = dirsync::resp::Data { client_nonce: req_data.nonce, @@ -706,9 +798,9 @@ impl Worker { return; } }; - let authsrv_conn = AuthSrvConn(connection::Conn { - queue: self.queue_sender.clone(), - fast: track_auth_conn, + let authsrv_conn = AuthSrvConn(ConnTracker { + lib: track_auth_conn, + user: None, }); let mut service_conn = None; if cci.service_id != auth::SERVICEID_AUTH { @@ -748,9 +840,9 @@ impl Worker { return; } }; - service_conn = Some(ServiceConn(connection::Conn { - queue: self.queue_sender.clone(), - fast: track_serv_conn, + service_conn = Some(ServiceConn(ConnTracker { + lib: track_serv_conn, + user: None, })); } let _ = cci.answer.send(Ok(handshake::tracker::ConnectOk { @@ -786,3 +878,16 @@ impl Worker { let _res = src_sock.send_to(&data, client.0).await; } } + +/// Handle to send work asyncronously to the worker +#[derive(Debug, Clone)] +pub struct Handle { + queue: ::async_channel::Sender, +} + +impl Handle { + // TODO + // pub fn send(..) + // pub fn set_connection_id(..) + // try_get_new_connections() +} diff --git a/src/lib.rs b/src/lib.rs index 61af27a..347b646 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -59,15 +59,15 @@ pub enum Error { /// Handshake errors #[error("Handshake: {0:?}")] Handshake(#[from] handshake::Error), - /// Key error - #[error("key: {0:?}")] - Key(#[from] crate::enc::Error), /// Resolution problems. wrong or incomplete DNSSEC data #[error("DNSSEC resolution: {0}")] Resolution(String), /// Wrapper on encryption errors - #[error("Encrypt: {0}")] - Encrypt(enc::Error), + #[error("Crypto: {0}")] + Crypto(#[from] enc::Error), + /// Wrapper on connection errors + #[error("Connection: {0}")] + Connection(#[from] connection::Error), } pub(crate) enum StopWorking {