User conn tracking, enqueue data, timers

Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
This commit is contained in:
Luca Fulchir 2023-06-22 12:50:47 +02:00
parent b49ede334f
commit e3ae166ca9
Signed by: luca.fulchir
GPG Key ID: 8F6440603D13A78E
5 changed files with 198 additions and 55 deletions

View File

@ -5,7 +5,11 @@ pub mod packet;
pub mod socket;
pub mod stream;
use ::std::{collections::HashMap, rc::Rc, vec::Vec};
use ::core::num::Wrapping;
use ::std::{
collections::{BTreeMap, HashMap},
vec::Vec,
};
pub use crate::connection::{handshake::Handshake, packet::Packet};
@ -125,12 +129,12 @@ impl ProtocolVersion {
}
}
#[derive(Debug, Clone, Copy, Eq, Hash, PartialEq)]
pub(crate) struct UserConnTracker(usize);
#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
pub(crate) struct UserConnTracker(Wrapping<usize>);
impl UserConnTracker {
fn advance(&mut self) -> Self {
let old = self.0;
self.0 = self.0 + 1;
self.0 = self.0 + Wrapping(1);
UserConnTracker(old)
}
}
@ -151,8 +155,13 @@ pub struct Conn {
impl Conn {
/// Queue some data to be sent in this connection
pub fn send(&mut self, stream: stream::ID, _data: Vec<u8>) {
todo!()
// TODO: send_and_wait, that wait for recipient ACK
pub async fn send(&mut self, stream: stream::ID, data: Vec<u8>) {
use crate::inner::worker::Work;
let _ = self
.queue
.send(Work::UserSend((self.conn, stream, data)))
.await;
}
}
@ -160,15 +169,17 @@ impl Conn {
#[derive(Debug)]
pub(crate) struct Connection {
/// Receiving Conn ID
pub id_recv: IDRecv,
pub(crate) id_recv: IDRecv,
/// Sending Conn ID
pub id_send: IDSend,
pub(crate) id_send: IDSend,
/// The main hkdf used for all secrets in this connection
pub hkdf: Hkdf,
hkdf: Hkdf,
/// Cipher for decrypting data
pub cipher_recv: CipherRecv,
pub(crate) cipher_recv: CipherRecv,
/// Cipher for encrypting data
pub cipher_send: CipherSend,
pub(crate) cipher_send: CipherSend,
/// send queue for each Stream
send_queue: BTreeMap<stream::ID, Vec<Vec<u8>>>,
}
/// Role: track the connection direction
@ -210,14 +221,22 @@ impl Connection {
hkdf,
cipher_recv,
cipher_send,
send_queue: BTreeMap::new(),
}
}
pub(crate) fn send(&mut self, stream: stream::ID, data: Vec<u8>) {
let stream = match self.send_queue.get_mut(&stream) {
None => return,
Some(stream) => stream,
};
stream.push(data);
}
}
pub(crate) struct ConnList {
thread_id: ThreadTracker,
connections: Vec<Option<Rc<Connection>>>,
user_tracker: HashMap<UserConnTracker, usize>,
connections: Vec<Option<Connection>>,
user_tracker: BTreeMap<UserConnTracker, usize>,
last_tracked: UserConnTracker,
/// Bitmap to track which connection ids are used or free
ids_used: Vec<::bitmaps::Bitmap<1024>>,
@ -234,13 +253,27 @@ impl ConnList {
let mut ret = Self {
thread_id,
connections: Vec::with_capacity(INITIAL_CAP),
user_tracker: HashMap::with_capacity(INITIAL_CAP),
last_tracked: UserConnTracker(0),
user_tracker: BTreeMap::new(),
last_tracked: UserConnTracker(Wrapping(0)),
ids_used: vec![bitmap_id],
};
ret.connections.resize_with(INITIAL_CAP, || None);
ret
}
pub fn get_mut(
&mut self,
tracker: UserConnTracker,
) -> Option<&mut Connection> {
let idx = if let Some(idx) = self.user_tracker.get(&tracker) {
*idx
} else {
return None;
};
match &mut self.connections[idx] {
None => None,
Some(conn) => Some(conn),
}
}
pub fn len(&self) -> usize {
let mut total: usize = 0;
for bitmap in self.ids_used.iter() {
@ -293,7 +326,7 @@ impl ConnList {
/// NOTE: does NOT check if the connection has been previously reserved!
pub(crate) fn track(
&mut self,
conn: Rc<Connection>,
conn: Connection,
) -> Result<UserConnTracker, ()> {
let conn_id = match conn.id_recv {
IDRecv(ID::Handshake) => {
@ -304,8 +337,14 @@ impl ConnList {
let id_in_thread: usize =
(conn_id.get() / (self.thread_id.total as u64)) as usize;
self.connections[id_in_thread] = Some(conn);
let tracked = self.last_tracked.advance();
let _ = self.user_tracker.insert(tracked, id_in_thread);
let mut tracked;
loop {
tracked = self.last_tracked.advance();
if self.user_tracker.get(&tracked).is_none() {
let _ = self.user_tracker.insert(tracked, id_in_thread);
break;
}
}
Ok(tracked)
}
pub(crate) fn remove(&mut self, id: IDRecv) {

View File

@ -19,7 +19,7 @@ pub enum Kind {
}
/// Id of the stream
#[derive(Debug, Copy, Clone)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct ID(pub u16);
impl ID {

View File

@ -4,6 +4,10 @@
pub(crate) mod worker;
use crate::inner::worker::Work;
use ::std::{collections::BTreeMap, vec::Vec};
use ::tokio::time::Instant;
/// Track the total number of threads and our index
/// 65K cpus should be enough for anybody
#[derive(Debug, Clone, Copy)]
@ -12,3 +16,83 @@ pub(crate) struct ThreadTracker {
/// Note: starts from 1
pub id: u16,
}
pub(crate) static mut SLEEP_RESOLUTION: ::std::time::Duration =
if cfg!(linux) || cfg!(macos) {
::std::time::Duration::from_millis(1)
} else {
// windows
::std::time::Duration::from_millis(16)
};
pub(crate) async fn set_minimum_sleep_resolution() {
let nanosleep = ::std::time::Duration::from_nanos(1);
let mut tests: usize = 3;
while tests > 0 {
let pre_sleep = ::std::time::Instant::now();
::tokio::time::sleep(nanosleep).await;
let post_sleep = ::std::time::Instant::now();
let slept_for = post_sleep - pre_sleep;
#[allow(unsafe_code)]
unsafe {
if slept_for < SLEEP_RESOLUTION {
SLEEP_RESOLUTION = slept_for;
}
}
tests = tests - 1;
}
}
/// Sleeping has a higher resolution that we would like for packet pacing.
/// So we sleep for however log we need, then chunk up all the work here
/// we will end up chunking the work in SLEEP_RESOLUTION, then we will busy wait
/// for more precise timing
pub(crate) struct Timers {
times: BTreeMap<Instant, Work>,
}
impl Timers {
pub(crate) fn new() -> Self {
Self {
times: BTreeMap::new(),
}
}
pub(crate) fn get_next(&self) -> ::tokio::time::Sleep {
match self.times.keys().next() {
Some(entry) => ::tokio::time::sleep_until((*entry).into()),
None => {
::tokio::time::sleep(::std::time::Duration::from_secs(3600))
}
}
}
/// Get all the work from now up until now + SLEEP_RESOLUTION
pub(crate) fn get_work(&mut self) -> Vec<Work> {
let now: ::tokio::time::Instant = ::std::time::Instant::now().into();
let mut ret = Vec::with_capacity(4);
let mut count_rm = 0;
#[allow(unsafe_code)]
let next_instant = unsafe { now + SLEEP_RESOLUTION };
let mut iter = self.times.iter_mut().peekable();
loop {
match iter.peek() {
None => break,
Some(next) => {
if *next.0 > next_instant {
break;
}
}
}
let mut work = Work::DropHandshake(crate::enc::asym::KeyID(0));
let mut entry = iter.next().unwrap();
::core::mem::swap(&mut work, &mut entry.1);
ret.push(work);
count_rm = count_rm + 1;
}
while count_rm > 0 {
self.times.pop_first();
count_rm = count_rm - 1;
}
ret
}
}

View File

@ -11,7 +11,8 @@ use crate::{
},
packet::{self, Packet},
socket::{UdpClient, UdpServer},
AuthSrvConn, ConnList, Connection, IDSend, ServiceConn,
stream, AuthSrvConn, ConnList, Connection, IDSend, ServiceConn,
UserConnTracker,
},
dnssec,
enc::{
@ -51,6 +52,7 @@ pub(crate) enum Work {
Connect(ConnectInfo),
DropHandshake(KeyID),
Recv(RawUdp),
UserSend((UserConnTracker, stream::ID, Vec<u8>)),
}
/// Actual worker implementation.
@ -70,6 +72,7 @@ pub struct Worker {
thread_channels: Vec<::async_channel::Sender<Work>>,
connections: ConnList,
handshakes: handshake::Tracker,
work_timers: super::Timers,
}
#[allow(unsafe_code)]
@ -126,12 +129,15 @@ impl Worker {
thread_channels: Vec::new(),
connections: ConnList::new(thread_id),
handshakes,
work_timers: super::Timers::new(),
})
}
/// Continuously loop and process work as needed
pub async fn work_loop(&mut self) {
'mainloop: loop {
let next_timer = self.work_timers.get_next();
::tokio::pin!(next_timer);
let work = ::tokio::select! {
tell_stopped = self.stop_working.recv() => {
if let Ok(stop_ch) = tell_stopped {
@ -140,6 +146,13 @@ impl Worker {
}
break;
}
() = &mut next_timer => {
let work_list = self.work_timers.get_work();
for w in work_list.into_iter() {
let _ = self.queue_sender.send(w).await;
}
continue 'mainloop;
}
maybe_timeout = self.queue.recv() => {
match maybe_timeout {
Ok(work) => work,
@ -419,6 +432,13 @@ impl Worker {
Work::Recv(pkt) => {
self.recv(pkt).await;
}
Work::UserSend((tracker, stream, data)) => {
let conn = match self.connections.get_mut(tracker) {
None => return,
Some(conn) => conn,
};
conn.send(stream, data);
}
}
}
}
@ -596,21 +616,20 @@ impl Worker {
let id_recv = conn.id_recv;
let cipher = conn.cipher_recv.kind();
// track the connection to the authentication server
let track_auth_conn =
match self.connections.track(conn.into()) {
Ok(track_auth_conn) => track_auth_conn,
Err(e) => {
::tracing::error!(
"Could not track new auth srv connection"
);
self.connections.remove(id_recv);
// FIXME: proper connection closing
let _ = cci.answer.send(Err(
handshake::Error::InternalTracking.into(),
));
return;
}
};
let track_auth_conn = match self.connections.track(conn) {
Ok(track_auth_conn) => track_auth_conn,
Err(_) => {
::tracing::error!(
"Could not track new auth srv connection"
);
self.connections.remove(id_recv);
// FIXME: proper connection closing
let _ = cci.answer.send(Err(
handshake::Error::InternalTracking.into(),
));
return;
}
};
let authsrv_conn = AuthSrvConn(connection::Conn {
queue: self.queue_sender.clone(),
conn: track_auth_conn,
@ -635,26 +654,25 @@ impl Worker {
service_connection.id_recv = cci.service_connection_id;
service_connection.id_send =
IDSend(resp_data.service_connection_id);
let track_serv_conn = match self
.connections
.track(service_connection.into())
{
Ok(track_serv_conn) => track_serv_conn,
Err(e) => {
::tracing::error!(
"Could not track new service connection"
);
self.connections
.remove(cci.service_connection_id);
// FIXME: proper connection closing
// FIXME: drop auth srv connection if we just
// established it
let _ = cci.answer.send(Err(
handshake::Error::InternalTracking.into(),
));
return;
}
};
let track_serv_conn =
match self.connections.track(service_connection) {
Ok(track_serv_conn) => track_serv_conn,
Err(_) => {
::tracing::error!(
"Could not track new service connection"
);
self.connections
.remove(cci.service_connection_id);
// FIXME: proper connection closing
// FIXME: drop auth srv connection if we just
// established it
let _ = cci.answer.send(Err(
handshake::Error::InternalTracking
.into(),
));
return;
}
};
service_conn = Some(ServiceConn(connection::Conn {
queue: self.queue_sender.clone(),
conn: track_serv_conn,

View File

@ -176,6 +176,7 @@ impl Fenrir {
config: &Config,
tokio_rt: Arc<::tokio::runtime::Runtime>,
) -> Result<Self, Error> {
inner::set_minimum_sleep_resolution().await;
let (sender, _) = ::tokio::sync::broadcast::channel(1);
let dnssec = dnssec::Dnssec::new(&config.resolvers)?;
// bind sockets early so we can change "port 0" (aka: random)
@ -214,6 +215,7 @@ impl Fenrir {
pub async fn with_workers(
config: &Config,
) -> Result<(Self, Vec<Worker>), Error> {
inner::set_minimum_sleep_resolution().await;
let (stop_working, _) = ::tokio::sync::broadcast::channel(1);
let dnssec = dnssec::Dnssec::new(&config.resolvers)?;
// bind sockets early so we can change "port 0" (aka: random)