More thread-pinning work.

No more Arc<Connection>, Rc<Connection> is better on the same thread.
Track the thread number so we can generate the correct connection IDs

Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
This commit is contained in:
Luca Fulchir 2023-05-24 17:30:15 +02:00
parent 9b33ed8828
commit 810cc16ce6
Signed by: luca.fulchir
GPG Key ID: 8F6440603D13A78E
5 changed files with 109 additions and 71 deletions

View File

@ -7,7 +7,7 @@ use crate::{
enc::sym::{HeadLen, TagLen}, enc::sym::{HeadLen, TagLen},
}; };
use ::num_traits::FromPrimitive; use ::num_traits::FromPrimitive;
use ::std::sync::Arc; use ::std::{rc::Rc, sync::Arc};
/// Handshake errors /// Handshake errors
#[derive(::thiserror::Error, Debug, Copy, Clone)] #[derive(::thiserror::Error, Debug, Copy, Clone)]
@ -36,7 +36,7 @@ pub(crate) struct HandshakeServer {
pub(crate) struct HandshakeClient { pub(crate) struct HandshakeClient {
pub id: crate::enc::asym::KeyID, pub id: crate::enc::asym::KeyID,
pub key: crate::enc::asym::PrivKey, pub key: crate::enc::asym::PrivKey,
pub connection: Arc<crate::connection::Connection>, pub connection: Rc<crate::connection::Connection>,
} }
/// Parsed handshake /// Parsed handshake

View File

@ -4,16 +4,19 @@ pub mod handshake;
pub mod packet; pub mod packet;
pub mod socket; pub mod socket;
use ::std::{sync::Arc, vec::Vec}; use ::std::{rc::Rc, sync::Arc, vec::Vec};
pub use crate::connection::{ pub use crate::connection::{
handshake::Handshake, handshake::Handshake,
packet::{ConnectionID as ID, Packet, PacketData}, packet::{ConnectionID as ID, Packet, PacketData},
}; };
use crate::enc::{ use crate::{
hkdf::HkdfSha3, enc::{
sym::{CipherKind, CipherRecv, CipherSend}, hkdf::HkdfSha3,
sym::{CipherKind, CipherRecv, CipherSend},
},
inner::ThreadTracker,
}; };
/// strong typedef for receiving connection id /// strong typedef for receiving connection id
@ -99,16 +102,18 @@ impl Connection {
// PERF: Arc<RwLock<ConnList>> loks a bit too much, need to find // PERF: Arc<RwLock<ConnList>> loks a bit too much, need to find
// faster ways to do this // faster ways to do this
pub(crate) struct ConnList { pub(crate) struct ConnList {
connections: Vec<Option<Arc<Connection>>>, thread_id: ThreadTracker,
connections: Vec<Option<Rc<Connection>>>,
/// Bitmap to track which connection ids are used or free /// Bitmap to track which connection ids are used or free
ids_used: Vec<::bitmaps::Bitmap<1024>>, ids_used: Vec<::bitmaps::Bitmap<1024>>,
} }
impl ConnList { impl ConnList {
pub(crate) fn new() -> Self { pub(crate) fn new(thread_id: ThreadTracker) -> Self {
let mut bitmap_id = ::bitmaps::Bitmap::<1024>::new(); let mut bitmap_id = ::bitmaps::Bitmap::<1024>::new();
bitmap_id.set(0, true); // ID(0) == handshake bitmap_id.set(0, true); // ID(0) == handshake
Self { Self {
thread_id,
connections: Vec::with_capacity(128), connections: Vec::with_capacity(128),
ids_used: vec![bitmap_id], ids_used: vec![bitmap_id],
} }
@ -116,20 +121,20 @@ impl ConnList {
pub(crate) fn reserve_first( pub(crate) fn reserve_first(
&mut self, &mut self,
mut conn: Connection, mut conn: Connection,
) -> Arc<Connection> { ) -> Rc<Connection> {
// uhm... bad things are going on here: // uhm... bad things are going on here:
// * id must be initialized, but only because: // * id must be initialized, but only because:
// * rust does not understand that after the `!found` id is always // * rust does not understand that after the `!found` id is always
// initialized // initialized
// * `ID::new_u64` is really safe only with >0, but here it always is // * `ID::new_u64` is really safe only with >0, but here it always is
// ...we should probably rewrite it in better, safer rust // ...we should probably rewrite it in better, safer rust
let mut id: u64 = 0; let mut id_in_thread: u64 = 0;
let mut found = false; let mut found = false;
for (i, b) in self.ids_used.iter_mut().enumerate() { for (i, b) in self.ids_used.iter_mut().enumerate() {
match b.first_false_index() { match b.first_false_index() {
Some(idx) => { Some(idx) => {
b.set(idx, true); b.set(idx, true);
id = ((i as u64) * 1024) + (idx as u64); id_in_thread = ((i as u64) * 1024) + (idx as u64);
found = true; found = true;
break; break;
} }
@ -139,17 +144,19 @@ impl ConnList {
if !found { if !found {
let mut new_bitmap = ::bitmaps::Bitmap::<1024>::new(); let mut new_bitmap = ::bitmaps::Bitmap::<1024>::new();
new_bitmap.set(0, true); new_bitmap.set(0, true);
id = (self.ids_used.len() as u64) * 1024; id_in_thread = (self.ids_used.len() as u64) * 1024;
self.ids_used.push(new_bitmap); self.ids_used.push(new_bitmap);
} }
let new_id = IDRecv(ID::new_u64(id)); let actual_id = (id_in_thread * (self.thread_id.total as u64))
+ (self.thread_id.id as u64);
let new_id = IDRecv(ID::new_u64(actual_id));
conn.id_recv = new_id; conn.id_recv = new_id;
let conn = Arc::new(conn); let conn = Rc::new(conn);
if (self.connections.len() as u64) < id { if (self.connections.len() as u64) < id_in_thread {
self.connections.push(Some(conn.clone())); self.connections.push(Some(conn.clone()));
} else { } else {
// very probably redundant // very probably redundant
self.connections[id as usize] = Some(conn.clone()); self.connections[id_in_thread as usize] = Some(conn.clone());
} }
conn conn
} }

View File

@ -19,11 +19,11 @@ use crate::{
Error, Error,
}; };
use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption}; use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption};
use ::std::{sync::Arc, vec::Vec}; use ::std::{rc::Rc, sync::Arc, vec::Vec};
/// Information needed to reply after the key exchange /// Information needed to reply after the key exchange
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct AuthNeededInfo { pub(crate) struct AuthNeededInfo {
/// Parsed handshake packet /// Parsed handshake packet
pub handshake: Handshake, pub handshake: Handshake,
/// hkdf generated from the handshake /// hkdf generated from the handshake
@ -34,15 +34,15 @@ pub struct AuthNeededInfo {
/// Client information needed to fully establish the conenction /// Client information needed to fully establish the conenction
#[derive(Debug)] #[derive(Debug)]
pub struct ClientConnectInfo { pub(crate) struct ClientConnectInfo {
/// Parsed handshake packet /// Parsed handshake packet
pub handshake: Handshake, pub handshake: Handshake,
/// Connection /// Connection
pub connection: Arc<Connection>, pub connection: Rc<Connection>,
} }
/// Intermediate actions to be taken while parsing the handshake /// Intermediate actions to be taken while parsing the handshake
#[derive(Debug)] #[derive(Debug)]
pub enum HandshakeAction { pub(crate) enum HandshakeAction {
/// Parsing finished, all ok, nothing to do /// Parsing finished, all ok, nothing to do
None, None,
/// Packet parsed, now go perform authentication /// Packet parsed, now go perform authentication
@ -51,14 +51,28 @@ pub enum HandshakeAction {
ClientConnect(ClientConnectInfo), ClientConnect(ClientConnectInfo),
} }
/// Track the total number of threads and our index
/// 65K cpus should be enough for anybody
#[derive(Debug, Clone, Copy)]
pub(crate) struct ThreadTracker {
pub total: u16,
/// Note: starts from 1
pub id: u16,
}
/// Async free but thread safe tracking of handhsakes and conenctions /// Async free but thread safe tracking of handhsakes and conenctions
pub struct HandshakeTracker { /// Note that we have multiple Handshake trackers, pinned to different cores
key_exchanges: ArcSwapAny<Arc<Vec<(asym::Key, asym::KeyExchange)>>>, /// Each of them will handle a subset of all handshakes.
ciphers: ArcSwapAny<Arc<Vec<CipherKind>>>, /// Each handshake is routed to a different tracker with:
/// (udp_src_sender_port % total_threads) - 1
pub(crate) struct HandshakeTracker {
thread_id: ThreadTracker,
key_exchanges: Vec<(asym::Key, asym::KeyExchange)>,
ciphers: Vec<CipherKind>,
/// ephemeral keys used server side in key exchange /// ephemeral keys used server side in key exchange
keys_srv: ArcSwapAny<Arc<Vec<HandshakeServer>>>, keys_srv: Vec<HandshakeServer>,
/// ephemeral keys used client side in key exchange /// ephemeral keys used client side in key exchange
hshake_cli: ArcSwapAny<Arc<Vec<HandshakeClient>>>, hshake_cli: Vec<HandshakeClient>,
} }
#[allow(unsafe_code)] #[allow(unsafe_code)]
unsafe impl Send for HandshakeTracker {} unsafe impl Send for HandshakeTracker {}
@ -66,12 +80,13 @@ unsafe impl Send for HandshakeTracker {}
unsafe impl Sync for HandshakeTracker {} unsafe impl Sync for HandshakeTracker {}
impl HandshakeTracker { impl HandshakeTracker {
pub fn new() -> Self { pub(crate) fn new(thread_id: ThreadTracker) -> Self {
Self { Self {
ciphers: ArcSwapAny::new(Arc::new(Vec::new())), thread_id,
key_exchanges: ArcSwapAny::new(Arc::new(Vec::new())), ciphers: Vec::new(),
keys_srv: ArcSwapAny::new(Arc::new(Vec::new())), key_exchanges: Vec::new(),
hshake_cli: ArcSwapAny::new(Arc::new(Vec::new())), keys_srv: Vec::new(),
hshake_cli: Vec::new(),
} }
} }
pub(crate) fn recv_handshake( pub(crate) fn recv_handshake(
@ -87,11 +102,8 @@ impl HandshakeTracker {
HandshakeData::DirSync(ref mut ds) => match ds { HandshakeData::DirSync(ref mut ds) => match ds {
DirSync::Req(ref mut req) => { DirSync::Req(ref mut req) => {
let ephemeral_key = { let ephemeral_key = {
// Keep this block short to avoid contention
// on self.keys_srv
let keys = self.keys_srv.load();
if let Some(h_k) = if let Some(h_k) =
keys.iter().find(|k| k.id == req.key_id) self.keys_srv.iter().find(|k| k.id == req.key_id)
{ {
use enc::asym::PrivKey; use enc::asym::PrivKey;
// Directory synchronized can only use keys // Directory synchronized can only use keys
@ -114,9 +126,8 @@ impl HandshakeTracker {
} }
let ephemeral_key = ephemeral_key.unwrap(); let ephemeral_key = ephemeral_key.unwrap();
{ {
let exchanges = self.key_exchanges.load();
if None if None
== exchanges.iter().find(|&x| { == self.key_exchanges.iter().find(|&x| {
*x == (ephemeral_key.kind(), req.exchange) *x == (ephemeral_key.kind(), req.exchange)
}) })
{ {
@ -126,8 +137,9 @@ impl HandshakeTracker {
} }
} }
{ {
let ciphers = self.ciphers.load(); if None
if None == ciphers.iter().find(|&x| *x == req.cipher) { == self.ciphers.iter().find(|&x| *x == req.cipher)
{
return Err(enc::Error::UnsupportedCipher.into()); return Err(enc::Error::UnsupportedCipher.into());
} }
} }
@ -164,10 +176,8 @@ impl HandshakeTracker {
} }
DirSync::Resp(resp) => { DirSync::Resp(resp) => {
let hshake = { let hshake = {
// Keep this block short to avoid contention match self
// on self.hshake_cli .hshake_cli
let hshake_cli_lock = self.hshake_cli.load();
match hshake_cli_lock
.iter() .iter()
.find(|h| h.id == resp.client_key_id) .find(|h| h.id == resp.client_key_id)
{ {

View File

@ -12,9 +12,9 @@ use crate::{
ConnList, Connection, IDSend, Packet, ID, ConnList, Connection, IDSend, Packet, ID,
}, },
enc::sym::Secret, enc::sym::Secret,
inner::{HandshakeAction, HandshakeTracker}, inner::{HandshakeAction, HandshakeTracker, ThreadTracker},
}; };
use ::std::{sync::Arc, vec::Vec}; use ::std::{rc::Rc, sync::Arc, vec::Vec};
/// This worker must be cpu-pinned /// This worker must be cpu-pinned
use ::tokio::{net::UdpSocket, sync::Mutex}; use ::tokio::{net::UdpSocket, sync::Mutex};
use std::net::SocketAddr; use std::net::SocketAddr;
@ -33,6 +33,7 @@ pub(crate) enum Work {
/// Actual worker implementation. /// Actual worker implementation.
pub(crate) struct Worker { pub(crate) struct Worker {
thread_id: ThreadTracker,
// PERF: rand uses syscalls. how to do that async? // PERF: rand uses syscalls. how to do that async?
rand: ::ring::rand::SystemRandom, rand: ::ring::rand::SystemRandom,
stop_working: ::tokio::sync::broadcast::Receiver<bool>, stop_working: ::tokio::sync::broadcast::Receiver<bool>,
@ -45,7 +46,27 @@ pub(crate) struct Worker {
} }
impl Worker { impl Worker {
pub(crate) async fn new_and_loop(
thread_id: ThreadTracker,
stop_working: ::tokio::sync::broadcast::Receiver<bool>,
token_check: Option<Arc<Mutex<TokenChecker>>>,
socket_addrs: Vec<::std::net::SocketAddr>,
queue: ::async_channel::Receiver<Work>,
) -> ::std::io::Result<()> {
// TODO: get a channel to send back information, and send the error
let mut worker = Self::new(
thread_id,
stop_working,
token_check,
socket_addrs,
queue,
)
.await?;
worker.work_loop().await;
Ok(())
}
pub(crate) async fn new( pub(crate) async fn new(
thread_id: ThreadTracker,
stop_working: ::tokio::sync::broadcast::Receiver<bool>, stop_working: ::tokio::sync::broadcast::Receiver<bool>,
token_check: Option<Arc<Mutex<TokenChecker>>>, token_check: Option<Arc<Mutex<TokenChecker>>>,
socket_addrs: Vec<::std::net::SocketAddr>, socket_addrs: Vec<::std::net::SocketAddr>,
@ -85,14 +106,15 @@ impl Worker {
}; };
Ok(Self { Ok(Self {
thread_id,
rand: ::ring::rand::SystemRandom::new(), rand: ::ring::rand::SystemRandom::new(),
stop_working, stop_working,
token_check, token_check,
sockets, sockets,
queue, queue,
thread_channels: Vec::new(), thread_channels: Vec::new(),
connections: ConnList::new(), connections: ConnList::new(thread_id),
handshakes: HandshakeTracker::new(), handshakes: HandshakeTracker::new(thread_id),
}) })
} }
pub(crate) async fn work_loop(&mut self) { pub(crate) async fn work_loop(&mut self) {
@ -167,6 +189,8 @@ impl Worker {
return; return;
} }
}; };
// FIXME: This part can take a while,
// we should just spawn it probably
let is_authenticated = { let is_authenticated = {
let tk_check = token_check.lock().await; let tk_check = token_check.lock().await;
tk_check( tk_check(
@ -273,7 +297,7 @@ impl Worker {
return; return;
} }
// FIXME: conn tracking and arc counting // FIXME: conn tracking and arc counting
let conn = Arc::get_mut(&mut cci.connection).unwrap(); let conn = Rc::get_mut(&mut cci.connection).unwrap();
conn.id_send = IDSend(resp_data.id); conn.id_send = IDSend(resp_data.id);
todo!(); todo!();
} }

View File

@ -34,7 +34,10 @@ use crate::{
socket::{SocketList, UdpClient, UdpServer}, socket::{SocketList, UdpClient, UdpServer},
Packet, Packet,
}, },
inner::worker::{RawUdp, Work, Worker}, inner::{
worker::{RawUdp, Work, Worker},
ThreadTracker,
},
}; };
pub use config::Config; pub use config::Config;
@ -72,12 +75,8 @@ pub struct Fenrir {
dnssec: Option<dnssec::Dnssec>, dnssec: Option<dnssec::Dnssec>,
/// Broadcast channel to tell workers to stop working /// Broadcast channel to tell workers to stop working
stop_working: ::tokio::sync::broadcast::Sender<bool>, stop_working: ::tokio::sync::broadcast::Sender<bool>,
/// Private keys used in the handshake
_inner: Arc<inner::HandshakeTracker>,
/// where to ask for token check /// where to ask for token check
token_check: Option<Arc<::tokio::sync::Mutex<TokenChecker>>>, token_check: Option<Arc<::tokio::sync::Mutex<TokenChecker>>>,
// PERF: rand uses syscalls. should we do that async?
rand: ::ring::rand::SystemRandom,
// TODO: find a way to both increase and decrease these two in a thread-safe // TODO: find a way to both increase and decrease these two in a thread-safe
// manner // manner
_thread_pool: Vec<::std::thread::JoinHandle<()>>, _thread_pool: Vec<::std::thread::JoinHandle<()>>,
@ -103,9 +102,7 @@ impl Fenrir {
sockets: SocketList::new(), sockets: SocketList::new(),
dnssec: None, dnssec: None,
stop_working: sender, stop_working: sender,
_inner: Arc::new(inner::HandshakeTracker::new()),
token_check: None, token_check: None,
rand: ::ring::rand::SystemRandom::new(),
_thread_pool: Vec::new(), _thread_pool: Vec::new(),
_thread_work: Arc::new(Vec::new()), _thread_work: Arc::new(Vec::new()),
}; };
@ -130,7 +127,6 @@ impl Fenrir {
/// asyncronous version for Drop /// asyncronous version for Drop
fn stop_sync(&mut self) { fn stop_sync(&mut self) {
let _ = self.stop_working.send(true); let _ = self.stop_working.send(true);
// FIXME: wait for thread pool to actually stop
let mut toempty_sockets = self.sockets.rm_all(); let mut toempty_sockets = self.sockets.rm_all();
let task = ::tokio::task::spawn(toempty_sockets.stop_all()); let task = ::tokio::task::spawn(toempty_sockets.stop_all());
let _ = ::futures::executor::block_on(task); let _ = ::futures::executor::block_on(task);
@ -143,7 +139,6 @@ impl Fenrir {
/// Stop all workers, listeners /// Stop all workers, listeners
pub async fn stop(&mut self) { pub async fn stop(&mut self) {
let _ = self.stop_working.send(true); let _ = self.stop_working.send(true);
// FIXME: wait for thread pool to actually stop
let mut toempty_sockets = self.sockets.rm_all(); let mut toempty_sockets = self.sockets.rm_all();
toempty_sockets.stop_all().await; toempty_sockets.stop_all().await;
let mut old_thread_pool = Vec::new(); let mut old_thread_pool = Vec::new();
@ -285,19 +280,12 @@ impl Fenrir {
let th_topology = hw_topology.clone(); let th_topology = hw_topology.clone();
let th_tokio_rt = tokio_rt.clone(); let th_tokio_rt = tokio_rt.clone();
let (work_send, work_recv) = ::async_channel::unbounded::<Work>(); let (work_send, work_recv) = ::async_channel::unbounded::<Work>();
let mut worker = match Worker::new( let th_stop_working = self.stop_working.subscribe();
self.stop_working.subscribe(), let th_token_check = self.token_check.clone();
self.token_check.clone(), let th_socket_addrs = self.cfg.listen.clone();
self.cfg.listen.clone(), let thread_id = ThreadTracker {
work_recv, total: cores as u16,
) id: 1 + (core as u16),
.await
{
Ok(worker) => worker,
Err(e) => {
::tracing::error!("can't start worker");
return Err(Error::IO(e));
}
}; };
let join_handle = ::std::thread::spawn(move || { let join_handle = ::std::thread::spawn(move || {
@ -324,7 +312,16 @@ impl Fenrir {
// finally run the main worker. // finally run the main worker.
// make sure things stay on this thread // make sure things stay on this thread
let tk_local = ::tokio::task::LocalSet::new(); let tk_local = ::tokio::task::LocalSet::new();
let _ = tk_local.block_on(&th_tokio_rt, worker.work_loop()); let _ = tk_local.block_on(
&th_tokio_rt,
Worker::new_and_loop(
thread_id,
th_stop_working,
th_token_check,
th_socket_addrs,
work_recv,
),
);
}); });
loop { loop {
let queues_lock = match Arc::get_mut(&mut self._thread_work) { let queues_lock = match Arc::get_mut(&mut self._thread_work) {