Per-thread work loop
This will let us have a lot less locking. We can do better in the future with ebpf and pinning connection to a specific CPU with multiple listen() points on the same address, but good enough for now Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
This commit is contained in:
parent
28cbe2ae20
commit
c0d6cf1824
|
@ -35,6 +35,7 @@ bitmaps = { version = "3.2" }
|
||||||
chacha20poly1305 = { version = "0.10" }
|
chacha20poly1305 = { version = "0.10" }
|
||||||
futures = { version = "0.3" }
|
futures = { version = "0.3" }
|
||||||
hkdf = { version = "0.12" }
|
hkdf = { version = "0.12" }
|
||||||
|
hwloc2 = {version = "2.2" }
|
||||||
libc = { version = "0.2" }
|
libc = { version = "0.2" }
|
||||||
num-traits = { version = "0.2" }
|
num-traits = { version = "0.2" }
|
||||||
num-derive = { version = "0.3" }
|
num-derive = { version = "0.3" }
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
//! Connection handling and send/receive queues
|
//! Connection handling and send/receive queues
|
||||||
|
|
||||||
pub mod handshake;
|
pub mod handshake;
|
||||||
mod packet;
|
pub mod packet;
|
||||||
pub mod socket;
|
pub mod socket;
|
||||||
|
|
||||||
use ::std::{sync::Arc, vec::Vec};
|
use ::std::{sync::Arc, vec::Vec};
|
||||||
|
|
|
@ -91,6 +91,8 @@ impl From<[u8; 8]> for ConnectionID {
|
||||||
pub enum PacketData {
|
pub enum PacketData {
|
||||||
/// A parsed handshake packet
|
/// A parsed handshake packet
|
||||||
Handshake(super::Handshake),
|
Handshake(super::Handshake),
|
||||||
|
/// Raw packet. we only have the connection ID and packet length
|
||||||
|
Raw(usize),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PacketData {
|
impl PacketData {
|
||||||
|
@ -98,6 +100,7 @@ impl PacketData {
|
||||||
pub fn len(&self) -> usize {
|
pub fn len(&self) -> usize {
|
||||||
match self {
|
match self {
|
||||||
PacketData::Handshake(h) => h.len(),
|
PacketData::Handshake(h) => h.len(),
|
||||||
|
PacketData::Raw(len) => *len
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/// serialize data into bytes
|
/// serialize data into bytes
|
||||||
|
@ -111,9 +114,14 @@ impl PacketData {
|
||||||
assert!(self.len() == out.len(), "PacketData: wrong buffer length");
|
assert!(self.len() == out.len(), "PacketData: wrong buffer length");
|
||||||
match self {
|
match self {
|
||||||
PacketData::Handshake(h) => h.serialize(head_len, tag_len, out),
|
PacketData::Handshake(h) => h.serialize(head_len, tag_len, out),
|
||||||
|
PacketData::Raw(_) => {
|
||||||
|
::tracing::error!("Tried to serialize a raw PacketData!");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const MIN_PACKET_BYTES: usize = 16;
|
||||||
|
|
||||||
/// Fenrir packet structure
|
/// Fenrir packet structure
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
|
@ -125,6 +133,18 @@ pub struct Packet {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Packet {
|
impl Packet {
|
||||||
|
/// New recevied packet, yet unparsed
|
||||||
|
pub fn deserialize_id(raw: &[u8]) -> Result<Self,()> {
|
||||||
|
// TODO: proper min_packet length. 16 is too conservative.
|
||||||
|
if raw.len() < MIN_PACKET_BYTES {
|
||||||
|
return Err(());
|
||||||
|
}
|
||||||
|
let raw_id: [u8; 8] = (raw[..8]).try_into().expect("unreachable");
|
||||||
|
Ok(Packet {
|
||||||
|
id: raw_id.into(),
|
||||||
|
data: PacketData::Raw(raw.len()),
|
||||||
|
})
|
||||||
|
}
|
||||||
/// get the total length of the packet
|
/// get the total length of the packet
|
||||||
pub fn len(&self) -> usize {
|
pub fn len(&self) -> usize {
|
||||||
ConnectionID::len() + self.data.len()
|
ConnectionID::len() + self.data.len()
|
||||||
|
|
159
src/lib.rs
159
src/lib.rs
|
@ -24,7 +24,7 @@ use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption};
|
||||||
use ::std::{
|
use ::std::{
|
||||||
net::SocketAddr,
|
net::SocketAddr,
|
||||||
pin::Pin,
|
pin::Pin,
|
||||||
sync::Arc,
|
sync::{Arc, Weak},
|
||||||
vec::{self, Vec},
|
vec::{self, Vec},
|
||||||
};
|
};
|
||||||
use ::tokio::{
|
use ::tokio::{
|
||||||
|
@ -38,7 +38,7 @@ use crate::{
|
||||||
HandshakeServer,
|
HandshakeServer,
|
||||||
},
|
},
|
||||||
socket::{SocketList, UdpClient, UdpServer},
|
socket::{SocketList, UdpClient, UdpServer},
|
||||||
ConnList, Connection, IDSend,
|
ConnList, Connection, IDSend, Packet,
|
||||||
},
|
},
|
||||||
enc::{
|
enc::{
|
||||||
asym,
|
asym,
|
||||||
|
@ -79,9 +79,10 @@ type TokenChecker =
|
||||||
|
|
||||||
/// Track a raw Udp packet
|
/// Track a raw Udp packet
|
||||||
struct RawUdp {
|
struct RawUdp {
|
||||||
data: Vec<u8>,
|
|
||||||
src: UdpClient,
|
src: UdpClient,
|
||||||
dst: UdpServer,
|
dst: UdpServer,
|
||||||
|
data: Vec<u8>,
|
||||||
|
packet: Packet,
|
||||||
}
|
}
|
||||||
|
|
||||||
enum Work {
|
enum Work {
|
||||||
|
@ -103,14 +104,15 @@ pub struct Fenrir {
|
||||||
_inner: Arc<inner::Tracker>,
|
_inner: Arc<inner::Tracker>,
|
||||||
/// where to ask for token check
|
/// where to ask for token check
|
||||||
token_check: Arc<ArcSwapOption<TokenChecker>>,
|
token_check: Arc<ArcSwapOption<TokenChecker>>,
|
||||||
/// MPMC work queue. sender
|
|
||||||
work_send: Arc<::async_channel::Sender<Work>>,
|
|
||||||
/// MPMC work queue. receiver
|
|
||||||
work_recv: Arc<::async_channel::Receiver<Work>>,
|
|
||||||
// PERF: rand uses syscalls. should we do that async?
|
// PERF: rand uses syscalls. should we do that async?
|
||||||
rand: ::ring::rand::SystemRandom,
|
rand: ::ring::rand::SystemRandom,
|
||||||
/// list of Established connections
|
/// list of Established connections
|
||||||
connections: Arc<RwLock<ConnList>>,
|
connections: Arc<RwLock<ConnList>>,
|
||||||
|
_myself: Weak<Self>,
|
||||||
|
// TODO: find a way to both increase and decrease these two in a thread-safe
|
||||||
|
// manner
|
||||||
|
_thread_pool: Vec<::std::thread::JoinHandle<()>>,
|
||||||
|
_thread_work: Arc<Vec<::async_channel::Sender<Work>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: graceful vs immediate stop
|
// TODO: graceful vs immediate stop
|
||||||
|
@ -123,22 +125,23 @@ impl Drop for Fenrir {
|
||||||
|
|
||||||
impl Fenrir {
|
impl Fenrir {
|
||||||
/// Create a new Fenrir endpoint
|
/// Create a new Fenrir endpoint
|
||||||
pub fn new(config: &Config) -> Result<Self, Error> {
|
pub fn new(config: &Config) -> Result<Arc<Self>, Error> {
|
||||||
let listen_num = config.listen.len();
|
let listen_num = config.listen.len();
|
||||||
let (sender, _) = ::tokio::sync::broadcast::channel(1);
|
let (sender, _) = ::tokio::sync::broadcast::channel(1);
|
||||||
let (work_send, work_recv) = ::async_channel::unbounded::<Work>();
|
let (work_send, work_recv) = ::async_channel::unbounded::<Work>();
|
||||||
let endpoint = Fenrir {
|
let endpoint = Arc::new_cyclic(|myself| Fenrir {
|
||||||
cfg: config.clone(),
|
cfg: config.clone(),
|
||||||
sockets: SocketList::new(),
|
sockets: SocketList::new(),
|
||||||
dnssec: None,
|
dnssec: None,
|
||||||
stop_working: sender,
|
stop_working: sender,
|
||||||
_inner: Arc::new(inner::Tracker::new()),
|
_inner: Arc::new(inner::Tracker::new()),
|
||||||
token_check: Arc::new(ArcSwapOption::from(None)),
|
token_check: Arc::new(ArcSwapOption::from(None)),
|
||||||
work_send: Arc::new(work_send),
|
|
||||||
work_recv: Arc::new(work_recv),
|
|
||||||
rand: ::ring::rand::SystemRandom::new(),
|
rand: ::ring::rand::SystemRandom::new(),
|
||||||
connections: Arc::new(RwLock::new(ConnList::new())),
|
connections: Arc::new(RwLock::new(ConnList::new())),
|
||||||
};
|
_myself: myself.clone(),
|
||||||
|
_thread_pool: Vec::new(),
|
||||||
|
_thread_work: Arc::new(Vec::new()),
|
||||||
|
});
|
||||||
Ok(endpoint)
|
Ok(endpoint)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -169,7 +172,6 @@ impl Fenrir {
|
||||||
toempty_sockets.stop_all().await;
|
toempty_sockets.stop_all().await;
|
||||||
self.dnssec = None;
|
self.dnssec = None;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add all UDP sockets found in config
|
/// Add all UDP sockets found in config
|
||||||
/// and start listening for packets
|
/// and start listening for packets
|
||||||
async fn add_sockets(&self) -> ::std::io::Result<()> {
|
async fn add_sockets(&self) -> ::std::io::Result<()> {
|
||||||
|
@ -187,7 +189,7 @@ impl Fenrir {
|
||||||
let arc_s = Arc::new(s);
|
let arc_s = Arc::new(s);
|
||||||
let join = ::tokio::spawn(Self::listen_udp(
|
let join = ::tokio::spawn(Self::listen_udp(
|
||||||
stop_working,
|
stop_working,
|
||||||
self.work_send.clone(),
|
self._thread_work.clone(),
|
||||||
arc_s.clone(),
|
arc_s.clone(),
|
||||||
));
|
));
|
||||||
self.sockets.add_socket(arc_s, join);
|
self.sockets.add_socket(arc_s, join);
|
||||||
|
@ -203,12 +205,13 @@ impl Fenrir {
|
||||||
/// Run a dedicated loop to read packets on the listening socket
|
/// Run a dedicated loop to read packets on the listening socket
|
||||||
async fn listen_udp(
|
async fn listen_udp(
|
||||||
mut stop_working: ::tokio::sync::broadcast::Receiver<bool>,
|
mut stop_working: ::tokio::sync::broadcast::Receiver<bool>,
|
||||||
work_queue: Arc<::async_channel::Sender<Work>>,
|
work_queues: Arc<Vec<::async_channel::Sender<Work>>>,
|
||||||
socket: Arc<UdpSocket>,
|
socket: Arc<UdpSocket>,
|
||||||
) -> ::std::io::Result<()> {
|
) -> ::std::io::Result<()> {
|
||||||
// jumbo frames are 9K max
|
// jumbo frames are 9K max
|
||||||
let sock_receiver = UdpServer(socket.local_addr()?);
|
let sock_receiver = UdpServer(socket.local_addr()?);
|
||||||
let mut buffer: [u8; 9000] = [0; 9000];
|
let mut buffer: [u8; 9000] = [0; 9000];
|
||||||
|
let queues_num = work_queues.len() as u64;
|
||||||
loop {
|
loop {
|
||||||
let (bytes, sock_sender) = ::tokio::select! {
|
let (bytes, sock_sender) = ::tokio::select! {
|
||||||
_done = stop_working.recv() => {
|
_done = stop_working.recv() => {
|
||||||
|
@ -219,10 +222,33 @@ impl Fenrir {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let data: Vec<u8> = buffer[..bytes].to_vec();
|
let data: Vec<u8> = buffer[..bytes].to_vec();
|
||||||
work_queue.send(Work::Recv(RawUdp {
|
|
||||||
data,
|
// we very likely have multiple threads, pinned to different cpus.
|
||||||
|
// use the ConnectionID to send the same connection
|
||||||
|
// to the same thread.
|
||||||
|
// Handshakes have conenction ID 0, so we use the sender's UDP port
|
||||||
|
|
||||||
|
let packet = match Packet::deserialize_id(&data) {
|
||||||
|
Ok(packet) => packet,
|
||||||
|
Err(_) => continue, // packet way too short, ignore.
|
||||||
|
};
|
||||||
|
let thread_idx: usize = {
|
||||||
|
use connection::packet::ConnectionID;
|
||||||
|
match packet.id {
|
||||||
|
ConnectionID::Handshake => {
|
||||||
|
let send_port = sock_sender.port() as u64;
|
||||||
|
((send_port % queues_num) - 1) as usize
|
||||||
|
}
|
||||||
|
ConnectionID::ID(id) => {
|
||||||
|
((id.get() % queues_num) - 1) as usize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
work_queues[thread_idx].send(Work::Recv(RawUdp {
|
||||||
src: UdpClient(sock_sender),
|
src: UdpClient(sock_sender),
|
||||||
dst: sock_receiver,
|
dst: sock_receiver,
|
||||||
|
packet,
|
||||||
|
data,
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -242,15 +268,97 @@ impl Fenrir {
|
||||||
Ok(dnssec::Dnssec::parse_txt_record(&record_str)?)
|
Ok(dnssec::Dnssec::parse_txt_record(&record_str)?)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Loop continuously and parse packets and other work
|
/// Start one working thread for each physical cpu
|
||||||
pub async fn work_loop(&self) {
|
/// threads are pinned to each cpu core.
|
||||||
|
/// Work will be divided and rerouted so that there is no need to lock
|
||||||
|
pub async fn start_work_threads_pinned(
|
||||||
|
&mut self,
|
||||||
|
tokio_rt: Arc<::tokio::runtime::Runtime>,
|
||||||
|
) -> ::std::result::Result<(), ()> {
|
||||||
|
use ::std::sync::Mutex;
|
||||||
|
let hw_topology = match ::hwloc2::Topology::new() {
|
||||||
|
Some(hw_topology) => Arc::new(Mutex::new(hw_topology)),
|
||||||
|
None => return Err(()),
|
||||||
|
};
|
||||||
|
let cores;
|
||||||
|
{
|
||||||
|
let topology_lock = hw_topology.lock().unwrap();
|
||||||
|
let all_cores = match topology_lock
|
||||||
|
.objects_with_type(&::hwloc2::ObjectType::Core)
|
||||||
|
{
|
||||||
|
Ok(all_cores) => all_cores,
|
||||||
|
Err(_) => return Err(()),
|
||||||
|
};
|
||||||
|
cores = all_cores.len();
|
||||||
|
if cores <= 0 || !topology_lock.support().cpu().set_thread() {
|
||||||
|
::tracing::error!("No support for CPU pinning");
|
||||||
|
return Err(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for core in 0..cores {
|
||||||
|
::tracing::debug!("Spawning thread {}", core);
|
||||||
|
let th_topology = hw_topology.clone();
|
||||||
|
let th_tokio_rt = tokio_rt.clone();
|
||||||
|
let th_myself = self._myself.upgrade().unwrap();
|
||||||
|
let (work_send, work_recv) = ::async_channel::unbounded::<Work>();
|
||||||
|
let join_handle = ::std::thread::spawn(move || {
|
||||||
|
// bind to a specific core
|
||||||
|
let th_pinning;
|
||||||
|
{
|
||||||
|
let mut th_topology_lock = th_topology.lock().unwrap();
|
||||||
|
let th_cores = th_topology_lock
|
||||||
|
.objects_with_type(&::hwloc2::ObjectType::Core)
|
||||||
|
.unwrap();
|
||||||
|
let cpuset = th_cores.get(core).unwrap().cpuset().unwrap();
|
||||||
|
th_pinning = th_topology_lock.set_cpubind(
|
||||||
|
cpuset,
|
||||||
|
::hwloc2::CpuBindFlags::CPUBIND_THREAD,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
match th_pinning {
|
||||||
|
Ok(_) => {}
|
||||||
|
Err(_) => {
|
||||||
|
::tracing::error!("Can't bind thread to cpu");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// finally run the main listener. make sure things stay on this
|
||||||
|
// thread
|
||||||
|
let tk_local = ::tokio::task::LocalSet::new();
|
||||||
|
let _ = tk_local.block_on(
|
||||||
|
&th_tokio_rt,
|
||||||
|
Self::work_loop_thread(th_myself, work_recv),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
loop {
|
||||||
|
let queues_lock = match Arc::get_mut(&mut self._thread_work) {
|
||||||
|
Some(queues_lock) => queues_lock,
|
||||||
|
None => {
|
||||||
|
::tokio::time::sleep(
|
||||||
|
::std::time::Duration::from_millis(50),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
queues_lock.push(work_send);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
self._thread_pool.push(join_handle);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
async fn work_loop_thread(
|
||||||
|
self: Arc<Self>,
|
||||||
|
work_recv: ::async_channel::Receiver<Work>,
|
||||||
|
) {
|
||||||
let mut stop_working = self.stop_working.subscribe();
|
let mut stop_working = self.stop_working.subscribe();
|
||||||
loop {
|
loop {
|
||||||
let work = ::tokio::select! {
|
let work = ::tokio::select! {
|
||||||
_done = stop_working.recv() => {
|
_done = stop_working.recv() => {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
maybe_work = self.work_recv.recv() => {
|
maybe_work = work_recv.recv() => {
|
||||||
match maybe_work {
|
match maybe_work {
|
||||||
Ok(work) => work,
|
Ok(work) => work,
|
||||||
Err(_) => break,
|
Err(_) => break,
|
||||||
|
@ -265,16 +373,9 @@ impl Fenrir {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const MIN_PACKET_BYTES: usize = 8;
|
|
||||||
/// Read and do stuff with the raw udp packet
|
/// Read and do stuff with the raw udp packet
|
||||||
async fn recv(&self, mut udp: RawUdp) {
|
async fn recv(&self, mut udp: RawUdp) {
|
||||||
if udp.data.len() < Self::MIN_PACKET_BYTES {
|
if udp.packet.id.is_handshake() {
|
||||||
return;
|
|
||||||
}
|
|
||||||
use connection::ID;
|
|
||||||
let raw_id: [u8; 8] = (udp.data[..8]).try_into().expect("unreachable");
|
|
||||||
if ID::from(raw_id).is_handshake() {
|
|
||||||
use connection::handshake::Handshake;
|
|
||||||
let handshake = match Handshake::deserialize(&udp.data[8..]) {
|
let handshake = match Handshake::deserialize(&udp.data[8..]) {
|
||||||
Ok(handshake) => handshake,
|
Ok(handshake) => handshake,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
@ -390,7 +491,7 @@ impl Fenrir {
|
||||||
let resp_handshake = Handshake::new(
|
let resp_handshake = Handshake::new(
|
||||||
HandshakeData::DirSync(DirSync::Resp(resp)),
|
HandshakeData::DirSync(DirSync::Resp(resp)),
|
||||||
);
|
);
|
||||||
use connection::{Packet, PacketData, ID};
|
use connection::{PacketData, ID};
|
||||||
let packet = Packet {
|
let packet = Packet {
|
||||||
id: ID::new_handshake(),
|
id: ID::new_handshake(),
|
||||||
data: PacketData::Handshake(resp_handshake),
|
data: PacketData::Handshake(resp_handshake),
|
||||||
|
|
Loading…
Reference in New Issue