From 810cc16ce6b016cd3e1ef063d2f05183dfd76055 Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Wed, 24 May 2023 17:30:15 +0200 Subject: [PATCH] More thread-pinning work. No more Arc, Rc is better on the same thread. Track the thread number so we can generate the correct connection IDs Signed-off-by: Luca Fulchir --- src/connection/handshake/mod.rs | 4 +-- src/connection/mod.rs | 35 ++++++++++-------- src/inner/mod.rs | 64 +++++++++++++++++++-------------- src/inner/worker.rs | 34 +++++++++++++++--- src/lib.rs | 43 +++++++++++----------- 5 files changed, 109 insertions(+), 71 deletions(-) diff --git a/src/connection/handshake/mod.rs b/src/connection/handshake/mod.rs index 99eab75..a231bce 100644 --- a/src/connection/handshake/mod.rs +++ b/src/connection/handshake/mod.rs @@ -7,7 +7,7 @@ use crate::{ enc::sym::{HeadLen, TagLen}, }; use ::num_traits::FromPrimitive; -use ::std::sync::Arc; +use ::std::{rc::Rc, sync::Arc}; /// Handshake errors #[derive(::thiserror::Error, Debug, Copy, Clone)] @@ -36,7 +36,7 @@ pub(crate) struct HandshakeServer { pub(crate) struct HandshakeClient { pub id: crate::enc::asym::KeyID, pub key: crate::enc::asym::PrivKey, - pub connection: Arc, + pub connection: Rc, } /// Parsed handshake diff --git a/src/connection/mod.rs b/src/connection/mod.rs index fe75884..b79d910 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -4,16 +4,19 @@ pub mod handshake; pub mod packet; pub mod socket; -use ::std::{sync::Arc, vec::Vec}; +use ::std::{rc::Rc, sync::Arc, vec::Vec}; pub use crate::connection::{ handshake::Handshake, packet::{ConnectionID as ID, Packet, PacketData}, }; -use crate::enc::{ - hkdf::HkdfSha3, - sym::{CipherKind, CipherRecv, CipherSend}, +use crate::{ + enc::{ + hkdf::HkdfSha3, + sym::{CipherKind, CipherRecv, CipherSend}, + }, + inner::ThreadTracker, }; /// strong typedef for receiving connection id @@ -99,16 +102,18 @@ impl Connection { // PERF: Arc> loks a bit too much, need to find // faster ways to do this pub(crate) struct ConnList { - connections: Vec>>, + thread_id: ThreadTracker, + connections: Vec>>, /// Bitmap to track which connection ids are used or free ids_used: Vec<::bitmaps::Bitmap<1024>>, } impl ConnList { - pub(crate) fn new() -> Self { + pub(crate) fn new(thread_id: ThreadTracker) -> Self { let mut bitmap_id = ::bitmaps::Bitmap::<1024>::new(); bitmap_id.set(0, true); // ID(0) == handshake Self { + thread_id, connections: Vec::with_capacity(128), ids_used: vec![bitmap_id], } @@ -116,20 +121,20 @@ impl ConnList { pub(crate) fn reserve_first( &mut self, mut conn: Connection, - ) -> Arc { + ) -> Rc { // uhm... bad things are going on here: // * id must be initialized, but only because: // * rust does not understand that after the `!found` id is always // initialized // * `ID::new_u64` is really safe only with >0, but here it always is // ...we should probably rewrite it in better, safer rust - let mut id: u64 = 0; + let mut id_in_thread: u64 = 0; let mut found = false; for (i, b) in self.ids_used.iter_mut().enumerate() { match b.first_false_index() { Some(idx) => { b.set(idx, true); - id = ((i as u64) * 1024) + (idx as u64); + id_in_thread = ((i as u64) * 1024) + (idx as u64); found = true; break; } @@ -139,17 +144,19 @@ impl ConnList { if !found { let mut new_bitmap = ::bitmaps::Bitmap::<1024>::new(); new_bitmap.set(0, true); - id = (self.ids_used.len() as u64) * 1024; + id_in_thread = (self.ids_used.len() as u64) * 1024; self.ids_used.push(new_bitmap); } - let new_id = IDRecv(ID::new_u64(id)); + let actual_id = (id_in_thread * (self.thread_id.total as u64)) + + (self.thread_id.id as u64); + let new_id = IDRecv(ID::new_u64(actual_id)); conn.id_recv = new_id; - let conn = Arc::new(conn); - if (self.connections.len() as u64) < id { + let conn = Rc::new(conn); + if (self.connections.len() as u64) < id_in_thread { self.connections.push(Some(conn.clone())); } else { // very probably redundant - self.connections[id as usize] = Some(conn.clone()); + self.connections[id_in_thread as usize] = Some(conn.clone()); } conn } diff --git a/src/inner/mod.rs b/src/inner/mod.rs index 73490f8..c2569ec 100644 --- a/src/inner/mod.rs +++ b/src/inner/mod.rs @@ -19,11 +19,11 @@ use crate::{ Error, }; use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption}; -use ::std::{sync::Arc, vec::Vec}; +use ::std::{rc::Rc, sync::Arc, vec::Vec}; /// Information needed to reply after the key exchange #[derive(Debug, Clone)] -pub struct AuthNeededInfo { +pub(crate) struct AuthNeededInfo { /// Parsed handshake packet pub handshake: Handshake, /// hkdf generated from the handshake @@ -34,15 +34,15 @@ pub struct AuthNeededInfo { /// Client information needed to fully establish the conenction #[derive(Debug)] -pub struct ClientConnectInfo { +pub(crate) struct ClientConnectInfo { /// Parsed handshake packet pub handshake: Handshake, /// Connection - pub connection: Arc, + pub connection: Rc, } /// Intermediate actions to be taken while parsing the handshake #[derive(Debug)] -pub enum HandshakeAction { +pub(crate) enum HandshakeAction { /// Parsing finished, all ok, nothing to do None, /// Packet parsed, now go perform authentication @@ -51,14 +51,28 @@ pub enum HandshakeAction { ClientConnect(ClientConnectInfo), } +/// Track the total number of threads and our index +/// 65K cpus should be enough for anybody +#[derive(Debug, Clone, Copy)] +pub(crate) struct ThreadTracker { + pub total: u16, + /// Note: starts from 1 + pub id: u16, +} + /// Async free but thread safe tracking of handhsakes and conenctions -pub struct HandshakeTracker { - key_exchanges: ArcSwapAny>>, - ciphers: ArcSwapAny>>, +/// Note that we have multiple Handshake trackers, pinned to different cores +/// Each of them will handle a subset of all handshakes. +/// Each handshake is routed to a different tracker with: +/// (udp_src_sender_port % total_threads) - 1 +pub(crate) struct HandshakeTracker { + thread_id: ThreadTracker, + key_exchanges: Vec<(asym::Key, asym::KeyExchange)>, + ciphers: Vec, /// ephemeral keys used server side in key exchange - keys_srv: ArcSwapAny>>, + keys_srv: Vec, /// ephemeral keys used client side in key exchange - hshake_cli: ArcSwapAny>>, + hshake_cli: Vec, } #[allow(unsafe_code)] unsafe impl Send for HandshakeTracker {} @@ -66,12 +80,13 @@ unsafe impl Send for HandshakeTracker {} unsafe impl Sync for HandshakeTracker {} impl HandshakeTracker { - pub fn new() -> Self { + pub(crate) fn new(thread_id: ThreadTracker) -> Self { Self { - ciphers: ArcSwapAny::new(Arc::new(Vec::new())), - key_exchanges: ArcSwapAny::new(Arc::new(Vec::new())), - keys_srv: ArcSwapAny::new(Arc::new(Vec::new())), - hshake_cli: ArcSwapAny::new(Arc::new(Vec::new())), + thread_id, + ciphers: Vec::new(), + key_exchanges: Vec::new(), + keys_srv: Vec::new(), + hshake_cli: Vec::new(), } } pub(crate) fn recv_handshake( @@ -87,11 +102,8 @@ impl HandshakeTracker { HandshakeData::DirSync(ref mut ds) => match ds { DirSync::Req(ref mut req) => { let ephemeral_key = { - // Keep this block short to avoid contention - // on self.keys_srv - let keys = self.keys_srv.load(); if let Some(h_k) = - keys.iter().find(|k| k.id == req.key_id) + self.keys_srv.iter().find(|k| k.id == req.key_id) { use enc::asym::PrivKey; // Directory synchronized can only use keys @@ -114,9 +126,8 @@ impl HandshakeTracker { } let ephemeral_key = ephemeral_key.unwrap(); { - let exchanges = self.key_exchanges.load(); if None - == exchanges.iter().find(|&x| { + == self.key_exchanges.iter().find(|&x| { *x == (ephemeral_key.kind(), req.exchange) }) { @@ -126,8 +137,9 @@ impl HandshakeTracker { } } { - let ciphers = self.ciphers.load(); - if None == ciphers.iter().find(|&x| *x == req.cipher) { + if None + == self.ciphers.iter().find(|&x| *x == req.cipher) + { return Err(enc::Error::UnsupportedCipher.into()); } } @@ -164,10 +176,8 @@ impl HandshakeTracker { } DirSync::Resp(resp) => { let hshake = { - // Keep this block short to avoid contention - // on self.hshake_cli - let hshake_cli_lock = self.hshake_cli.load(); - match hshake_cli_lock + match self + .hshake_cli .iter() .find(|h| h.id == resp.client_key_id) { diff --git a/src/inner/worker.rs b/src/inner/worker.rs index 32e2148..7e3bd2e 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -12,9 +12,9 @@ use crate::{ ConnList, Connection, IDSend, Packet, ID, }, enc::sym::Secret, - inner::{HandshakeAction, HandshakeTracker}, + inner::{HandshakeAction, HandshakeTracker, ThreadTracker}, }; -use ::std::{sync::Arc, vec::Vec}; +use ::std::{rc::Rc, sync::Arc, vec::Vec}; /// This worker must be cpu-pinned use ::tokio::{net::UdpSocket, sync::Mutex}; use std::net::SocketAddr; @@ -33,6 +33,7 @@ pub(crate) enum Work { /// Actual worker implementation. pub(crate) struct Worker { + thread_id: ThreadTracker, // PERF: rand uses syscalls. how to do that async? rand: ::ring::rand::SystemRandom, stop_working: ::tokio::sync::broadcast::Receiver, @@ -45,7 +46,27 @@ pub(crate) struct Worker { } impl Worker { + pub(crate) async fn new_and_loop( + thread_id: ThreadTracker, + stop_working: ::tokio::sync::broadcast::Receiver, + token_check: Option>>, + socket_addrs: Vec<::std::net::SocketAddr>, + queue: ::async_channel::Receiver, + ) -> ::std::io::Result<()> { + // TODO: get a channel to send back information, and send the error + let mut worker = Self::new( + thread_id, + stop_working, + token_check, + socket_addrs, + queue, + ) + .await?; + worker.work_loop().await; + Ok(()) + } pub(crate) async fn new( + thread_id: ThreadTracker, stop_working: ::tokio::sync::broadcast::Receiver, token_check: Option>>, socket_addrs: Vec<::std::net::SocketAddr>, @@ -85,14 +106,15 @@ impl Worker { }; Ok(Self { + thread_id, rand: ::ring::rand::SystemRandom::new(), stop_working, token_check, sockets, queue, thread_channels: Vec::new(), - connections: ConnList::new(), - handshakes: HandshakeTracker::new(), + connections: ConnList::new(thread_id), + handshakes: HandshakeTracker::new(thread_id), }) } pub(crate) async fn work_loop(&mut self) { @@ -167,6 +189,8 @@ impl Worker { return; } }; + // FIXME: This part can take a while, + // we should just spawn it probably let is_authenticated = { let tk_check = token_check.lock().await; tk_check( @@ -273,7 +297,7 @@ impl Worker { return; } // FIXME: conn tracking and arc counting - let conn = Arc::get_mut(&mut cci.connection).unwrap(); + let conn = Rc::get_mut(&mut cci.connection).unwrap(); conn.id_send = IDSend(resp_data.id); todo!(); } diff --git a/src/lib.rs b/src/lib.rs index 5c8a5d7..02df264 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -34,7 +34,10 @@ use crate::{ socket::{SocketList, UdpClient, UdpServer}, Packet, }, - inner::worker::{RawUdp, Work, Worker}, + inner::{ + worker::{RawUdp, Work, Worker}, + ThreadTracker, + }, }; pub use config::Config; @@ -72,12 +75,8 @@ pub struct Fenrir { dnssec: Option, /// Broadcast channel to tell workers to stop working stop_working: ::tokio::sync::broadcast::Sender, - /// Private keys used in the handshake - _inner: Arc, /// where to ask for token check token_check: Option>>, - // PERF: rand uses syscalls. should we do that async? - rand: ::ring::rand::SystemRandom, // TODO: find a way to both increase and decrease these two in a thread-safe // manner _thread_pool: Vec<::std::thread::JoinHandle<()>>, @@ -103,9 +102,7 @@ impl Fenrir { sockets: SocketList::new(), dnssec: None, stop_working: sender, - _inner: Arc::new(inner::HandshakeTracker::new()), token_check: None, - rand: ::ring::rand::SystemRandom::new(), _thread_pool: Vec::new(), _thread_work: Arc::new(Vec::new()), }; @@ -130,7 +127,6 @@ impl Fenrir { /// asyncronous version for Drop fn stop_sync(&mut self) { let _ = self.stop_working.send(true); - // FIXME: wait for thread pool to actually stop let mut toempty_sockets = self.sockets.rm_all(); let task = ::tokio::task::spawn(toempty_sockets.stop_all()); let _ = ::futures::executor::block_on(task); @@ -143,7 +139,6 @@ impl Fenrir { /// Stop all workers, listeners pub async fn stop(&mut self) { let _ = self.stop_working.send(true); - // FIXME: wait for thread pool to actually stop let mut toempty_sockets = self.sockets.rm_all(); toempty_sockets.stop_all().await; let mut old_thread_pool = Vec::new(); @@ -285,19 +280,12 @@ impl Fenrir { let th_topology = hw_topology.clone(); let th_tokio_rt = tokio_rt.clone(); let (work_send, work_recv) = ::async_channel::unbounded::(); - let mut worker = match Worker::new( - self.stop_working.subscribe(), - self.token_check.clone(), - self.cfg.listen.clone(), - work_recv, - ) - .await - { - Ok(worker) => worker, - Err(e) => { - ::tracing::error!("can't start worker"); - return Err(Error::IO(e)); - } + let th_stop_working = self.stop_working.subscribe(); + let th_token_check = self.token_check.clone(); + let th_socket_addrs = self.cfg.listen.clone(); + let thread_id = ThreadTracker { + total: cores as u16, + id: 1 + (core as u16), }; let join_handle = ::std::thread::spawn(move || { @@ -324,7 +312,16 @@ impl Fenrir { // finally run the main worker. // make sure things stay on this thread let tk_local = ::tokio::task::LocalSet::new(); - let _ = tk_local.block_on(&th_tokio_rt, worker.work_loop()); + let _ = tk_local.block_on( + &th_tokio_rt, + Worker::new_and_loop( + thread_id, + th_stop_working, + th_token_check, + th_socket_addrs, + work_recv, + ), + ); }); loop { let queues_lock = match Arc::get_mut(&mut self._thread_work) {