Per-thread work loop

This will let us have a lot less locking.
We can do better in the future with ebpf and pinning connection to
a specific CPU with multiple listen() points on the same address,
but good enough for now

Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
This commit is contained in:
Luca Fulchir 2023-05-23 18:20:08 +02:00
parent 28cbe2ae20
commit c0d6cf1824
Signed by: luca.fulchir
GPG Key ID: 8F6440603D13A78E
4 changed files with 152 additions and 30 deletions

View File

@ -35,6 +35,7 @@ bitmaps = { version = "3.2" }
chacha20poly1305 = { version = "0.10" } chacha20poly1305 = { version = "0.10" }
futures = { version = "0.3" } futures = { version = "0.3" }
hkdf = { version = "0.12" } hkdf = { version = "0.12" }
hwloc2 = {version = "2.2" }
libc = { version = "0.2" } libc = { version = "0.2" }
num-traits = { version = "0.2" } num-traits = { version = "0.2" }
num-derive = { version = "0.3" } num-derive = { version = "0.3" }

View File

@ -1,7 +1,7 @@
//! Connection handling and send/receive queues //! Connection handling and send/receive queues
pub mod handshake; pub mod handshake;
mod packet; pub mod packet;
pub mod socket; pub mod socket;
use ::std::{sync::Arc, vec::Vec}; use ::std::{sync::Arc, vec::Vec};

View File

@ -91,6 +91,8 @@ impl From<[u8; 8]> for ConnectionID {
pub enum PacketData { pub enum PacketData {
/// A parsed handshake packet /// A parsed handshake packet
Handshake(super::Handshake), Handshake(super::Handshake),
/// Raw packet. we only have the connection ID and packet length
Raw(usize),
} }
impl PacketData { impl PacketData {
@ -98,6 +100,7 @@ impl PacketData {
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
match self { match self {
PacketData::Handshake(h) => h.len(), PacketData::Handshake(h) => h.len(),
PacketData::Raw(len) => *len
} }
} }
/// serialize data into bytes /// serialize data into bytes
@ -111,10 +114,15 @@ impl PacketData {
assert!(self.len() == out.len(), "PacketData: wrong buffer length"); assert!(self.len() == out.len(), "PacketData: wrong buffer length");
match self { match self {
PacketData::Handshake(h) => h.serialize(head_len, tag_len, out), PacketData::Handshake(h) => h.serialize(head_len, tag_len, out),
PacketData::Raw(_) => {
::tracing::error!("Tried to serialize a raw PacketData!");
}
} }
} }
} }
const MIN_PACKET_BYTES: usize = 16;
/// Fenrir packet structure /// Fenrir packet structure
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Packet { pub struct Packet {
@ -125,6 +133,18 @@ pub struct Packet {
} }
impl Packet { impl Packet {
/// New recevied packet, yet unparsed
pub fn deserialize_id(raw: &[u8]) -> Result<Self,()> {
// TODO: proper min_packet length. 16 is too conservative.
if raw.len() < MIN_PACKET_BYTES {
return Err(());
}
let raw_id: [u8; 8] = (raw[..8]).try_into().expect("unreachable");
Ok(Packet {
id: raw_id.into(),
data: PacketData::Raw(raw.len()),
})
}
/// get the total length of the packet /// get the total length of the packet
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
ConnectionID::len() + self.data.len() ConnectionID::len() + self.data.len()

View File

@ -24,7 +24,7 @@ use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption};
use ::std::{ use ::std::{
net::SocketAddr, net::SocketAddr,
pin::Pin, pin::Pin,
sync::Arc, sync::{Arc, Weak},
vec::{self, Vec}, vec::{self, Vec},
}; };
use ::tokio::{ use ::tokio::{
@ -38,7 +38,7 @@ use crate::{
HandshakeServer, HandshakeServer,
}, },
socket::{SocketList, UdpClient, UdpServer}, socket::{SocketList, UdpClient, UdpServer},
ConnList, Connection, IDSend, ConnList, Connection, IDSend, Packet,
}, },
enc::{ enc::{
asym, asym,
@ -79,9 +79,10 @@ type TokenChecker =
/// Track a raw Udp packet /// Track a raw Udp packet
struct RawUdp { struct RawUdp {
data: Vec<u8>,
src: UdpClient, src: UdpClient,
dst: UdpServer, dst: UdpServer,
data: Vec<u8>,
packet: Packet,
} }
enum Work { enum Work {
@ -103,14 +104,15 @@ pub struct Fenrir {
_inner: Arc<inner::Tracker>, _inner: Arc<inner::Tracker>,
/// where to ask for token check /// where to ask for token check
token_check: Arc<ArcSwapOption<TokenChecker>>, token_check: Arc<ArcSwapOption<TokenChecker>>,
/// MPMC work queue. sender
work_send: Arc<::async_channel::Sender<Work>>,
/// MPMC work queue. receiver
work_recv: Arc<::async_channel::Receiver<Work>>,
// PERF: rand uses syscalls. should we do that async? // PERF: rand uses syscalls. should we do that async?
rand: ::ring::rand::SystemRandom, rand: ::ring::rand::SystemRandom,
/// list of Established connections /// list of Established connections
connections: Arc<RwLock<ConnList>>, connections: Arc<RwLock<ConnList>>,
_myself: Weak<Self>,
// TODO: find a way to both increase and decrease these two in a thread-safe
// manner
_thread_pool: Vec<::std::thread::JoinHandle<()>>,
_thread_work: Arc<Vec<::async_channel::Sender<Work>>>,
} }
// TODO: graceful vs immediate stop // TODO: graceful vs immediate stop
@ -123,22 +125,23 @@ impl Drop for Fenrir {
impl Fenrir { impl Fenrir {
/// Create a new Fenrir endpoint /// Create a new Fenrir endpoint
pub fn new(config: &Config) -> Result<Self, Error> { pub fn new(config: &Config) -> Result<Arc<Self>, Error> {
let listen_num = config.listen.len(); let listen_num = config.listen.len();
let (sender, _) = ::tokio::sync::broadcast::channel(1); let (sender, _) = ::tokio::sync::broadcast::channel(1);
let (work_send, work_recv) = ::async_channel::unbounded::<Work>(); let (work_send, work_recv) = ::async_channel::unbounded::<Work>();
let endpoint = Fenrir { let endpoint = Arc::new_cyclic(|myself| Fenrir {
cfg: config.clone(), cfg: config.clone(),
sockets: SocketList::new(), sockets: SocketList::new(),
dnssec: None, dnssec: None,
stop_working: sender, stop_working: sender,
_inner: Arc::new(inner::Tracker::new()), _inner: Arc::new(inner::Tracker::new()),
token_check: Arc::new(ArcSwapOption::from(None)), token_check: Arc::new(ArcSwapOption::from(None)),
work_send: Arc::new(work_send),
work_recv: Arc::new(work_recv),
rand: ::ring::rand::SystemRandom::new(), rand: ::ring::rand::SystemRandom::new(),
connections: Arc::new(RwLock::new(ConnList::new())), connections: Arc::new(RwLock::new(ConnList::new())),
}; _myself: myself.clone(),
_thread_pool: Vec::new(),
_thread_work: Arc::new(Vec::new()),
});
Ok(endpoint) Ok(endpoint)
} }
@ -169,7 +172,6 @@ impl Fenrir {
toempty_sockets.stop_all().await; toempty_sockets.stop_all().await;
self.dnssec = None; self.dnssec = None;
} }
/// Add all UDP sockets found in config /// Add all UDP sockets found in config
/// and start listening for packets /// and start listening for packets
async fn add_sockets(&self) -> ::std::io::Result<()> { async fn add_sockets(&self) -> ::std::io::Result<()> {
@ -187,7 +189,7 @@ impl Fenrir {
let arc_s = Arc::new(s); let arc_s = Arc::new(s);
let join = ::tokio::spawn(Self::listen_udp( let join = ::tokio::spawn(Self::listen_udp(
stop_working, stop_working,
self.work_send.clone(), self._thread_work.clone(),
arc_s.clone(), arc_s.clone(),
)); ));
self.sockets.add_socket(arc_s, join); self.sockets.add_socket(arc_s, join);
@ -203,12 +205,13 @@ impl Fenrir {
/// Run a dedicated loop to read packets on the listening socket /// Run a dedicated loop to read packets on the listening socket
async fn listen_udp( async fn listen_udp(
mut stop_working: ::tokio::sync::broadcast::Receiver<bool>, mut stop_working: ::tokio::sync::broadcast::Receiver<bool>,
work_queue: Arc<::async_channel::Sender<Work>>, work_queues: Arc<Vec<::async_channel::Sender<Work>>>,
socket: Arc<UdpSocket>, socket: Arc<UdpSocket>,
) -> ::std::io::Result<()> { ) -> ::std::io::Result<()> {
// jumbo frames are 9K max // jumbo frames are 9K max
let sock_receiver = UdpServer(socket.local_addr()?); let sock_receiver = UdpServer(socket.local_addr()?);
let mut buffer: [u8; 9000] = [0; 9000]; let mut buffer: [u8; 9000] = [0; 9000];
let queues_num = work_queues.len() as u64;
loop { loop {
let (bytes, sock_sender) = ::tokio::select! { let (bytes, sock_sender) = ::tokio::select! {
_done = stop_working.recv() => { _done = stop_working.recv() => {
@ -219,10 +222,33 @@ impl Fenrir {
} }
}; };
let data: Vec<u8> = buffer[..bytes].to_vec(); let data: Vec<u8> = buffer[..bytes].to_vec();
work_queue.send(Work::Recv(RawUdp {
data, // we very likely have multiple threads, pinned to different cpus.
// use the ConnectionID to send the same connection
// to the same thread.
// Handshakes have conenction ID 0, so we use the sender's UDP port
let packet = match Packet::deserialize_id(&data) {
Ok(packet) => packet,
Err(_) => continue, // packet way too short, ignore.
};
let thread_idx: usize = {
use connection::packet::ConnectionID;
match packet.id {
ConnectionID::Handshake => {
let send_port = sock_sender.port() as u64;
((send_port % queues_num) - 1) as usize
}
ConnectionID::ID(id) => {
((id.get() % queues_num) - 1) as usize
}
}
};
work_queues[thread_idx].send(Work::Recv(RawUdp {
src: UdpClient(sock_sender), src: UdpClient(sock_sender),
dst: sock_receiver, dst: sock_receiver,
packet,
data,
})); }));
} }
Ok(()) Ok(())
@ -242,15 +268,97 @@ impl Fenrir {
Ok(dnssec::Dnssec::parse_txt_record(&record_str)?) Ok(dnssec::Dnssec::parse_txt_record(&record_str)?)
} }
/// Loop continuously and parse packets and other work /// Start one working thread for each physical cpu
pub async fn work_loop(&self) { /// threads are pinned to each cpu core.
/// Work will be divided and rerouted so that there is no need to lock
pub async fn start_work_threads_pinned(
&mut self,
tokio_rt: Arc<::tokio::runtime::Runtime>,
) -> ::std::result::Result<(), ()> {
use ::std::sync::Mutex;
let hw_topology = match ::hwloc2::Topology::new() {
Some(hw_topology) => Arc::new(Mutex::new(hw_topology)),
None => return Err(()),
};
let cores;
{
let topology_lock = hw_topology.lock().unwrap();
let all_cores = match topology_lock
.objects_with_type(&::hwloc2::ObjectType::Core)
{
Ok(all_cores) => all_cores,
Err(_) => return Err(()),
};
cores = all_cores.len();
if cores <= 0 || !topology_lock.support().cpu().set_thread() {
::tracing::error!("No support for CPU pinning");
return Err(());
}
}
for core in 0..cores {
::tracing::debug!("Spawning thread {}", core);
let th_topology = hw_topology.clone();
let th_tokio_rt = tokio_rt.clone();
let th_myself = self._myself.upgrade().unwrap();
let (work_send, work_recv) = ::async_channel::unbounded::<Work>();
let join_handle = ::std::thread::spawn(move || {
// bind to a specific core
let th_pinning;
{
let mut th_topology_lock = th_topology.lock().unwrap();
let th_cores = th_topology_lock
.objects_with_type(&::hwloc2::ObjectType::Core)
.unwrap();
let cpuset = th_cores.get(core).unwrap().cpuset().unwrap();
th_pinning = th_topology_lock.set_cpubind(
cpuset,
::hwloc2::CpuBindFlags::CPUBIND_THREAD,
);
}
match th_pinning {
Ok(_) => {}
Err(_) => {
::tracing::error!("Can't bind thread to cpu");
return;
}
}
// finally run the main listener. make sure things stay on this
// thread
let tk_local = ::tokio::task::LocalSet::new();
let _ = tk_local.block_on(
&th_tokio_rt,
Self::work_loop_thread(th_myself, work_recv),
);
});
loop {
let queues_lock = match Arc::get_mut(&mut self._thread_work) {
Some(queues_lock) => queues_lock,
None => {
::tokio::time::sleep(
::std::time::Duration::from_millis(50),
)
.await;
continue;
}
};
queues_lock.push(work_send);
break;
}
self._thread_pool.push(join_handle);
}
Ok(())
}
async fn work_loop_thread(
self: Arc<Self>,
work_recv: ::async_channel::Receiver<Work>,
) {
let mut stop_working = self.stop_working.subscribe(); let mut stop_working = self.stop_working.subscribe();
loop { loop {
let work = ::tokio::select! { let work = ::tokio::select! {
_done = stop_working.recv() => { _done = stop_working.recv() => {
break; break;
} }
maybe_work = self.work_recv.recv() => { maybe_work = work_recv.recv() => {
match maybe_work { match maybe_work {
Ok(work) => work, Ok(work) => work,
Err(_) => break, Err(_) => break,
@ -265,16 +373,9 @@ impl Fenrir {
} }
} }
const MIN_PACKET_BYTES: usize = 8;
/// Read and do stuff with the raw udp packet /// Read and do stuff with the raw udp packet
async fn recv(&self, mut udp: RawUdp) { async fn recv(&self, mut udp: RawUdp) {
if udp.data.len() < Self::MIN_PACKET_BYTES { if udp.packet.id.is_handshake() {
return;
}
use connection::ID;
let raw_id: [u8; 8] = (udp.data[..8]).try_into().expect("unreachable");
if ID::from(raw_id).is_handshake() {
use connection::handshake::Handshake;
let handshake = match Handshake::deserialize(&udp.data[8..]) { let handshake = match Handshake::deserialize(&udp.data[8..]) {
Ok(handshake) => handshake, Ok(handshake) => handshake,
Err(e) => { Err(e) => {
@ -390,7 +491,7 @@ impl Fenrir {
let resp_handshake = Handshake::new( let resp_handshake = Handshake::new(
HandshakeData::DirSync(DirSync::Resp(resp)), HandshakeData::DirSync(DirSync::Resp(resp)),
); );
use connection::{Packet, PacketData, ID}; use connection::{PacketData, ID};
let packet = Packet { let packet = Packet {
id: ID::new_handshake(), id: ID::new_handshake(),
data: PacketData::Handshake(resp_handshake), data: PacketData::Handshake(resp_handshake),