MPMC queue to distribute work on threads

Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
This commit is contained in:
Luca Fulchir 2023-02-24 22:00:56 +01:00
parent 0d33033c0b
commit 59394959bd
Signed by: luca.fulchir
GPG Key ID: 8F6440603D13A78E
2 changed files with 67 additions and 37 deletions

View File

@ -27,6 +27,7 @@ crate_type = [ "lib", "cdylib", "staticlib" ]
# please keep these in alphabetical order # please keep these in alphabetical order
arc-swap = { version = "1.6" } arc-swap = { version = "1.6" }
async-channel = { version = "1.8" }
# base85 repo has no tags, fix on a commit. v1.1.1 points to older, wrong version # base85 repo has no tags, fix on a commit. v1.1.1 points to older, wrong version
base85 = { git = "https://gitlab.com/darkwyrm/base85", rev = "d98efbfd171dd9ba48e30a5c88f94db92fc7b3c6" } base85 = { git = "https://gitlab.com/darkwyrm/base85", rev = "d98efbfd171dd9ba48e30a5c88f94db92fc7b3c6" }
chacha20poly1305 = { version = "0.10" } chacha20poly1305 = { version = "0.10" }

View File

@ -54,13 +54,16 @@ pub enum Error {
// No async here // No async here
struct FenrirInner { struct FenrirInner {
// PERF: rand uses syscalls. should we do that async?
rand: ::ring::rand::SystemRandom,
key_exchanges: ArcSwapAny<Arc<Vec<(asym::Key, asym::KeyExchange)>>>, key_exchanges: ArcSwapAny<Arc<Vec<(asym::Key, asym::KeyExchange)>>>,
ciphers: ArcSwapAny<Arc<Vec<CipherKind>>>, ciphers: ArcSwapAny<Arc<Vec<CipherKind>>>,
keys: ArcSwapAny<Arc<Vec<HandshakeKey>>>, keys: ArcSwapAny<Arc<Vec<HandshakeKey>>>,
} }
#[allow(unsafe_code)]
unsafe impl Send for FenrirInner {}
#[allow(unsafe_code)]
unsafe impl Sync for FenrirInner {}
/// Information needed to reply after the key exchange /// Information needed to reply after the key exchange
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct AuthNeededInfo { pub struct AuthNeededInfo {
@ -250,13 +253,22 @@ impl SocketList {
} }
} }
struct RawUdp {
data: Vec<u8>,
src: SocketAddr,
dst: SocketAddr,
}
enum Work {
Recv(RawUdp),
}
/// Instance of a fenrir endpoint /// Instance of a fenrir endpoint
#[allow(missing_copy_implementations, missing_debug_implementations)] #[allow(missing_copy_implementations, missing_debug_implementations)]
pub struct Fenrir { pub struct Fenrir {
/// library Configuration /// library Configuration
cfg: Config, cfg: Config,
/// listening udp sockets /// listening udp sockets
//sockets: Vec<(Arc<UdpSocket>, JoinHandle<::std::io::Result<()>>)>,
sockets: SocketList, sockets: SocketList,
/// DNSSEC resolver, with failovers /// DNSSEC resolver, with failovers
dnssec: Option<dnssec::Dnssec>, dnssec: Option<dnssec::Dnssec>,
@ -266,6 +278,12 @@ pub struct Fenrir {
_inner: Arc<FenrirInner>, _inner: Arc<FenrirInner>,
/// where to ask for token check /// where to ask for token check
token_check: Arc<ArcSwapOption<TokenChecker>>, token_check: Arc<ArcSwapOption<TokenChecker>>,
/// MPMC work queue. sender
work_send: Arc<::async_channel::Sender<Work>>,
/// MPMC work queue. receiver
work_recv: Arc<::async_channel::Receiver<Work>>,
// PERF: rand uses syscalls. should we do that async?
rand: ::ring::rand::SystemRandom,
} }
// TODO: graceful vs immediate stop // TODO: graceful vs immediate stop
@ -281,18 +299,21 @@ impl Fenrir {
pub fn new(config: &Config) -> Result<Self, Error> { pub fn new(config: &Config) -> Result<Self, Error> {
let listen_num = config.listen.len(); let listen_num = config.listen.len();
let (sender, _) = ::tokio::sync::broadcast::channel(1); let (sender, _) = ::tokio::sync::broadcast::channel(1);
let (work_send, work_recv) = ::async_channel::unbounded::<Work>();
let endpoint = Fenrir { let endpoint = Fenrir {
cfg: config.clone(), cfg: config.clone(),
sockets: SocketList::new(), sockets: SocketList::new(),
dnssec: None, dnssec: None,
stop_working: sender, stop_working: sender,
_inner: Arc::new(FenrirInner { _inner: Arc::new(FenrirInner {
rand: ::ring::rand::SystemRandom::new(),
ciphers: ArcSwapAny::new(Arc::new(Vec::new())), ciphers: ArcSwapAny::new(Arc::new(Vec::new())),
key_exchanges: ArcSwapAny::new(Arc::new(Vec::new())), key_exchanges: ArcSwapAny::new(Arc::new(Vec::new())),
keys: ArcSwapAny::new(Arc::new(Vec::new())), keys: ArcSwapAny::new(Arc::new(Vec::new())),
}), }),
token_check: Arc::new(ArcSwapOption::from(None)), token_check: Arc::new(ArcSwapOption::from(None)),
work_send: Arc::new(work_send),
work_recv: Arc::new(work_recv),
rand: ::ring::rand::SystemRandom::new(),
}; };
Ok(endpoint) Ok(endpoint)
} }
@ -331,11 +352,6 @@ impl Fenrir {
/// actually do the work of stopping resolvers and listeners /// actually do the work of stopping resolvers and listeners
async fn stop_sockets(sockets: SocketList) { async fn stop_sockets(sockets: SocketList) {
sockets.stop_all().await; sockets.stop_all().await;
/*
for s in sockets.into_iter() {
let _ = s.1.await;
}
*/
} }
/// Enable some common socket options. This is just the unsafe part /// Enable some common socket options. This is just the unsafe part
@ -376,8 +392,7 @@ impl Fenrir {
let arc_s = Arc::new(s); let arc_s = Arc::new(s);
let join = ::tokio::spawn(Self::listen_udp( let join = ::tokio::spawn(Self::listen_udp(
stop_working, stop_working,
self._inner.clone(), self.work_send.clone(),
self.token_check.clone(),
arc_s.clone(), arc_s.clone(),
)); ));
self.sockets.add_socket(arc_s, join); self.sockets.add_socket(arc_s, join);
@ -419,8 +434,7 @@ impl Fenrir {
/// Run a dedicated loop to read packets on the listening socket /// Run a dedicated loop to read packets on the listening socket
async fn listen_udp( async fn listen_udp(
mut stop_working: ::tokio::sync::broadcast::Receiver<bool>, mut stop_working: ::tokio::sync::broadcast::Receiver<bool>,
fenrir: Arc<FenrirInner>, work_queue: Arc<::async_channel::Sender<Work>>,
token_check: Arc<ArcSwapOption<TokenChecker>>,
socket: Arc<UdpSocket>, socket: Arc<UdpSocket>,
) -> ::std::io::Result<()> { ) -> ::std::io::Result<()> {
// jumbo frames are 9K max // jumbo frames are 9K max
@ -435,14 +449,12 @@ impl Fenrir {
result? result?
} }
}; };
Self::recv( let data: Vec<u8> = buffer[..bytes].to_vec();
fenrir.clone(), work_queue.send(Work::Recv(RawUdp {
token_check.clone(), data,
&buffer[0..bytes], src: sock_sender,
sock_receiver, dst: sock_receiver.clone(),
sock_sender, }));
)
.await;
} }
Ok(()) Ok(())
} }
@ -461,30 +473,47 @@ impl Fenrir {
Ok(dnssec::Dnssec::parse_txt_record(&record_str)?) Ok(dnssec::Dnssec::parse_txt_record(&record_str)?)
} }
/// Loop continuously and parse packets and other work
pub async fn work_loop(&self) {
let mut stop_working = self.stop_working.subscribe();
loop {
let work = ::tokio::select! {
_done = stop_working.recv() => {
break;
}
maybe_work = self.work_recv.recv() => {
match maybe_work {
Ok(work) => work,
Err(_) => break,
}
}
};
match work {
Work::Recv(pkt) => {
self.recv(pkt).await;
}
}
}
}
const MIN_PACKET_BYTES: usize = 8; const MIN_PACKET_BYTES: usize = 8;
/// Read and do stuff with the udp packet /// Read and do stuff with the raw udp packet
async fn recv( async fn recv(&self, udp: RawUdp) {
fenrir: Arc<FenrirInner>, if udp.data.len() < Self::MIN_PACKET_BYTES {
token_check: Arc<ArcSwapOption<TokenChecker>>,
buffer: &[u8],
_sock_receiver: SocketAddr,
_sock_sender: SocketAddr,
) {
if buffer.len() < Self::MIN_PACKET_BYTES {
return; return;
} }
use connection::ID; use connection::ID;
let raw_id: [u8; 8] = buffer.try_into().expect("unreachable"); let raw_id: [u8; 8] = (udp.data[..8]).try_into().expect("unreachable");
if ID::from(raw_id).is_handshake() { if ID::from(raw_id).is_handshake() {
use connection::handshake::Handshake; use connection::handshake::Handshake;
let handshake = match Handshake::deserialize(&buffer[8..]) { let handshake = match Handshake::deserialize(&udp.data[8..]) {
Ok(handshake) => handshake, Ok(handshake) => handshake,
Err(e) => { Err(e) => {
::tracing::warn!("Handshake parsing: {}", e); ::tracing::warn!("Handshake parsing: {}", e);
return; return;
} }
}; };
let action = match fenrir.recv_handshake(handshake) { let action = match self._inner.recv_handshake(handshake) {
Ok(action) => action, Ok(action) => action,
Err(err) => { Err(err) => {
::tracing::debug!("Handshake recv error {}", err); ::tracing::debug!("Handshake recv error {}", err);
@ -493,7 +522,7 @@ impl Fenrir {
}; };
match action { match action {
HandshakeAction::AuthNeeded(authinfo) => { HandshakeAction::AuthNeeded(authinfo) => {
let tk_check = match token_check.load_full() { let tk_check = match self.token_check.load_full() {
Some(tk_check) => tk_check, Some(tk_check) => tk_check,
None => { None => {
::tracing::error!( ::tracing::error!(
@ -548,11 +577,11 @@ impl Fenrir {
// TODO: contact the service, get the key and // TODO: contact the service, get the key and
// connection ID // connection ID
let srv_conn_id = let srv_conn_id =
connection::ID::new_rand(&fenrir.rand); connection::ID::new_rand(&self.rand);
let auth_conn_id = let auth_conn_id =
connection::ID::new_rand(&fenrir.rand); connection::ID::new_rand(&self.rand);
let srv_secret = let srv_secret =
enc::sym::Secret::new_rand(&fenrir.rand); enc::sym::Secret::new_rand(&self.rand);
let resp_data = dirsync::RespData { let resp_data = dirsync::RespData {
client_nonce: req_data.nonce, client_nonce: req_data.nonce,