diff --git a/Cargo.toml b/Cargo.toml index 93dcec2..fe12dd1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ crate_type = [ "lib", "cdylib", "staticlib" ] # please keep these in alphabetical order arc-swap = { version = "1.6" } +async-channel = { version = "1.8" } # base85 repo has no tags, fix on a commit. v1.1.1 points to older, wrong version base85 = { git = "https://gitlab.com/darkwyrm/base85", rev = "d98efbfd171dd9ba48e30a5c88f94db92fc7b3c6" } chacha20poly1305 = { version = "0.10" } diff --git a/src/lib.rs b/src/lib.rs index 9fcce95..8961fc9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -54,13 +54,16 @@ pub enum Error { // No async here struct FenrirInner { - // PERF: rand uses syscalls. should we do that async? - rand: ::ring::rand::SystemRandom, key_exchanges: ArcSwapAny>>, ciphers: ArcSwapAny>>, keys: ArcSwapAny>>, } +#[allow(unsafe_code)] +unsafe impl Send for FenrirInner {} +#[allow(unsafe_code)] +unsafe impl Sync for FenrirInner {} + /// Information needed to reply after the key exchange #[derive(Debug, Clone)] pub struct AuthNeededInfo { @@ -250,13 +253,22 @@ impl SocketList { } } +struct RawUdp { + data: Vec, + src: SocketAddr, + dst: SocketAddr, +} + +enum Work { + Recv(RawUdp), +} + /// Instance of a fenrir endpoint #[allow(missing_copy_implementations, missing_debug_implementations)] pub struct Fenrir { /// library Configuration cfg: Config, /// listening udp sockets - //sockets: Vec<(Arc, JoinHandle<::std::io::Result<()>>)>, sockets: SocketList, /// DNSSEC resolver, with failovers dnssec: Option, @@ -266,6 +278,12 @@ pub struct Fenrir { _inner: Arc, /// where to ask for token check token_check: Arc>, + /// MPMC work queue. sender + work_send: Arc<::async_channel::Sender>, + /// MPMC work queue. receiver + work_recv: Arc<::async_channel::Receiver>, + // PERF: rand uses syscalls. should we do that async? + rand: ::ring::rand::SystemRandom, } // TODO: graceful vs immediate stop @@ -281,18 +299,21 @@ impl Fenrir { pub fn new(config: &Config) -> Result { let listen_num = config.listen.len(); let (sender, _) = ::tokio::sync::broadcast::channel(1); + let (work_send, work_recv) = ::async_channel::unbounded::(); let endpoint = Fenrir { cfg: config.clone(), sockets: SocketList::new(), dnssec: None, stop_working: sender, _inner: Arc::new(FenrirInner { - rand: ::ring::rand::SystemRandom::new(), ciphers: ArcSwapAny::new(Arc::new(Vec::new())), key_exchanges: ArcSwapAny::new(Arc::new(Vec::new())), keys: ArcSwapAny::new(Arc::new(Vec::new())), }), token_check: Arc::new(ArcSwapOption::from(None)), + work_send: Arc::new(work_send), + work_recv: Arc::new(work_recv), + rand: ::ring::rand::SystemRandom::new(), }; Ok(endpoint) } @@ -331,11 +352,6 @@ impl Fenrir { /// actually do the work of stopping resolvers and listeners async fn stop_sockets(sockets: SocketList) { sockets.stop_all().await; - /* - for s in sockets.into_iter() { - let _ = s.1.await; - } - */ } /// Enable some common socket options. This is just the unsafe part @@ -376,8 +392,7 @@ impl Fenrir { let arc_s = Arc::new(s); let join = ::tokio::spawn(Self::listen_udp( stop_working, - self._inner.clone(), - self.token_check.clone(), + self.work_send.clone(), arc_s.clone(), )); self.sockets.add_socket(arc_s, join); @@ -419,8 +434,7 @@ impl Fenrir { /// Run a dedicated loop to read packets on the listening socket async fn listen_udp( mut stop_working: ::tokio::sync::broadcast::Receiver, - fenrir: Arc, - token_check: Arc>, + work_queue: Arc<::async_channel::Sender>, socket: Arc, ) -> ::std::io::Result<()> { // jumbo frames are 9K max @@ -435,14 +449,12 @@ impl Fenrir { result? } }; - Self::recv( - fenrir.clone(), - token_check.clone(), - &buffer[0..bytes], - sock_receiver, - sock_sender, - ) - .await; + let data: Vec = buffer[..bytes].to_vec(); + work_queue.send(Work::Recv(RawUdp { + data, + src: sock_sender, + dst: sock_receiver.clone(), + })); } Ok(()) } @@ -461,30 +473,47 @@ impl Fenrir { Ok(dnssec::Dnssec::parse_txt_record(&record_str)?) } + /// Loop continuously and parse packets and other work + pub async fn work_loop(&self) { + let mut stop_working = self.stop_working.subscribe(); + loop { + let work = ::tokio::select! { + _done = stop_working.recv() => { + break; + } + maybe_work = self.work_recv.recv() => { + match maybe_work { + Ok(work) => work, + Err(_) => break, + } + } + }; + match work { + Work::Recv(pkt) => { + self.recv(pkt).await; + } + } + } + } + const MIN_PACKET_BYTES: usize = 8; - /// Read and do stuff with the udp packet - async fn recv( - fenrir: Arc, - token_check: Arc>, - buffer: &[u8], - _sock_receiver: SocketAddr, - _sock_sender: SocketAddr, - ) { - if buffer.len() < Self::MIN_PACKET_BYTES { + /// Read and do stuff with the raw udp packet + async fn recv(&self, udp: RawUdp) { + if udp.data.len() < Self::MIN_PACKET_BYTES { return; } use connection::ID; - let raw_id: [u8; 8] = buffer.try_into().expect("unreachable"); + let raw_id: [u8; 8] = (udp.data[..8]).try_into().expect("unreachable"); if ID::from(raw_id).is_handshake() { use connection::handshake::Handshake; - let handshake = match Handshake::deserialize(&buffer[8..]) { + let handshake = match Handshake::deserialize(&udp.data[8..]) { Ok(handshake) => handshake, Err(e) => { ::tracing::warn!("Handshake parsing: {}", e); return; } }; - let action = match fenrir.recv_handshake(handshake) { + let action = match self._inner.recv_handshake(handshake) { Ok(action) => action, Err(err) => { ::tracing::debug!("Handshake recv error {}", err); @@ -493,7 +522,7 @@ impl Fenrir { }; match action { HandshakeAction::AuthNeeded(authinfo) => { - let tk_check = match token_check.load_full() { + let tk_check = match self.token_check.load_full() { Some(tk_check) => tk_check, None => { ::tracing::error!( @@ -548,11 +577,11 @@ impl Fenrir { // TODO: contact the service, get the key and // connection ID let srv_conn_id = - connection::ID::new_rand(&fenrir.rand); + connection::ID::new_rand(&self.rand); let auth_conn_id = - connection::ID::new_rand(&fenrir.rand); + connection::ID::new_rand(&self.rand); let srv_secret = - enc::sym::Secret::new_rand(&fenrir.rand); + enc::sym::Secret::new_rand(&self.rand); let resp_data = dirsync::RespData { client_nonce: req_data.nonce,