Per-thread work loop
This will let us have a lot less locking. We can do better in the future with ebpf and pinning connection to a specific CPU with multiple listen() points on the same address, but good enough for now Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
This commit is contained in:
parent
28cbe2ae20
commit
c0d6cf1824
|
@ -35,6 +35,7 @@ bitmaps = { version = "3.2" }
|
|||
chacha20poly1305 = { version = "0.10" }
|
||||
futures = { version = "0.3" }
|
||||
hkdf = { version = "0.12" }
|
||||
hwloc2 = {version = "2.2" }
|
||||
libc = { version = "0.2" }
|
||||
num-traits = { version = "0.2" }
|
||||
num-derive = { version = "0.3" }
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
//! Connection handling and send/receive queues
|
||||
|
||||
pub mod handshake;
|
||||
mod packet;
|
||||
pub mod packet;
|
||||
pub mod socket;
|
||||
|
||||
use ::std::{sync::Arc, vec::Vec};
|
||||
|
|
|
@ -91,6 +91,8 @@ impl From<[u8; 8]> for ConnectionID {
|
|||
pub enum PacketData {
|
||||
/// A parsed handshake packet
|
||||
Handshake(super::Handshake),
|
||||
/// Raw packet. we only have the connection ID and packet length
|
||||
Raw(usize),
|
||||
}
|
||||
|
||||
impl PacketData {
|
||||
|
@ -98,6 +100,7 @@ impl PacketData {
|
|||
pub fn len(&self) -> usize {
|
||||
match self {
|
||||
PacketData::Handshake(h) => h.len(),
|
||||
PacketData::Raw(len) => *len
|
||||
}
|
||||
}
|
||||
/// serialize data into bytes
|
||||
|
@ -111,10 +114,15 @@ impl PacketData {
|
|||
assert!(self.len() == out.len(), "PacketData: wrong buffer length");
|
||||
match self {
|
||||
PacketData::Handshake(h) => h.serialize(head_len, tag_len, out),
|
||||
PacketData::Raw(_) => {
|
||||
::tracing::error!("Tried to serialize a raw PacketData!");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const MIN_PACKET_BYTES: usize = 16;
|
||||
|
||||
/// Fenrir packet structure
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Packet {
|
||||
|
@ -125,6 +133,18 @@ pub struct Packet {
|
|||
}
|
||||
|
||||
impl Packet {
|
||||
/// New recevied packet, yet unparsed
|
||||
pub fn deserialize_id(raw: &[u8]) -> Result<Self,()> {
|
||||
// TODO: proper min_packet length. 16 is too conservative.
|
||||
if raw.len() < MIN_PACKET_BYTES {
|
||||
return Err(());
|
||||
}
|
||||
let raw_id: [u8; 8] = (raw[..8]).try_into().expect("unreachable");
|
||||
Ok(Packet {
|
||||
id: raw_id.into(),
|
||||
data: PacketData::Raw(raw.len()),
|
||||
})
|
||||
}
|
||||
/// get the total length of the packet
|
||||
pub fn len(&self) -> usize {
|
||||
ConnectionID::len() + self.data.len()
|
||||
|
|
159
src/lib.rs
159
src/lib.rs
|
@ -24,7 +24,7 @@ use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption};
|
|||
use ::std::{
|
||||
net::SocketAddr,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
sync::{Arc, Weak},
|
||||
vec::{self, Vec},
|
||||
};
|
||||
use ::tokio::{
|
||||
|
@ -38,7 +38,7 @@ use crate::{
|
|||
HandshakeServer,
|
||||
},
|
||||
socket::{SocketList, UdpClient, UdpServer},
|
||||
ConnList, Connection, IDSend,
|
||||
ConnList, Connection, IDSend, Packet,
|
||||
},
|
||||
enc::{
|
||||
asym,
|
||||
|
@ -79,9 +79,10 @@ type TokenChecker =
|
|||
|
||||
/// Track a raw Udp packet
|
||||
struct RawUdp {
|
||||
data: Vec<u8>,
|
||||
src: UdpClient,
|
||||
dst: UdpServer,
|
||||
data: Vec<u8>,
|
||||
packet: Packet,
|
||||
}
|
||||
|
||||
enum Work {
|
||||
|
@ -103,14 +104,15 @@ pub struct Fenrir {
|
|||
_inner: Arc<inner::Tracker>,
|
||||
/// where to ask for token check
|
||||
token_check: Arc<ArcSwapOption<TokenChecker>>,
|
||||
/// MPMC work queue. sender
|
||||
work_send: Arc<::async_channel::Sender<Work>>,
|
||||
/// MPMC work queue. receiver
|
||||
work_recv: Arc<::async_channel::Receiver<Work>>,
|
||||
// PERF: rand uses syscalls. should we do that async?
|
||||
rand: ::ring::rand::SystemRandom,
|
||||
/// list of Established connections
|
||||
connections: Arc<RwLock<ConnList>>,
|
||||
_myself: Weak<Self>,
|
||||
// TODO: find a way to both increase and decrease these two in a thread-safe
|
||||
// manner
|
||||
_thread_pool: Vec<::std::thread::JoinHandle<()>>,
|
||||
_thread_work: Arc<Vec<::async_channel::Sender<Work>>>,
|
||||
}
|
||||
|
||||
// TODO: graceful vs immediate stop
|
||||
|
@ -123,22 +125,23 @@ impl Drop for Fenrir {
|
|||
|
||||
impl Fenrir {
|
||||
/// Create a new Fenrir endpoint
|
||||
pub fn new(config: &Config) -> Result<Self, Error> {
|
||||
pub fn new(config: &Config) -> Result<Arc<Self>, Error> {
|
||||
let listen_num = config.listen.len();
|
||||
let (sender, _) = ::tokio::sync::broadcast::channel(1);
|
||||
let (work_send, work_recv) = ::async_channel::unbounded::<Work>();
|
||||
let endpoint = Fenrir {
|
||||
let endpoint = Arc::new_cyclic(|myself| Fenrir {
|
||||
cfg: config.clone(),
|
||||
sockets: SocketList::new(),
|
||||
dnssec: None,
|
||||
stop_working: sender,
|
||||
_inner: Arc::new(inner::Tracker::new()),
|
||||
token_check: Arc::new(ArcSwapOption::from(None)),
|
||||
work_send: Arc::new(work_send),
|
||||
work_recv: Arc::new(work_recv),
|
||||
rand: ::ring::rand::SystemRandom::new(),
|
||||
connections: Arc::new(RwLock::new(ConnList::new())),
|
||||
};
|
||||
_myself: myself.clone(),
|
||||
_thread_pool: Vec::new(),
|
||||
_thread_work: Arc::new(Vec::new()),
|
||||
});
|
||||
Ok(endpoint)
|
||||
}
|
||||
|
||||
|
@ -169,7 +172,6 @@ impl Fenrir {
|
|||
toempty_sockets.stop_all().await;
|
||||
self.dnssec = None;
|
||||
}
|
||||
|
||||
/// Add all UDP sockets found in config
|
||||
/// and start listening for packets
|
||||
async fn add_sockets(&self) -> ::std::io::Result<()> {
|
||||
|
@ -187,7 +189,7 @@ impl Fenrir {
|
|||
let arc_s = Arc::new(s);
|
||||
let join = ::tokio::spawn(Self::listen_udp(
|
||||
stop_working,
|
||||
self.work_send.clone(),
|
||||
self._thread_work.clone(),
|
||||
arc_s.clone(),
|
||||
));
|
||||
self.sockets.add_socket(arc_s, join);
|
||||
|
@ -203,12 +205,13 @@ impl Fenrir {
|
|||
/// Run a dedicated loop to read packets on the listening socket
|
||||
async fn listen_udp(
|
||||
mut stop_working: ::tokio::sync::broadcast::Receiver<bool>,
|
||||
work_queue: Arc<::async_channel::Sender<Work>>,
|
||||
work_queues: Arc<Vec<::async_channel::Sender<Work>>>,
|
||||
socket: Arc<UdpSocket>,
|
||||
) -> ::std::io::Result<()> {
|
||||
// jumbo frames are 9K max
|
||||
let sock_receiver = UdpServer(socket.local_addr()?);
|
||||
let mut buffer: [u8; 9000] = [0; 9000];
|
||||
let queues_num = work_queues.len() as u64;
|
||||
loop {
|
||||
let (bytes, sock_sender) = ::tokio::select! {
|
||||
_done = stop_working.recv() => {
|
||||
|
@ -219,10 +222,33 @@ impl Fenrir {
|
|||
}
|
||||
};
|
||||
let data: Vec<u8> = buffer[..bytes].to_vec();
|
||||
work_queue.send(Work::Recv(RawUdp {
|
||||
data,
|
||||
|
||||
// we very likely have multiple threads, pinned to different cpus.
|
||||
// use the ConnectionID to send the same connection
|
||||
// to the same thread.
|
||||
// Handshakes have conenction ID 0, so we use the sender's UDP port
|
||||
|
||||
let packet = match Packet::deserialize_id(&data) {
|
||||
Ok(packet) => packet,
|
||||
Err(_) => continue, // packet way too short, ignore.
|
||||
};
|
||||
let thread_idx: usize = {
|
||||
use connection::packet::ConnectionID;
|
||||
match packet.id {
|
||||
ConnectionID::Handshake => {
|
||||
let send_port = sock_sender.port() as u64;
|
||||
((send_port % queues_num) - 1) as usize
|
||||
}
|
||||
ConnectionID::ID(id) => {
|
||||
((id.get() % queues_num) - 1) as usize
|
||||
}
|
||||
}
|
||||
};
|
||||
work_queues[thread_idx].send(Work::Recv(RawUdp {
|
||||
src: UdpClient(sock_sender),
|
||||
dst: sock_receiver,
|
||||
packet,
|
||||
data,
|
||||
}));
|
||||
}
|
||||
Ok(())
|
||||
|
@ -242,15 +268,97 @@ impl Fenrir {
|
|||
Ok(dnssec::Dnssec::parse_txt_record(&record_str)?)
|
||||
}
|
||||
|
||||
/// Loop continuously and parse packets and other work
|
||||
pub async fn work_loop(&self) {
|
||||
/// Start one working thread for each physical cpu
|
||||
/// threads are pinned to each cpu core.
|
||||
/// Work will be divided and rerouted so that there is no need to lock
|
||||
pub async fn start_work_threads_pinned(
|
||||
&mut self,
|
||||
tokio_rt: Arc<::tokio::runtime::Runtime>,
|
||||
) -> ::std::result::Result<(), ()> {
|
||||
use ::std::sync::Mutex;
|
||||
let hw_topology = match ::hwloc2::Topology::new() {
|
||||
Some(hw_topology) => Arc::new(Mutex::new(hw_topology)),
|
||||
None => return Err(()),
|
||||
};
|
||||
let cores;
|
||||
{
|
||||
let topology_lock = hw_topology.lock().unwrap();
|
||||
let all_cores = match topology_lock
|
||||
.objects_with_type(&::hwloc2::ObjectType::Core)
|
||||
{
|
||||
Ok(all_cores) => all_cores,
|
||||
Err(_) => return Err(()),
|
||||
};
|
||||
cores = all_cores.len();
|
||||
if cores <= 0 || !topology_lock.support().cpu().set_thread() {
|
||||
::tracing::error!("No support for CPU pinning");
|
||||
return Err(());
|
||||
}
|
||||
}
|
||||
for core in 0..cores {
|
||||
::tracing::debug!("Spawning thread {}", core);
|
||||
let th_topology = hw_topology.clone();
|
||||
let th_tokio_rt = tokio_rt.clone();
|
||||
let th_myself = self._myself.upgrade().unwrap();
|
||||
let (work_send, work_recv) = ::async_channel::unbounded::<Work>();
|
||||
let join_handle = ::std::thread::spawn(move || {
|
||||
// bind to a specific core
|
||||
let th_pinning;
|
||||
{
|
||||
let mut th_topology_lock = th_topology.lock().unwrap();
|
||||
let th_cores = th_topology_lock
|
||||
.objects_with_type(&::hwloc2::ObjectType::Core)
|
||||
.unwrap();
|
||||
let cpuset = th_cores.get(core).unwrap().cpuset().unwrap();
|
||||
th_pinning = th_topology_lock.set_cpubind(
|
||||
cpuset,
|
||||
::hwloc2::CpuBindFlags::CPUBIND_THREAD,
|
||||
);
|
||||
}
|
||||
match th_pinning {
|
||||
Ok(_) => {}
|
||||
Err(_) => {
|
||||
::tracing::error!("Can't bind thread to cpu");
|
||||
return;
|
||||
}
|
||||
}
|
||||
// finally run the main listener. make sure things stay on this
|
||||
// thread
|
||||
let tk_local = ::tokio::task::LocalSet::new();
|
||||
let _ = tk_local.block_on(
|
||||
&th_tokio_rt,
|
||||
Self::work_loop_thread(th_myself, work_recv),
|
||||
);
|
||||
});
|
||||
loop {
|
||||
let queues_lock = match Arc::get_mut(&mut self._thread_work) {
|
||||
Some(queues_lock) => queues_lock,
|
||||
None => {
|
||||
::tokio::time::sleep(
|
||||
::std::time::Duration::from_millis(50),
|
||||
)
|
||||
.await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
queues_lock.push(work_send);
|
||||
break;
|
||||
}
|
||||
self._thread_pool.push(join_handle);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
async fn work_loop_thread(
|
||||
self: Arc<Self>,
|
||||
work_recv: ::async_channel::Receiver<Work>,
|
||||
) {
|
||||
let mut stop_working = self.stop_working.subscribe();
|
||||
loop {
|
||||
let work = ::tokio::select! {
|
||||
_done = stop_working.recv() => {
|
||||
break;
|
||||
}
|
||||
maybe_work = self.work_recv.recv() => {
|
||||
maybe_work = work_recv.recv() => {
|
||||
match maybe_work {
|
||||
Ok(work) => work,
|
||||
Err(_) => break,
|
||||
|
@ -265,16 +373,9 @@ impl Fenrir {
|
|||
}
|
||||
}
|
||||
|
||||
const MIN_PACKET_BYTES: usize = 8;
|
||||
/// Read and do stuff with the raw udp packet
|
||||
async fn recv(&self, mut udp: RawUdp) {
|
||||
if udp.data.len() < Self::MIN_PACKET_BYTES {
|
||||
return;
|
||||
}
|
||||
use connection::ID;
|
||||
let raw_id: [u8; 8] = (udp.data[..8]).try_into().expect("unreachable");
|
||||
if ID::from(raw_id).is_handshake() {
|
||||
use connection::handshake::Handshake;
|
||||
if udp.packet.id.is_handshake() {
|
||||
let handshake = match Handshake::deserialize(&udp.data[8..]) {
|
||||
Ok(handshake) => handshake,
|
||||
Err(e) => {
|
||||
|
@ -390,7 +491,7 @@ impl Fenrir {
|
|||
let resp_handshake = Handshake::new(
|
||||
HandshakeData::DirSync(DirSync::Resp(resp)),
|
||||
);
|
||||
use connection::{Packet, PacketData, ID};
|
||||
use connection::{PacketData, ID};
|
||||
let packet = Packet {
|
||||
id: ID::new_handshake(),
|
||||
data: PacketData::Handshake(resp_handshake),
|
||||
|
|
Loading…
Reference in New Issue