diff --git a/src/connection/mod.rs b/src/connection/mod.rs index 58853c7..ea5dfcb 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -5,7 +5,11 @@ pub mod packet; pub mod socket; pub mod stream; -use ::std::{collections::HashMap, rc::Rc, vec::Vec}; +use ::core::num::Wrapping; +use ::std::{ + collections::{BTreeMap, HashMap}, + vec::Vec, +}; pub use crate::connection::{handshake::Handshake, packet::Packet}; @@ -125,12 +129,12 @@ impl ProtocolVersion { } } -#[derive(Debug, Clone, Copy, Eq, Hash, PartialEq)] -pub(crate) struct UserConnTracker(usize); +#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)] +pub(crate) struct UserConnTracker(Wrapping); impl UserConnTracker { fn advance(&mut self) -> Self { let old = self.0; - self.0 = self.0 + 1; + self.0 = self.0 + Wrapping(1); UserConnTracker(old) } } @@ -151,8 +155,13 @@ pub struct Conn { impl Conn { /// Queue some data to be sent in this connection - pub fn send(&mut self, stream: stream::ID, _data: Vec) { - todo!() + // TODO: send_and_wait, that wait for recipient ACK + pub async fn send(&mut self, stream: stream::ID, data: Vec) { + use crate::inner::worker::Work; + let _ = self + .queue + .send(Work::UserSend((self.conn, stream, data))) + .await; } } @@ -160,15 +169,17 @@ impl Conn { #[derive(Debug)] pub(crate) struct Connection { /// Receiving Conn ID - pub id_recv: IDRecv, + pub(crate) id_recv: IDRecv, /// Sending Conn ID - pub id_send: IDSend, + pub(crate) id_send: IDSend, /// The main hkdf used for all secrets in this connection - pub hkdf: Hkdf, + hkdf: Hkdf, /// Cipher for decrypting data - pub cipher_recv: CipherRecv, + pub(crate) cipher_recv: CipherRecv, /// Cipher for encrypting data - pub cipher_send: CipherSend, + pub(crate) cipher_send: CipherSend, + /// send queue for each Stream + send_queue: BTreeMap>>, } /// Role: track the connection direction @@ -210,14 +221,22 @@ impl Connection { hkdf, cipher_recv, cipher_send, + send_queue: BTreeMap::new(), } } + pub(crate) fn send(&mut self, stream: stream::ID, data: Vec) { + let stream = match self.send_queue.get_mut(&stream) { + None => return, + Some(stream) => stream, + }; + stream.push(data); + } } pub(crate) struct ConnList { thread_id: ThreadTracker, - connections: Vec>>, - user_tracker: HashMap, + connections: Vec>, + user_tracker: BTreeMap, last_tracked: UserConnTracker, /// Bitmap to track which connection ids are used or free ids_used: Vec<::bitmaps::Bitmap<1024>>, @@ -234,13 +253,27 @@ impl ConnList { let mut ret = Self { thread_id, connections: Vec::with_capacity(INITIAL_CAP), - user_tracker: HashMap::with_capacity(INITIAL_CAP), - last_tracked: UserConnTracker(0), + user_tracker: BTreeMap::new(), + last_tracked: UserConnTracker(Wrapping(0)), ids_used: vec![bitmap_id], }; ret.connections.resize_with(INITIAL_CAP, || None); ret } + pub fn get_mut( + &mut self, + tracker: UserConnTracker, + ) -> Option<&mut Connection> { + let idx = if let Some(idx) = self.user_tracker.get(&tracker) { + *idx + } else { + return None; + }; + match &mut self.connections[idx] { + None => None, + Some(conn) => Some(conn), + } + } pub fn len(&self) -> usize { let mut total: usize = 0; for bitmap in self.ids_used.iter() { @@ -293,7 +326,7 @@ impl ConnList { /// NOTE: does NOT check if the connection has been previously reserved! pub(crate) fn track( &mut self, - conn: Rc, + conn: Connection, ) -> Result { let conn_id = match conn.id_recv { IDRecv(ID::Handshake) => { @@ -304,8 +337,14 @@ impl ConnList { let id_in_thread: usize = (conn_id.get() / (self.thread_id.total as u64)) as usize; self.connections[id_in_thread] = Some(conn); - let tracked = self.last_tracked.advance(); - let _ = self.user_tracker.insert(tracked, id_in_thread); + let mut tracked; + loop { + tracked = self.last_tracked.advance(); + if self.user_tracker.get(&tracked).is_none() { + let _ = self.user_tracker.insert(tracked, id_in_thread); + break; + } + } Ok(tracked) } pub(crate) fn remove(&mut self, id: IDRecv) { diff --git a/src/connection/stream/mod.rs b/src/connection/stream/mod.rs index 1b56a2b..58dd76e 100644 --- a/src/connection/stream/mod.rs +++ b/src/connection/stream/mod.rs @@ -19,7 +19,7 @@ pub enum Kind { } /// Id of the stream -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] pub struct ID(pub u16); impl ID { diff --git a/src/inner/mod.rs b/src/inner/mod.rs index 001ca16..e23d614 100644 --- a/src/inner/mod.rs +++ b/src/inner/mod.rs @@ -4,6 +4,10 @@ pub(crate) mod worker; +use crate::inner::worker::Work; +use ::std::{collections::BTreeMap, vec::Vec}; +use ::tokio::time::Instant; + /// Track the total number of threads and our index /// 65K cpus should be enough for anybody #[derive(Debug, Clone, Copy)] @@ -12,3 +16,83 @@ pub(crate) struct ThreadTracker { /// Note: starts from 1 pub id: u16, } + +pub(crate) static mut SLEEP_RESOLUTION: ::std::time::Duration = + if cfg!(linux) || cfg!(macos) { + ::std::time::Duration::from_millis(1) + } else { + // windows + ::std::time::Duration::from_millis(16) + }; + +pub(crate) async fn set_minimum_sleep_resolution() { + let nanosleep = ::std::time::Duration::from_nanos(1); + let mut tests: usize = 3; + + while tests > 0 { + let pre_sleep = ::std::time::Instant::now(); + ::tokio::time::sleep(nanosleep).await; + let post_sleep = ::std::time::Instant::now(); + let slept_for = post_sleep - pre_sleep; + #[allow(unsafe_code)] + unsafe { + if slept_for < SLEEP_RESOLUTION { + SLEEP_RESOLUTION = slept_for; + } + } + tests = tests - 1; + } +} + +/// Sleeping has a higher resolution that we would like for packet pacing. +/// So we sleep for however log we need, then chunk up all the work here +/// we will end up chunking the work in SLEEP_RESOLUTION, then we will busy wait +/// for more precise timing +pub(crate) struct Timers { + times: BTreeMap, +} + +impl Timers { + pub(crate) fn new() -> Self { + Self { + times: BTreeMap::new(), + } + } + pub(crate) fn get_next(&self) -> ::tokio::time::Sleep { + match self.times.keys().next() { + Some(entry) => ::tokio::time::sleep_until((*entry).into()), + None => { + ::tokio::time::sleep(::std::time::Duration::from_secs(3600)) + } + } + } + /// Get all the work from now up until now + SLEEP_RESOLUTION + pub(crate) fn get_work(&mut self) -> Vec { + let now: ::tokio::time::Instant = ::std::time::Instant::now().into(); + let mut ret = Vec::with_capacity(4); + let mut count_rm = 0; + #[allow(unsafe_code)] + let next_instant = unsafe { now + SLEEP_RESOLUTION }; + let mut iter = self.times.iter_mut().peekable(); + loop { + match iter.peek() { + None => break, + Some(next) => { + if *next.0 > next_instant { + break; + } + } + } + let mut work = Work::DropHandshake(crate::enc::asym::KeyID(0)); + let mut entry = iter.next().unwrap(); + ::core::mem::swap(&mut work, &mut entry.1); + ret.push(work); + count_rm = count_rm + 1; + } + while count_rm > 0 { + self.times.pop_first(); + count_rm = count_rm - 1; + } + ret + } +} diff --git a/src/inner/worker.rs b/src/inner/worker.rs index 29ad7f7..0bb41d0 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -11,7 +11,8 @@ use crate::{ }, packet::{self, Packet}, socket::{UdpClient, UdpServer}, - AuthSrvConn, ConnList, Connection, IDSend, ServiceConn, + stream, AuthSrvConn, ConnList, Connection, IDSend, ServiceConn, + UserConnTracker, }, dnssec, enc::{ @@ -51,6 +52,7 @@ pub(crate) enum Work { Connect(ConnectInfo), DropHandshake(KeyID), Recv(RawUdp), + UserSend((UserConnTracker, stream::ID, Vec)), } /// Actual worker implementation. @@ -70,6 +72,7 @@ pub struct Worker { thread_channels: Vec<::async_channel::Sender>, connections: ConnList, handshakes: handshake::Tracker, + work_timers: super::Timers, } #[allow(unsafe_code)] @@ -126,12 +129,15 @@ impl Worker { thread_channels: Vec::new(), connections: ConnList::new(thread_id), handshakes, + work_timers: super::Timers::new(), }) } /// Continuously loop and process work as needed pub async fn work_loop(&mut self) { 'mainloop: loop { + let next_timer = self.work_timers.get_next(); + ::tokio::pin!(next_timer); let work = ::tokio::select! { tell_stopped = self.stop_working.recv() => { if let Ok(stop_ch) = tell_stopped { @@ -140,6 +146,13 @@ impl Worker { } break; } + () = &mut next_timer => { + let work_list = self.work_timers.get_work(); + for w in work_list.into_iter() { + let _ = self.queue_sender.send(w).await; + } + continue 'mainloop; + } maybe_timeout = self.queue.recv() => { match maybe_timeout { Ok(work) => work, @@ -419,6 +432,13 @@ impl Worker { Work::Recv(pkt) => { self.recv(pkt).await; } + Work::UserSend((tracker, stream, data)) => { + let conn = match self.connections.get_mut(tracker) { + None => return, + Some(conn) => conn, + }; + conn.send(stream, data); + } } } } @@ -596,21 +616,20 @@ impl Worker { let id_recv = conn.id_recv; let cipher = conn.cipher_recv.kind(); // track the connection to the authentication server - let track_auth_conn = - match self.connections.track(conn.into()) { - Ok(track_auth_conn) => track_auth_conn, - Err(e) => { - ::tracing::error!( - "Could not track new auth srv connection" - ); - self.connections.remove(id_recv); - // FIXME: proper connection closing - let _ = cci.answer.send(Err( - handshake::Error::InternalTracking.into(), - )); - return; - } - }; + let track_auth_conn = match self.connections.track(conn) { + Ok(track_auth_conn) => track_auth_conn, + Err(_) => { + ::tracing::error!( + "Could not track new auth srv connection" + ); + self.connections.remove(id_recv); + // FIXME: proper connection closing + let _ = cci.answer.send(Err( + handshake::Error::InternalTracking.into(), + )); + return; + } + }; let authsrv_conn = AuthSrvConn(connection::Conn { queue: self.queue_sender.clone(), conn: track_auth_conn, @@ -635,26 +654,25 @@ impl Worker { service_connection.id_recv = cci.service_connection_id; service_connection.id_send = IDSend(resp_data.service_connection_id); - let track_serv_conn = match self - .connections - .track(service_connection.into()) - { - Ok(track_serv_conn) => track_serv_conn, - Err(e) => { - ::tracing::error!( - "Could not track new service connection" - ); - self.connections - .remove(cci.service_connection_id); - // FIXME: proper connection closing - // FIXME: drop auth srv connection if we just - // established it - let _ = cci.answer.send(Err( - handshake::Error::InternalTracking.into(), - )); - return; - } - }; + let track_serv_conn = + match self.connections.track(service_connection) { + Ok(track_serv_conn) => track_serv_conn, + Err(_) => { + ::tracing::error!( + "Could not track new service connection" + ); + self.connections + .remove(cci.service_connection_id); + // FIXME: proper connection closing + // FIXME: drop auth srv connection if we just + // established it + let _ = cci.answer.send(Err( + handshake::Error::InternalTracking + .into(), + )); + return; + } + }; service_conn = Some(ServiceConn(connection::Conn { queue: self.queue_sender.clone(), conn: track_serv_conn, diff --git a/src/lib.rs b/src/lib.rs index a0ee7ae..8ac887a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -176,6 +176,7 @@ impl Fenrir { config: &Config, tokio_rt: Arc<::tokio::runtime::Runtime>, ) -> Result { + inner::set_minimum_sleep_resolution().await; let (sender, _) = ::tokio::sync::broadcast::channel(1); let dnssec = dnssec::Dnssec::new(&config.resolvers)?; // bind sockets early so we can change "port 0" (aka: random) @@ -214,6 +215,7 @@ impl Fenrir { pub async fn with_workers( config: &Config, ) -> Result<(Self, Vec), Error> { + inner::set_minimum_sleep_resolution().await; let (stop_working, _) = ::tokio::sync::broadcast::channel(1); let dnssec = dnssec::Dnssec::new(&config.resolvers)?; // bind sockets early so we can change "port 0" (aka: random)