From 9b33ed882877995e605ed13bb393bc7a0605a1b6 Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Wed, 24 May 2023 15:45:37 +0200 Subject: [PATCH] Refactor, more pinned-thread work Signed-off-by: Luca Fulchir --- src/auth/mod.rs | 12 +- src/connection/socket.rs | 13 +- src/inner/mod.rs | 11 +- src/inner/worker.rs | 307 +++++++++++++++++++++++++++++++++++ src/lib.rs | 335 +++++++-------------------------------- 5 files changed, 385 insertions(+), 293 deletions(-) create mode 100644 src/inner/worker.rs diff --git a/src/auth/mod.rs b/src/auth/mod.rs index 17f43b9..52b51e0 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -1,4 +1,4 @@ -//! Authentication reslated struct definitions +//! Authentication related struct definitions use ::ring::rand::SecureRandom; use ::zeroize::Zeroize; @@ -53,6 +53,16 @@ impl ::core::fmt::Debug for Token { } } +/// Type of the function used to check the validity of the tokens +/// Reimplement this to use whatever database you want +pub type TokenChecker = + fn( + user: UserID, + token: Token, + service_id: ServiceID, + domain: Domain, + ) -> ::futures::future::BoxFuture<'static, Result>; + /// domain representation /// Security notice: internal representation is utf8, but we will /// further limit to a "safe" subset of utf8 diff --git a/src/connection/socket.rs b/src/connection/socket.rs index b95e7c9..1cc570d 100644 --- a/src/connection/socket.rs +++ b/src/connection/socket.rs @@ -8,17 +8,8 @@ use ::std::{ }; use ::tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle}; -// PERF: move to arcswap: -// We use multiple UDP sockets. -// We need a list of them all and to scan this list. -// But we want to be able to update this list without locking everyone -// So we should wrap into ArcSwap, and use unsafe `as_raf_fd()` and -// `from_raw_fd` to update the list. -// This means that we will have multiple `UdpSocket` per actual socket -// so we have to handle `drop()` manually, and garbage-collect the ones we -// are no longer using in the background. sigh. -// Just go with a ArcSwapAny, Arc>>); +/// Pair to easily track the socket and its async listening handle +pub type SocketTracker = (Arc, Arc>>); /// async free socket list pub(crate) struct SocketList { diff --git a/src/inner/mod.rs b/src/inner/mod.rs index 6007868..73490f8 100644 --- a/src/inner/mod.rs +++ b/src/inner/mod.rs @@ -2,7 +2,10 @@ //! This is meant to be **async-free** so that others might use it //! without the tokio runtime +pub(crate) mod worker; + use crate::{ + auth, connection::{ self, handshake::{self, Handshake, HandshakeClient, HandshakeServer}, @@ -49,7 +52,7 @@ pub enum HandshakeAction { } /// Async free but thread safe tracking of handhsakes and conenctions -pub struct Tracker { +pub struct HandshakeTracker { key_exchanges: ArcSwapAny>>, ciphers: ArcSwapAny>>, /// ephemeral keys used server side in key exchange @@ -58,11 +61,11 @@ pub struct Tracker { hshake_cli: ArcSwapAny>>, } #[allow(unsafe_code)] -unsafe impl Send for Tracker {} +unsafe impl Send for HandshakeTracker {} #[allow(unsafe_code)] -unsafe impl Sync for Tracker {} +unsafe impl Sync for HandshakeTracker {} -impl Tracker { +impl HandshakeTracker { pub fn new() -> Self { Self { ciphers: ArcSwapAny::new(Arc::new(Vec::new())), diff --git a/src/inner/worker.rs b/src/inner/worker.rs new file mode 100644 index 0000000..32e2148 --- /dev/null +++ b/src/inner/worker.rs @@ -0,0 +1,307 @@ +//! Worker thread implementation +use crate::{ + auth::TokenChecker, + connection::{ + self, + handshake::{ + self, + dirsync::{self, DirSync}, + Handshake, HandshakeClient, HandshakeData, + }, + socket::{UdpClient, UdpServer}, + ConnList, Connection, IDSend, Packet, ID, + }, + enc::sym::Secret, + inner::{HandshakeAction, HandshakeTracker}, +}; +use ::std::{sync::Arc, vec::Vec}; +/// This worker must be cpu-pinned +use ::tokio::{net::UdpSocket, sync::Mutex}; +use std::net::SocketAddr; + +/// Track a raw Udp packet +pub(crate) struct RawUdp { + pub src: UdpClient, + pub dst: UdpServer, + pub data: Vec, + pub packet: Packet, +} + +pub(crate) enum Work { + Recv(RawUdp), +} + +/// Actual worker implementation. +pub(crate) struct Worker { + // PERF: rand uses syscalls. how to do that async? + rand: ::ring::rand::SystemRandom, + stop_working: ::tokio::sync::broadcast::Receiver, + token_check: Option>>, + sockets: Vec, + queue: ::async_channel::Receiver, + thread_channels: Vec<::async_channel::Sender>, + connections: ConnList, + handshakes: HandshakeTracker, +} + +impl Worker { + pub(crate) async fn new( + stop_working: ::tokio::sync::broadcast::Receiver, + token_check: Option>>, + socket_addrs: Vec<::std::net::SocketAddr>, + queue: ::async_channel::Receiver, + ) -> ::std::io::Result { + // bind all sockets again so that we can easily + // send without sharing resources + // in the future we will want to have a thread-local listener too, + // but before that we need ebpf to pin a connection to a thread + // directly from the kernel + let socket_binding = + socket_addrs.into_iter().map(|s_addr| async move { + let socket = ::tokio::spawn(connection::socket::bind_udp( + s_addr.clone(), + )) + .await??; + Ok(socket) + }); + let sockets_bind_res = + ::futures::future::join_all(socket_binding).await; + let sockets: Result, ::std::io::Error> = + sockets_bind_res + .into_iter() + .map(|s_res| match s_res { + Ok(s) => Ok(s), + Err(e) => { + ::tracing::error!("Worker can't bind on socket: {}", e); + Err(e) + } + }) + .collect(); + let sockets = match sockets { + Ok(sockets) => sockets, + Err(e) => { + return Err(e); + } + }; + + Ok(Self { + rand: ::ring::rand::SystemRandom::new(), + stop_working, + token_check, + sockets, + queue, + thread_channels: Vec::new(), + connections: ConnList::new(), + handshakes: HandshakeTracker::new(), + }) + } + pub(crate) async fn work_loop(&mut self) { + loop { + let work = ::tokio::select! { + _done = self.stop_working.recv() => { + break; + } + maybe_work = self.queue.recv() => { + match maybe_work { + Ok(work) => work, + Err(_) => break, + } + } + }; + match work { + //TODO: reconf message to add channels + Work::Recv(pkt) => { + self.recv(pkt).await; + } + } + } + } + /// Read and do stuff with the raw udp packet + async fn recv(&mut self, mut udp: RawUdp) { + if udp.packet.id.is_handshake() { + let handshake = match Handshake::deserialize(&udp.data[8..]) { + Ok(handshake) => handshake, + Err(e) => { + ::tracing::warn!("Handshake parsing: {}", e); + return; + } + }; + let action = match self + .handshakes + .recv_handshake(handshake, &mut udp.data[8..]) + { + Ok(action) => action, + Err(err) => { + ::tracing::debug!("Handshake recv error {}", err); + return; + } + }; + match action { + HandshakeAction::AuthNeeded(authinfo) => { + let token_check = match self.token_check.as_ref() { + Some(token_check) => token_check, + None => { + ::tracing::error!( + "Authentication requested but \ + we have no token checker" + ); + return; + } + }; + let req; + if let HandshakeData::DirSync(DirSync::Req(r)) = + authinfo.handshake.data + { + req = r; + } else { + ::tracing::error!("AuthInfo on non DS::Req"); + return; + } + use dirsync::ReqInner; + let req_data = match req.data { + ReqInner::ClearText(req_data) => req_data, + _ => { + ::tracing::error!( + "token_check: expected ClearText" + ); + return; + } + }; + let is_authenticated = { + let tk_check = token_check.lock().await; + tk_check( + req_data.auth.user, + req_data.auth.token, + req_data.auth.service_id, + req_data.auth.domain, + ) + .await + }; + let is_authenticated = match is_authenticated { + Ok(is_authenticated) => is_authenticated, + Err(_) => { + ::tracing::error!("error in token auth"); + // TODO: retry? + return; + } + }; + if !is_authenticated { + ::tracing::warn!( + "Wrong authentication for user {:?}", + req_data.auth.user + ); + // TODO: error response + return; + } + // Client has correctly authenticated + // TODO: contact the service, get the key and + // connection ID + let srv_conn_id = ID::new_rand(&self.rand); + let srv_secret = Secret::new_rand(&self.rand); + let head_len = req.cipher.nonce_len(); + let tag_len = req.cipher.tag_len(); + + let mut raw_conn = Connection::new( + authinfo.hkdf, + req.cipher, + connection::Role::Server, + &self.rand, + ); + raw_conn.id_send = IDSend(req_data.id); + // track connection + let auth_conn = self.connections.reserve_first(raw_conn); + + let resp_data = dirsync::RespData { + client_nonce: req_data.nonce, + id: auth_conn.id_recv.0, + service_id: srv_conn_id, + service_key: srv_secret, + }; + use crate::enc::sym::AAD; + // no aad for now + let aad = AAD(&mut []); + + use dirsync::RespInner; + let resp = dirsync::Resp { + client_key_id: req_data.client_key_id, + data: RespInner::ClearText(resp_data), + }; + let offset_to_encrypt = resp.encrypted_offset(); + let encrypt_until = + offset_to_encrypt + resp.encrypted_length() + tag_len.0; + let resp_handshake = Handshake::new( + HandshakeData::DirSync(DirSync::Resp(resp)), + ); + use connection::{PacketData, ID}; + let packet = Packet { + id: ID::new_handshake(), + data: PacketData::Handshake(resp_handshake), + }; + let mut raw_out = Vec::::with_capacity(packet.len()); + packet.serialize(head_len, tag_len, &mut raw_out); + + if let Err(e) = auth_conn.cipher_send.encrypt( + aad, + &mut raw_out[offset_to_encrypt..encrypt_until], + ) { + ::tracing::error!("can't encrypt: {:?}", e); + return; + } + self.send_packet(raw_out, udp.src, udp.dst).await; + return; + } + HandshakeAction::ClientConnect(mut cci) => { + let ds_resp; + if let HandshakeData::DirSync(DirSync::Resp(resp)) = + cci.handshake.data + { + ds_resp = resp; + } else { + ::tracing::error!("ClientConnect on non DS::Resp"); + return; + } + // track connection + use handshake::dirsync; + let resp_data; + if let dirsync::RespInner::ClearText(r_data) = ds_resp.data + { + resp_data = r_data; + } else { + ::tracing::error!( + "ClientConnect on non DS::Resp::ClearText" + ); + return; + } + // FIXME: conn tracking and arc counting + let conn = Arc::get_mut(&mut cci.connection).unwrap(); + conn.id_send = IDSend(resp_data.id); + todo!(); + } + _ => {} + }; + } + // copy packet, spawn + todo!(); + } + async fn send_packet( + &self, + data: Vec, + client: UdpClient, + server: UdpServer, + ) { + let src_sock = match self + .sockets + .iter() + .find(|&s| s.local_addr().unwrap() == server.0) + { + Some(src_sock) => src_sock, + None => { + ::tracing::error!( + "Can't send packet: Server changed listening ip!" + ); + return; + } + }; + src_sock.send_to(&data, client.0); + } +} diff --git a/src/lib.rs b/src/lib.rs index 94e67e7..5c8a5d7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,32 +20,21 @@ pub mod dnssec; pub mod enc; mod inner; -use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption}; use ::std::{ net::SocketAddr, - pin::Pin, sync::{Arc, Weak}, - vec::{self, Vec}, -}; -use ::tokio::{ - macros::support::Future, net::UdpSocket, sync::RwLock, task::JoinHandle, + vec::Vec, }; +use ::tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle}; use crate::{ + auth::TokenChecker, connection::{ - handshake::{ - self, dirsync::DirSync, Handshake, HandshakeClient, HandshakeData, - HandshakeServer, - }, + handshake, socket::{SocketList, UdpClient, UdpServer}, - ConnList, Connection, IDSend, Packet, + Packet, }, - enc::{ - asym, - hkdf::HkdfSha3, - sym::{CipherKind, CipherRecv, CipherSend, HeadLen, TagLen}, - }, - inner::HandshakeAction, + inner::worker::{RawUdp, Work, Worker}, }; pub use config::Config; @@ -55,6 +44,9 @@ pub enum Error { /// The library was not initialized (run .start()) #[error("not initialized")] NotInitialized, + /// Error in setting up worker threads + #[error("Setup err: {0}")] + Setup(String), /// General I/O error #[error("IO: {0:?}")] IO(#[from] ::std::io::Error), @@ -69,26 +61,6 @@ pub enum Error { Key(#[from] crate::enc::Error), } -type TokenChecker = - fn( - user: auth::UserID, - token: auth::Token, - service_id: auth::ServiceID, - domain: auth::Domain, - ) -> ::futures::future::BoxFuture<'static, Result>; - -/// Track a raw Udp packet -struct RawUdp { - src: UdpClient, - dst: UdpServer, - data: Vec, - packet: Packet, -} - -enum Work { - Recv(RawUdp), -} - /// Instance of a fenrir endpoint #[allow(missing_copy_implementations, missing_debug_implementations)] pub struct Fenrir { @@ -101,14 +73,11 @@ pub struct Fenrir { /// Broadcast channel to tell workers to stop working stop_working: ::tokio::sync::broadcast::Sender, /// Private keys used in the handshake - _inner: Arc, + _inner: Arc, /// where to ask for token check - token_check: Arc>, + token_check: Option>>, // PERF: rand uses syscalls. should we do that async? rand: ::ring::rand::SystemRandom, - /// list of Established connections - connections: Arc>, - _myself: Weak, // TODO: find a way to both increase and decrease these two in a thread-safe // manner _thread_pool: Vec<::std::thread::JoinHandle<()>>, @@ -125,28 +94,30 @@ impl Drop for Fenrir { impl Fenrir { /// Create a new Fenrir endpoint - pub fn new(config: &Config) -> Result, Error> { + pub fn new(config: &Config) -> Result { let listen_num = config.listen.len(); let (sender, _) = ::tokio::sync::broadcast::channel(1); let (work_send, work_recv) = ::async_channel::unbounded::(); - let endpoint = Arc::new_cyclic(|myself| Fenrir { + let endpoint = Fenrir { cfg: config.clone(), sockets: SocketList::new(), dnssec: None, stop_working: sender, - _inner: Arc::new(inner::Tracker::new()), - token_check: Arc::new(ArcSwapOption::from(None)), + _inner: Arc::new(inner::HandshakeTracker::new()), + token_check: None, rand: ::ring::rand::SystemRandom::new(), - connections: Arc::new(RwLock::new(ConnList::new())), - _myself: myself.clone(), _thread_pool: Vec::new(), _thread_work: Arc::new(Vec::new()), - }); + }; Ok(endpoint) } /// Start all workers, listeners - pub async fn start(&mut self) -> Result<(), Error> { + pub async fn start( + &mut self, + tokio_rt: Arc<::tokio::runtime::Runtime>, + ) -> Result<(), Error> { + self.start_work_threads_pinned(tokio_rt).await?; if let Err(e) = self.add_sockets().await { self.stop().await; return Err(e.into()); @@ -159,17 +130,25 @@ impl Fenrir { /// asyncronous version for Drop fn stop_sync(&mut self) { let _ = self.stop_working.send(true); + // FIXME: wait for thread pool to actually stop let mut toempty_sockets = self.sockets.rm_all(); let task = ::tokio::task::spawn(toempty_sockets.stop_all()); let _ = ::futures::executor::block_on(task); + let mut old_thread_pool = Vec::new(); + ::std::mem::swap(&mut self._thread_pool, &mut old_thread_pool); + old_thread_pool.into_iter().map(|th| th.join()); self.dnssec = None; } /// Stop all workers, listeners pub async fn stop(&mut self) { let _ = self.stop_working.send(true); + // FIXME: wait for thread pool to actually stop let mut toempty_sockets = self.sockets.rm_all(); toempty_sockets.stop_all().await; + let mut old_thread_pool = Vec::new(); + ::std::mem::swap(&mut self._thread_pool, &mut old_thread_pool); + old_thread_pool.into_iter().map(|th| th.join()); self.dnssec = None; } /// Add all UDP sockets found in config @@ -271,14 +250,18 @@ impl Fenrir { /// Start one working thread for each physical cpu /// threads are pinned to each cpu core. /// Work will be divided and rerouted so that there is no need to lock - pub async fn start_work_threads_pinned( + async fn start_work_threads_pinned( &mut self, tokio_rt: Arc<::tokio::runtime::Runtime>, - ) -> ::std::result::Result<(), ()> { + ) -> ::std::result::Result<(), Error> { use ::std::sync::Mutex; let hw_topology = match ::hwloc2::Topology::new() { Some(hw_topology) => Arc::new(Mutex::new(hw_topology)), - None => return Err(()), + None => { + return Err(Error::Setup( + "Can't get hardware topology".to_owned(), + )) + } }; let cores; { @@ -287,20 +270,36 @@ impl Fenrir { .objects_with_type(&::hwloc2::ObjectType::Core) { Ok(all_cores) => all_cores, - Err(_) => return Err(()), + Err(_) => { + return Err(Error::Setup("can't list cores".to_owned())) + } }; cores = all_cores.len(); if cores <= 0 || !topology_lock.support().cpu().set_thread() { ::tracing::error!("No support for CPU pinning"); - return Err(()); + return Err(Error::Setup("No cpu pinning support".to_owned())); } } for core in 0..cores { ::tracing::debug!("Spawning thread {}", core); let th_topology = hw_topology.clone(); let th_tokio_rt = tokio_rt.clone(); - let th_myself = self._myself.upgrade().unwrap(); let (work_send, work_recv) = ::async_channel::unbounded::(); + let mut worker = match Worker::new( + self.stop_working.subscribe(), + self.token_check.clone(), + self.cfg.listen.clone(), + work_recv, + ) + .await + { + Ok(worker) => worker, + Err(e) => { + ::tracing::error!("can't start worker"); + return Err(Error::IO(e)); + } + }; + let join_handle = ::std::thread::spawn(move || { // bind to a specific core let th_pinning; @@ -322,13 +321,10 @@ impl Fenrir { return; } } - // finally run the main listener. make sure things stay on this - // thread + // finally run the main worker. + // make sure things stay on this thread let tk_local = ::tokio::task::LocalSet::new(); - let _ = tk_local.block_on( - &th_tokio_rt, - Self::work_loop_thread(th_myself, work_recv), - ); + let _ = tk_local.block_on(&th_tokio_rt, worker.work_loop()); }); loop { let queues_lock = match Arc::get_mut(&mut self._thread_work) { @@ -348,219 +344,4 @@ impl Fenrir { } Ok(()) } - async fn work_loop_thread( - self: Arc, - work_recv: ::async_channel::Receiver, - ) { - let mut stop_working = self.stop_working.subscribe(); - loop { - let work = ::tokio::select! { - _done = stop_working.recv() => { - break; - } - maybe_work = work_recv.recv() => { - match maybe_work { - Ok(work) => work, - Err(_) => break, - } - } - }; - match work { - Work::Recv(pkt) => { - self.recv(pkt).await; - } - } - } - } - - /// Read and do stuff with the raw udp packet - async fn recv(&self, mut udp: RawUdp) { - if udp.packet.id.is_handshake() { - let handshake = match Handshake::deserialize(&udp.data[8..]) { - Ok(handshake) => handshake, - Err(e) => { - ::tracing::warn!("Handshake parsing: {}", e); - return; - } - }; - let action = - match self._inner.recv_handshake(handshake, &mut udp.data[8..]) - { - Ok(action) => action, - Err(err) => { - ::tracing::debug!("Handshake recv error {}", err); - return; - } - }; - match action { - HandshakeAction::AuthNeeded(authinfo) => { - let tk_check = match self.token_check.load_full() { - Some(tk_check) => tk_check, - None => { - ::tracing::error!( - "Handshake received, but no tocken_checker" - ); - return; - } - }; - use handshake::{ - dirsync::{self, DirSync}, - HandshakeData, - }; - let req; - if let HandshakeData::DirSync(DirSync::Req(r)) = - authinfo.handshake.data - { - req = r; - } else { - ::tracing::error!("AuthInfo on non DS::Req"); - return; - } - use dirsync::ReqInner; - let req_data = match req.data { - ReqInner::ClearText(req_data) => req_data, - _ => { - ::tracing::error!( - "token_check: expected ClearText" - ); - return; - } - }; - let is_authenticated = match tk_check( - req_data.auth.user, - req_data.auth.token, - req_data.auth.service_id, - req_data.auth.domain, - ) - .await - { - Ok(is_authenticated) => is_authenticated, - Err(_) => { - ::tracing::error!("error in token auth"); - // TODO: retry? - return; - } - }; - if !is_authenticated { - ::tracing::warn!( - "Wrong authentication for user {:?}", - req_data.auth.user - ); - // TODO: error response - return; - } - // Client has correctly authenticated - // TODO: contact the service, get the key and - // connection ID - let srv_conn_id = connection::ID::new_rand(&self.rand); - let srv_secret = enc::sym::Secret::new_rand(&self.rand); - let head_len = req.cipher.nonce_len(); - let tag_len = req.cipher.tag_len(); - - let mut raw_conn = Connection::new( - authinfo.hkdf, - req.cipher, - connection::Role::Server, - &self.rand, - ); - raw_conn.id_send = IDSend(req_data.id); - // track connection - let auth_conn = { - let mut lock = self.connections.write().await; - lock.reserve_first(raw_conn) - }; - - let resp_data = dirsync::RespData { - client_nonce: req_data.nonce, - id: auth_conn.id_recv.0, - service_id: srv_conn_id, - service_key: srv_secret, - }; - use crate::enc::sym::AAD; - // no aad for now - let aad = AAD(&mut []); - - use dirsync::RespInner; - let resp = dirsync::Resp { - client_key_id: req_data.client_key_id, - data: RespInner::ClearText(resp_data), - }; - let offset_to_encrypt = resp.encrypted_offset(); - let encrypt_until = - offset_to_encrypt + resp.encrypted_length() + tag_len.0; - let resp_handshake = Handshake::new( - HandshakeData::DirSync(DirSync::Resp(resp)), - ); - use connection::{PacketData, ID}; - let packet = Packet { - id: ID::new_handshake(), - data: PacketData::Handshake(resp_handshake), - }; - let mut raw_out = Vec::::with_capacity(packet.len()); - packet.serialize(head_len, tag_len, &mut raw_out); - - if let Err(e) = auth_conn.cipher_send.encrypt( - aad, - &mut raw_out[offset_to_encrypt..encrypt_until], - ) { - ::tracing::error!("can't encrypt: {:?}", e); - return; - } - self.send_packet(raw_out, udp.src, udp.dst).await; - return; - } - HandshakeAction::ClientConnect(mut cci) => { - let ds_resp; - if let HandshakeData::DirSync(DirSync::Resp(resp)) = - cci.handshake.data - { - ds_resp = resp; - } else { - ::tracing::error!("ClientConnect on non DS::Resp"); - return; - } - // track connection - use handshake::dirsync; - let resp_data; - if let dirsync::RespInner::ClearText(r_data) = ds_resp.data - { - resp_data = r_data; - } else { - ::tracing::error!( - "ClientConnect on non DS::Resp::ClearText" - ); - return; - } - // FIXME: conn tracking and arc counting - let conn = Arc::get_mut(&mut cci.connection).unwrap(); - conn.id_send = IDSend(resp_data.id); - todo!(); - } - _ => {} - }; - } - // copy packet, spawn - todo!(); - } - async fn send_packet( - &self, - data: Vec, - client: UdpClient, - server: UdpServer, - ) { - let src_sock; - { - let sockets = self.sockets.lock(); - src_sock = match sockets.find(server) { - Some(src_sock) => src_sock, - None => { - ::tracing::error!( - "Can't send packet: Server changed listening ip!" - ); - return; - } - }; - } - src_sock.send_to(&data, client.0); - } }