User conn tracking, enqueue data, timers

Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
This commit is contained in:
Luca Fulchir 2023-06-22 12:50:47 +02:00
parent 2fe91d5dd3
commit 9c67210e3e
Signed by: luca.fulchir
GPG Key ID: 8F6440603D13A78E
5 changed files with 198 additions and 55 deletions

View File

@ -5,7 +5,11 @@ pub mod packet;
pub mod socket; pub mod socket;
pub mod stream; pub mod stream;
use ::std::{collections::HashMap, rc::Rc, vec::Vec}; use ::core::num::Wrapping;
use ::std::{
collections::{BTreeMap, HashMap},
vec::Vec,
};
pub use crate::connection::{handshake::Handshake, packet::Packet}; pub use crate::connection::{handshake::Handshake, packet::Packet};
@ -125,12 +129,12 @@ impl ProtocolVersion {
} }
} }
#[derive(Debug, Clone, Copy, Eq, Hash, PartialEq)] #[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
pub(crate) struct UserConnTracker(usize); pub(crate) struct UserConnTracker(Wrapping<usize>);
impl UserConnTracker { impl UserConnTracker {
fn advance(&mut self) -> Self { fn advance(&mut self) -> Self {
let old = self.0; let old = self.0;
self.0 = self.0 + 1; self.0 = self.0 + Wrapping(1);
UserConnTracker(old) UserConnTracker(old)
} }
} }
@ -151,8 +155,13 @@ pub struct Conn {
impl Conn { impl Conn {
/// Queue some data to be sent in this connection /// Queue some data to be sent in this connection
pub fn send(&mut self, stream: stream::ID, _data: Vec<u8>) { // TODO: send_and_wait, that wait for recipient ACK
todo!() pub async fn send(&mut self, stream: stream::ID, data: Vec<u8>) {
use crate::inner::worker::Work;
let _ = self
.queue
.send(Work::UserSend((self.conn, stream, data)))
.await;
} }
} }
@ -160,15 +169,17 @@ impl Conn {
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct Connection { pub(crate) struct Connection {
/// Receiving Conn ID /// Receiving Conn ID
pub id_recv: IDRecv, pub(crate) id_recv: IDRecv,
/// Sending Conn ID /// Sending Conn ID
pub id_send: IDSend, pub(crate) id_send: IDSend,
/// The main hkdf used for all secrets in this connection /// The main hkdf used for all secrets in this connection
pub hkdf: Hkdf, hkdf: Hkdf,
/// Cipher for decrypting data /// Cipher for decrypting data
pub cipher_recv: CipherRecv, pub(crate) cipher_recv: CipherRecv,
/// Cipher for encrypting data /// Cipher for encrypting data
pub cipher_send: CipherSend, pub(crate) cipher_send: CipherSend,
/// send queue for each Stream
send_queue: BTreeMap<stream::ID, Vec<Vec<u8>>>,
} }
/// Role: track the connection direction /// Role: track the connection direction
@ -210,14 +221,22 @@ impl Connection {
hkdf, hkdf,
cipher_recv, cipher_recv,
cipher_send, cipher_send,
send_queue: BTreeMap::new(),
} }
} }
pub(crate) fn send(&mut self, stream: stream::ID, data: Vec<u8>) {
let stream = match self.send_queue.get_mut(&stream) {
None => return,
Some(stream) => stream,
};
stream.push(data);
}
} }
pub(crate) struct ConnList { pub(crate) struct ConnList {
thread_id: ThreadTracker, thread_id: ThreadTracker,
connections: Vec<Option<Rc<Connection>>>, connections: Vec<Option<Connection>>,
user_tracker: HashMap<UserConnTracker, usize>, user_tracker: BTreeMap<UserConnTracker, usize>,
last_tracked: UserConnTracker, last_tracked: UserConnTracker,
/// Bitmap to track which connection ids are used or free /// Bitmap to track which connection ids are used or free
ids_used: Vec<::bitmaps::Bitmap<1024>>, ids_used: Vec<::bitmaps::Bitmap<1024>>,
@ -234,13 +253,27 @@ impl ConnList {
let mut ret = Self { let mut ret = Self {
thread_id, thread_id,
connections: Vec::with_capacity(INITIAL_CAP), connections: Vec::with_capacity(INITIAL_CAP),
user_tracker: HashMap::with_capacity(INITIAL_CAP), user_tracker: BTreeMap::new(),
last_tracked: UserConnTracker(0), last_tracked: UserConnTracker(Wrapping(0)),
ids_used: vec![bitmap_id], ids_used: vec![bitmap_id],
}; };
ret.connections.resize_with(INITIAL_CAP, || None); ret.connections.resize_with(INITIAL_CAP, || None);
ret ret
} }
pub fn get_mut(
&mut self,
tracker: UserConnTracker,
) -> Option<&mut Connection> {
let idx = if let Some(idx) = self.user_tracker.get(&tracker) {
*idx
} else {
return None;
};
match &mut self.connections[idx] {
None => None,
Some(conn) => Some(conn),
}
}
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
let mut total: usize = 0; let mut total: usize = 0;
for bitmap in self.ids_used.iter() { for bitmap in self.ids_used.iter() {
@ -293,7 +326,7 @@ impl ConnList {
/// NOTE: does NOT check if the connection has been previously reserved! /// NOTE: does NOT check if the connection has been previously reserved!
pub(crate) fn track( pub(crate) fn track(
&mut self, &mut self,
conn: Rc<Connection>, conn: Connection,
) -> Result<UserConnTracker, ()> { ) -> Result<UserConnTracker, ()> {
let conn_id = match conn.id_recv { let conn_id = match conn.id_recv {
IDRecv(ID::Handshake) => { IDRecv(ID::Handshake) => {
@ -304,8 +337,14 @@ impl ConnList {
let id_in_thread: usize = let id_in_thread: usize =
(conn_id.get() / (self.thread_id.total as u64)) as usize; (conn_id.get() / (self.thread_id.total as u64)) as usize;
self.connections[id_in_thread] = Some(conn); self.connections[id_in_thread] = Some(conn);
let tracked = self.last_tracked.advance(); let mut tracked;
let _ = self.user_tracker.insert(tracked, id_in_thread); loop {
tracked = self.last_tracked.advance();
if self.user_tracker.get(&tracked).is_none() {
let _ = self.user_tracker.insert(tracked, id_in_thread);
break;
}
}
Ok(tracked) Ok(tracked)
} }
pub(crate) fn remove(&mut self, id: IDRecv) { pub(crate) fn remove(&mut self, id: IDRecv) {

View File

@ -19,7 +19,7 @@ pub enum Kind {
} }
/// Id of the stream /// Id of the stream
#[derive(Debug, Copy, Clone)] #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct ID(pub u16); pub struct ID(pub u16);
impl ID { impl ID {

View File

@ -4,6 +4,10 @@
pub(crate) mod worker; pub(crate) mod worker;
use crate::inner::worker::Work;
use ::std::{collections::BTreeMap, vec::Vec};
use ::tokio::time::Instant;
/// Track the total number of threads and our index /// Track the total number of threads and our index
/// 65K cpus should be enough for anybody /// 65K cpus should be enough for anybody
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
@ -12,3 +16,83 @@ pub(crate) struct ThreadTracker {
/// Note: starts from 1 /// Note: starts from 1
pub id: u16, pub id: u16,
} }
pub(crate) static mut SLEEP_RESOLUTION: ::std::time::Duration =
if cfg!(linux) || cfg!(macos) {
::std::time::Duration::from_millis(1)
} else {
// windows
::std::time::Duration::from_millis(16)
};
pub(crate) async fn set_minimum_sleep_resolution() {
let nanosleep = ::std::time::Duration::from_nanos(1);
let mut tests: usize = 3;
while tests > 0 {
let pre_sleep = ::std::time::Instant::now();
::tokio::time::sleep(nanosleep).await;
let post_sleep = ::std::time::Instant::now();
let slept_for = post_sleep - pre_sleep;
#[allow(unsafe_code)]
unsafe {
if slept_for < SLEEP_RESOLUTION {
SLEEP_RESOLUTION = slept_for;
}
}
tests = tests - 1;
}
}
/// Sleeping has a higher resolution that we would like for packet pacing.
/// So we sleep for however log we need, then chunk up all the work here
/// we will end up chunking the work in SLEEP_RESOLUTION, then we will busy wait
/// for more precise timing
pub(crate) struct Timers {
times: BTreeMap<Instant, Work>,
}
impl Timers {
pub(crate) fn new() -> Self {
Self {
times: BTreeMap::new(),
}
}
pub(crate) fn get_next(&self) -> ::tokio::time::Sleep {
match self.times.keys().next() {
Some(entry) => ::tokio::time::sleep_until((*entry).into()),
None => {
::tokio::time::sleep(::std::time::Duration::from_secs(3600))
}
}
}
/// Get all the work from now up until now + SLEEP_RESOLUTION
pub(crate) fn get_work(&mut self) -> Vec<Work> {
let now: ::tokio::time::Instant = ::std::time::Instant::now().into();
let mut ret = Vec::with_capacity(4);
let mut count_rm = 0;
#[allow(unsafe_code)]
let next_instant = unsafe { now + SLEEP_RESOLUTION };
let mut iter = self.times.iter_mut().peekable();
loop {
match iter.peek() {
None => break,
Some(next) => {
if *next.0 > next_instant {
break;
}
}
}
let mut work = Work::DropHandshake(crate::enc::asym::KeyID(0));
let mut entry = iter.next().unwrap();
::core::mem::swap(&mut work, &mut entry.1);
ret.push(work);
count_rm = count_rm + 1;
}
while count_rm > 0 {
self.times.pop_first();
count_rm = count_rm - 1;
}
ret
}
}

View File

@ -11,7 +11,8 @@ use crate::{
}, },
packet::{self, Packet}, packet::{self, Packet},
socket::{UdpClient, UdpServer}, socket::{UdpClient, UdpServer},
AuthSrvConn, ConnList, Connection, IDSend, ServiceConn, stream, AuthSrvConn, ConnList, Connection, IDSend, ServiceConn,
UserConnTracker,
}, },
dnssec, dnssec,
enc::{ enc::{
@ -51,6 +52,7 @@ pub(crate) enum Work {
Connect(ConnectInfo), Connect(ConnectInfo),
DropHandshake(KeyID), DropHandshake(KeyID),
Recv(RawUdp), Recv(RawUdp),
UserSend((UserConnTracker, stream::ID, Vec<u8>)),
} }
/// Actual worker implementation. /// Actual worker implementation.
@ -70,6 +72,7 @@ pub struct Worker {
thread_channels: Vec<::async_channel::Sender<Work>>, thread_channels: Vec<::async_channel::Sender<Work>>,
connections: ConnList, connections: ConnList,
handshakes: handshake::Tracker, handshakes: handshake::Tracker,
work_timers: super::Timers,
} }
#[allow(unsafe_code)] #[allow(unsafe_code)]
@ -126,12 +129,15 @@ impl Worker {
thread_channels: Vec::new(), thread_channels: Vec::new(),
connections: ConnList::new(thread_id), connections: ConnList::new(thread_id),
handshakes, handshakes,
work_timers: super::Timers::new(),
}) })
} }
/// Continuously loop and process work as needed /// Continuously loop and process work as needed
pub async fn work_loop(&mut self) { pub async fn work_loop(&mut self) {
'mainloop: loop { 'mainloop: loop {
let next_timer = self.work_timers.get_next();
::tokio::pin!(next_timer);
let work = ::tokio::select! { let work = ::tokio::select! {
tell_stopped = self.stop_working.recv() => { tell_stopped = self.stop_working.recv() => {
if let Ok(stop_ch) = tell_stopped { if let Ok(stop_ch) = tell_stopped {
@ -140,6 +146,13 @@ impl Worker {
} }
break; break;
} }
() = &mut next_timer => {
let work_list = self.work_timers.get_work();
for w in work_list.into_iter() {
let _ = self.queue_sender.send(w).await;
}
continue 'mainloop;
}
maybe_timeout = self.queue.recv() => { maybe_timeout = self.queue.recv() => {
match maybe_timeout { match maybe_timeout {
Ok(work) => work, Ok(work) => work,
@ -419,6 +432,13 @@ impl Worker {
Work::Recv(pkt) => { Work::Recv(pkt) => {
self.recv(pkt).await; self.recv(pkt).await;
} }
Work::UserSend((tracker, stream, data)) => {
let conn = match self.connections.get_mut(tracker) {
None => return,
Some(conn) => conn,
};
conn.send(stream, data);
}
} }
} }
} }
@ -596,21 +616,20 @@ impl Worker {
let id_recv = conn.id_recv; let id_recv = conn.id_recv;
let cipher = conn.cipher_recv.kind(); let cipher = conn.cipher_recv.kind();
// track the connection to the authentication server // track the connection to the authentication server
let track_auth_conn = let track_auth_conn = match self.connections.track(conn) {
match self.connections.track(conn.into()) { Ok(track_auth_conn) => track_auth_conn,
Ok(track_auth_conn) => track_auth_conn, Err(_) => {
Err(e) => { ::tracing::error!(
::tracing::error!( "Could not track new auth srv connection"
"Could not track new auth srv connection" );
); self.connections.remove(id_recv);
self.connections.remove(id_recv); // FIXME: proper connection closing
// FIXME: proper connection closing let _ = cci.answer.send(Err(
let _ = cci.answer.send(Err( handshake::Error::InternalTracking.into(),
handshake::Error::InternalTracking.into(), ));
)); return;
return; }
} };
};
let authsrv_conn = AuthSrvConn(connection::Conn { let authsrv_conn = AuthSrvConn(connection::Conn {
queue: self.queue_sender.clone(), queue: self.queue_sender.clone(),
conn: track_auth_conn, conn: track_auth_conn,
@ -635,26 +654,25 @@ impl Worker {
service_connection.id_recv = cci.service_connection_id; service_connection.id_recv = cci.service_connection_id;
service_connection.id_send = service_connection.id_send =
IDSend(resp_data.service_connection_id); IDSend(resp_data.service_connection_id);
let track_serv_conn = match self let track_serv_conn =
.connections match self.connections.track(service_connection) {
.track(service_connection.into()) Ok(track_serv_conn) => track_serv_conn,
{ Err(_) => {
Ok(track_serv_conn) => track_serv_conn, ::tracing::error!(
Err(e) => { "Could not track new service connection"
::tracing::error!( );
"Could not track new service connection" self.connections
); .remove(cci.service_connection_id);
self.connections // FIXME: proper connection closing
.remove(cci.service_connection_id); // FIXME: drop auth srv connection if we just
// FIXME: proper connection closing // established it
// FIXME: drop auth srv connection if we just let _ = cci.answer.send(Err(
// established it handshake::Error::InternalTracking
let _ = cci.answer.send(Err( .into(),
handshake::Error::InternalTracking.into(), ));
)); return;
return; }
} };
};
service_conn = Some(ServiceConn(connection::Conn { service_conn = Some(ServiceConn(connection::Conn {
queue: self.queue_sender.clone(), queue: self.queue_sender.clone(),
conn: track_serv_conn, conn: track_serv_conn,

View File

@ -176,6 +176,7 @@ impl Fenrir {
config: &Config, config: &Config,
tokio_rt: Arc<::tokio::runtime::Runtime>, tokio_rt: Arc<::tokio::runtime::Runtime>,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
inner::set_minimum_sleep_resolution().await;
let (sender, _) = ::tokio::sync::broadcast::channel(1); let (sender, _) = ::tokio::sync::broadcast::channel(1);
let dnssec = dnssec::Dnssec::new(&config.resolvers)?; let dnssec = dnssec::Dnssec::new(&config.resolvers)?;
// bind sockets early so we can change "port 0" (aka: random) // bind sockets early so we can change "port 0" (aka: random)
@ -214,6 +215,7 @@ impl Fenrir {
pub async fn with_workers( pub async fn with_workers(
config: &Config, config: &Config,
) -> Result<(Self, Vec<Worker>), Error> { ) -> Result<(Self, Vec<Worker>), Error> {
inner::set_minimum_sleep_resolution().await;
let (stop_working, _) = ::tokio::sync::broadcast::channel(1); let (stop_working, _) = ::tokio::sync::broadcast::channel(1);
let dnssec = dnssec::Dnssec::new(&config.resolvers)?; let dnssec = dnssec::Dnssec::new(&config.resolvers)?;
// bind sockets early so we can change "port 0" (aka: random) // bind sockets early so we can change "port 0" (aka: random)