diff --git a/Cargo.toml b/Cargo.toml index 906053a..8142aa8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,6 +35,7 @@ bitmaps = { version = "3.2" } chacha20poly1305 = { version = "0.10" } futures = { version = "0.3" } hkdf = { version = "0.12" } +hwloc2 = {version = "2.2" } libc = { version = "0.2" } num-traits = { version = "0.2" } num-derive = { version = "0.3" } diff --git a/src/connection/mod.rs b/src/connection/mod.rs index 3c74c94..fe75884 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -1,7 +1,7 @@ //! Connection handling and send/receive queues pub mod handshake; -mod packet; +pub mod packet; pub mod socket; use ::std::{sync::Arc, vec::Vec}; diff --git a/src/connection/packet.rs b/src/connection/packet.rs index 9e5eed3..400f900 100644 --- a/src/connection/packet.rs +++ b/src/connection/packet.rs @@ -91,6 +91,8 @@ impl From<[u8; 8]> for ConnectionID { pub enum PacketData { /// A parsed handshake packet Handshake(super::Handshake), + /// Raw packet. we only have the connection ID and packet length + Raw(usize), } impl PacketData { @@ -98,6 +100,7 @@ impl PacketData { pub fn len(&self) -> usize { match self { PacketData::Handshake(h) => h.len(), + PacketData::Raw(len) => *len } } /// serialize data into bytes @@ -111,10 +114,15 @@ impl PacketData { assert!(self.len() == out.len(), "PacketData: wrong buffer length"); match self { PacketData::Handshake(h) => h.serialize(head_len, tag_len, out), + PacketData::Raw(_) => { + ::tracing::error!("Tried to serialize a raw PacketData!"); + } } } } +const MIN_PACKET_BYTES: usize = 16; + /// Fenrir packet structure #[derive(Debug, Clone)] pub struct Packet { @@ -125,6 +133,18 @@ pub struct Packet { } impl Packet { + /// New recevied packet, yet unparsed + pub fn deserialize_id(raw: &[u8]) -> Result { + // TODO: proper min_packet length. 16 is too conservative. + if raw.len() < MIN_PACKET_BYTES { + return Err(()); + } + let raw_id: [u8; 8] = (raw[..8]).try_into().expect("unreachable"); + Ok(Packet { + id: raw_id.into(), + data: PacketData::Raw(raw.len()), + }) + } /// get the total length of the packet pub fn len(&self) -> usize { ConnectionID::len() + self.data.len() diff --git a/src/lib.rs b/src/lib.rs index a4faa07..94e67e7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,7 +24,7 @@ use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption}; use ::std::{ net::SocketAddr, pin::Pin, - sync::Arc, + sync::{Arc, Weak}, vec::{self, Vec}, }; use ::tokio::{ @@ -38,7 +38,7 @@ use crate::{ HandshakeServer, }, socket::{SocketList, UdpClient, UdpServer}, - ConnList, Connection, IDSend, + ConnList, Connection, IDSend, Packet, }, enc::{ asym, @@ -79,9 +79,10 @@ type TokenChecker = /// Track a raw Udp packet struct RawUdp { - data: Vec, src: UdpClient, dst: UdpServer, + data: Vec, + packet: Packet, } enum Work { @@ -103,14 +104,15 @@ pub struct Fenrir { _inner: Arc, /// where to ask for token check token_check: Arc>, - /// MPMC work queue. sender - work_send: Arc<::async_channel::Sender>, - /// MPMC work queue. receiver - work_recv: Arc<::async_channel::Receiver>, // PERF: rand uses syscalls. should we do that async? rand: ::ring::rand::SystemRandom, /// list of Established connections connections: Arc>, + _myself: Weak, + // TODO: find a way to both increase and decrease these two in a thread-safe + // manner + _thread_pool: Vec<::std::thread::JoinHandle<()>>, + _thread_work: Arc>>, } // TODO: graceful vs immediate stop @@ -123,22 +125,23 @@ impl Drop for Fenrir { impl Fenrir { /// Create a new Fenrir endpoint - pub fn new(config: &Config) -> Result { + pub fn new(config: &Config) -> Result, Error> { let listen_num = config.listen.len(); let (sender, _) = ::tokio::sync::broadcast::channel(1); let (work_send, work_recv) = ::async_channel::unbounded::(); - let endpoint = Fenrir { + let endpoint = Arc::new_cyclic(|myself| Fenrir { cfg: config.clone(), sockets: SocketList::new(), dnssec: None, stop_working: sender, _inner: Arc::new(inner::Tracker::new()), token_check: Arc::new(ArcSwapOption::from(None)), - work_send: Arc::new(work_send), - work_recv: Arc::new(work_recv), rand: ::ring::rand::SystemRandom::new(), connections: Arc::new(RwLock::new(ConnList::new())), - }; + _myself: myself.clone(), + _thread_pool: Vec::new(), + _thread_work: Arc::new(Vec::new()), + }); Ok(endpoint) } @@ -169,7 +172,6 @@ impl Fenrir { toempty_sockets.stop_all().await; self.dnssec = None; } - /// Add all UDP sockets found in config /// and start listening for packets async fn add_sockets(&self) -> ::std::io::Result<()> { @@ -187,7 +189,7 @@ impl Fenrir { let arc_s = Arc::new(s); let join = ::tokio::spawn(Self::listen_udp( stop_working, - self.work_send.clone(), + self._thread_work.clone(), arc_s.clone(), )); self.sockets.add_socket(arc_s, join); @@ -203,12 +205,13 @@ impl Fenrir { /// Run a dedicated loop to read packets on the listening socket async fn listen_udp( mut stop_working: ::tokio::sync::broadcast::Receiver, - work_queue: Arc<::async_channel::Sender>, + work_queues: Arc>>, socket: Arc, ) -> ::std::io::Result<()> { // jumbo frames are 9K max let sock_receiver = UdpServer(socket.local_addr()?); let mut buffer: [u8; 9000] = [0; 9000]; + let queues_num = work_queues.len() as u64; loop { let (bytes, sock_sender) = ::tokio::select! { _done = stop_working.recv() => { @@ -219,10 +222,33 @@ impl Fenrir { } }; let data: Vec = buffer[..bytes].to_vec(); - work_queue.send(Work::Recv(RawUdp { - data, + + // we very likely have multiple threads, pinned to different cpus. + // use the ConnectionID to send the same connection + // to the same thread. + // Handshakes have conenction ID 0, so we use the sender's UDP port + + let packet = match Packet::deserialize_id(&data) { + Ok(packet) => packet, + Err(_) => continue, // packet way too short, ignore. + }; + let thread_idx: usize = { + use connection::packet::ConnectionID; + match packet.id { + ConnectionID::Handshake => { + let send_port = sock_sender.port() as u64; + ((send_port % queues_num) - 1) as usize + } + ConnectionID::ID(id) => { + ((id.get() % queues_num) - 1) as usize + } + } + }; + work_queues[thread_idx].send(Work::Recv(RawUdp { src: UdpClient(sock_sender), dst: sock_receiver, + packet, + data, })); } Ok(()) @@ -242,15 +268,97 @@ impl Fenrir { Ok(dnssec::Dnssec::parse_txt_record(&record_str)?) } - /// Loop continuously and parse packets and other work - pub async fn work_loop(&self) { + /// Start one working thread for each physical cpu + /// threads are pinned to each cpu core. + /// Work will be divided and rerouted so that there is no need to lock + pub async fn start_work_threads_pinned( + &mut self, + tokio_rt: Arc<::tokio::runtime::Runtime>, + ) -> ::std::result::Result<(), ()> { + use ::std::sync::Mutex; + let hw_topology = match ::hwloc2::Topology::new() { + Some(hw_topology) => Arc::new(Mutex::new(hw_topology)), + None => return Err(()), + }; + let cores; + { + let topology_lock = hw_topology.lock().unwrap(); + let all_cores = match topology_lock + .objects_with_type(&::hwloc2::ObjectType::Core) + { + Ok(all_cores) => all_cores, + Err(_) => return Err(()), + }; + cores = all_cores.len(); + if cores <= 0 || !topology_lock.support().cpu().set_thread() { + ::tracing::error!("No support for CPU pinning"); + return Err(()); + } + } + for core in 0..cores { + ::tracing::debug!("Spawning thread {}", core); + let th_topology = hw_topology.clone(); + let th_tokio_rt = tokio_rt.clone(); + let th_myself = self._myself.upgrade().unwrap(); + let (work_send, work_recv) = ::async_channel::unbounded::(); + let join_handle = ::std::thread::spawn(move || { + // bind to a specific core + let th_pinning; + { + let mut th_topology_lock = th_topology.lock().unwrap(); + let th_cores = th_topology_lock + .objects_with_type(&::hwloc2::ObjectType::Core) + .unwrap(); + let cpuset = th_cores.get(core).unwrap().cpuset().unwrap(); + th_pinning = th_topology_lock.set_cpubind( + cpuset, + ::hwloc2::CpuBindFlags::CPUBIND_THREAD, + ); + } + match th_pinning { + Ok(_) => {} + Err(_) => { + ::tracing::error!("Can't bind thread to cpu"); + return; + } + } + // finally run the main listener. make sure things stay on this + // thread + let tk_local = ::tokio::task::LocalSet::new(); + let _ = tk_local.block_on( + &th_tokio_rt, + Self::work_loop_thread(th_myself, work_recv), + ); + }); + loop { + let queues_lock = match Arc::get_mut(&mut self._thread_work) { + Some(queues_lock) => queues_lock, + None => { + ::tokio::time::sleep( + ::std::time::Duration::from_millis(50), + ) + .await; + continue; + } + }; + queues_lock.push(work_send); + break; + } + self._thread_pool.push(join_handle); + } + Ok(()) + } + async fn work_loop_thread( + self: Arc, + work_recv: ::async_channel::Receiver, + ) { let mut stop_working = self.stop_working.subscribe(); loop { let work = ::tokio::select! { _done = stop_working.recv() => { break; } - maybe_work = self.work_recv.recv() => { + maybe_work = work_recv.recv() => { match maybe_work { Ok(work) => work, Err(_) => break, @@ -265,16 +373,9 @@ impl Fenrir { } } - const MIN_PACKET_BYTES: usize = 8; /// Read and do stuff with the raw udp packet async fn recv(&self, mut udp: RawUdp) { - if udp.data.len() < Self::MIN_PACKET_BYTES { - return; - } - use connection::ID; - let raw_id: [u8; 8] = (udp.data[..8]).try_into().expect("unreachable"); - if ID::from(raw_id).is_handshake() { - use connection::handshake::Handshake; + if udp.packet.id.is_handshake() { let handshake = match Handshake::deserialize(&udp.data[8..]) { Ok(handshake) => handshake, Err(e) => { @@ -390,7 +491,7 @@ impl Fenrir { let resp_handshake = Handshake::new( HandshakeData::DirSync(DirSync::Resp(resp)), ); - use connection::{Packet, PacketData, ID}; + use connection::{PacketData, ID}; let packet = Packet { id: ID::new_handshake(), data: PacketData::Handshake(resp_handshake),