From 866edc2d7d012d0c2d8c16799537056562aff48f Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Sat, 17 Jun 2023 11:33:47 +0200 Subject: [PATCH] TONS of bugfixing. Add tests. Client now connects Signed-off-by: Luca Fulchir --- TODO | 1 + src/auth/mod.rs | 16 +- src/config/mod.rs | 26 ++- src/connection/handshake/dirsync.rs | 44 +++-- src/connection/handshake/mod.rs | 6 + src/connection/handshake/tracker.rs | 95 ++++++++-- src/connection/mod.rs | 23 ++- src/connection/socket.rs | 142 ++++++++++----- src/enc/asym.rs | 2 +- src/enc/mod.rs | 2 + src/enc/sym.rs | 39 ++-- src/enc/tests.rs | 135 ++++++++++++++ src/inner/worker.rs | 272 ++++++++++++++-------------- src/lib.rs | 232 +++++++++++++----------- src/tests.rs | 59 ++++-- 15 files changed, 739 insertions(+), 355 deletions(-) create mode 100644 TODO create mode 100644 src/enc/tests.rs diff --git a/TODO b/TODO new file mode 100644 index 0000000..9531367 --- /dev/null +++ b/TODO @@ -0,0 +1 @@ +* Wrapping for everything that wraps (sigh) diff --git a/src/auth/mod.rs b/src/auth/mod.rs index 84be8cb..085816b 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -3,6 +3,8 @@ use crate::enc::Random; use ::zeroize::Zeroize; +/// Anonymous user id +pub const USERID_ANONYMOUS: UserID = UserID([0; UserID::len()]); /// User identifier. 16 bytes for easy uuid conversion #[derive(Debug, Copy, Clone, PartialEq)] pub struct UserID(pub [u8; 16]); @@ -25,8 +27,8 @@ impl UserID { } } /// Anonymous user id - pub fn new_anonymous() -> Self { - UserID([0; 16]) + pub const fn new_anonymous() -> Self { + USERID_ANONYMOUS } /// length of the User ID in bytes pub const fn len() -> usize { @@ -98,6 +100,16 @@ impl TryFrom<&[u8]> for Domain { Ok(Domain(domain_string)) } } +impl From for Domain { + fn from(raw: String) -> Self { + Self(raw) + } +} +impl From<&str> for Domain { + fn from(raw: &str) -> Self { + Self(raw.to_owned()) + } +} impl Domain { /// length of the User ID in bytes diff --git a/src/config/mod.rs b/src/config/mod.rs index c0fe949..f196ac9 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -16,6 +16,23 @@ use ::std::{ vec, }; +/// Key used by a server during the handshake +#[derive(Clone, Debug)] +pub struct ServerKey { + pub id: KeyID, + pub priv_key: PrivKey, + pub pub_key: PubKey, +} + +/// Authentication Server information and keys +#[derive(Clone, Debug)] +pub struct AuthServer { + /// fqdn of the authentication server + pub fqdn: crate::auth::Domain, + /// list of key ids enabled for this domain + pub keys: Vec, +} + /// Main config for libFenrir #[derive(Clone, Debug)] pub struct Config { @@ -34,8 +51,12 @@ pub struct Config { pub hkdfs: Vec, /// Supported Ciphers pub ciphers: Vec, + /// list of authentication servers + /// clients will have this empty + pub servers: Vec, /// list of public/private keys - pub keys: Vec<(KeyID, PrivKey, PubKey)>, + /// clients should have this empty + pub server_keys: Vec, } impl Default for Config { @@ -56,7 +77,8 @@ impl Default for Config { key_exchanges: [KeyExchangeKind::X25519DiffieHellman].to_vec(), hkdfs: [HkdfKind::Sha3].to_vec(), ciphers: [CipherKind::XChaCha20Poly1305].to_vec(), - keys: Vec::new(), + servers: Vec::new(), + server_keys: Vec::new(), } } } diff --git a/src/connection/handshake/dirsync.rs b/src/connection/handshake/dirsync.rs index 8023ed6..0be2dd4 100644 --- a/src/connection/handshake/dirsync.rs +++ b/src/connection/handshake/dirsync.rs @@ -113,10 +113,14 @@ impl Req { + self.exchange_key.kind().pub_len() } /// return the total length of the cleartext data - pub fn encrypted_length(&self) -> usize { + pub fn encrypted_length( + &self, + head_len: HeadLen, + tag_len: TagLen, + ) -> usize { match &self.data { - ReqInner::ClearText(data) => data.len(), - _ => 0, + ReqInner::ClearText(data) => data.len() + head_len.0 + tag_len.0, + ReqInner::CipherText(length) => *length, } } /// actual length of the directory synchronized request @@ -177,11 +181,16 @@ impl super::HandshakeParsing for Req { Some(cipher) => cipher, None => return Err(Error::Parsing), }; - let (exchange_key, len) = match ExchangePubKey::deserialize(&raw[5..]) { - Ok(exchange_key) => exchange_key, - Err(e) => return Err(e.into()), - }; - let data = ReqInner::CipherText(raw.len() - (5 + len)); + const CURR_SIZE: usize = KeyID::len() + + KeyExchangeKind::len() + + HkdfKind::len() + + CipherKind::len(); + let (exchange_key, len) = + match ExchangePubKey::deserialize(&raw[CURR_SIZE..]) { + Ok(exchange_key) => exchange_key, + Err(e) => return Err(e.into()), + }; + let data = ReqInner::CipherText(raw.len() - (CURR_SIZE + len)); Ok(HandshakeData::DirSync(DirSync::Req(Self { key_id, exchange, @@ -436,7 +445,7 @@ impl super::HandshakeParsing for Resp { return Err(Error::NotEnoughData); } let client_key_id: KeyID = - KeyID(u16::from_le_bytes(raw[0..2].try_into().unwrap())); + KeyID(u16::from_le_bytes(raw[0..KeyID::len()].try_into().unwrap())); Ok(HandshakeData::DirSync(DirSync::Resp(Self { client_key_id, data: RespInner::CipherText(raw[KeyID::len()..].len()), @@ -453,10 +462,16 @@ impl Resp { + KeyID::len() } /// return the total length of the cleartext data - pub fn encrypted_length(&self) -> usize { + pub fn encrypted_length( + &self, + head_len: HeadLen, + tag_len: TagLen, + ) -> usize { match &self.data { - RespInner::ClearText(_data) => RespData::len(), - _ => 0, + RespInner::ClearText(_data) => { + RespData::len() + head_len.0 + tag_len.0 + } + RespInner::CipherText(len) => *len, } } /// Total length of the response handshake @@ -471,8 +486,9 @@ impl Resp { _tag_len: TagLen, out: &mut [u8], ) { - out[0..2].copy_from_slice(&self.client_key_id.0.to_le_bytes()); - let start_data = 2 + head_len.0; + out[0..KeyID::len()] + .copy_from_slice(&self.client_key_id.0.to_le_bytes()); + let start_data = KeyID::len() + head_len.0; let end_data = start_data + self.data.len(); self.data.serialize(&mut out[start_data..end_data]); } diff --git a/src/connection/handshake/mod.rs b/src/connection/handshake/mod.rs index 6dd248e..b5204a1 100644 --- a/src/connection/handshake/mod.rs +++ b/src/connection/handshake/mod.rs @@ -37,6 +37,12 @@ pub enum Error { /// Too many client handshakes currently running #[error("Too many client handshakes")] TooManyClientHandshakes, + /// generic internal error + #[error("Internal tracking error")] + InternalTracking, + /// Handshake Timeout + #[error("Handshake timeout")] + Timeout, } /// List of possible handshakes diff --git a/src/connection/handshake/tracker.rs b/src/connection/handshake/tracker.rs index f304337..63ee31d 100644 --- a/src/connection/handshake/tracker.rs +++ b/src/connection/handshake/tracker.rs @@ -1,11 +1,11 @@ //! Handhsake handling use crate::{ - auth::ServiceID, + auth::{Domain, ServiceID}, connection::{ self, handshake::{self, Error, Handshake}, - Connection, IDRecv, + Connection, IDRecv, IDSend, }, enc::{ self, @@ -16,16 +16,23 @@ use crate::{ inner::ThreadTracker, }; +use ::tokio::sync::oneshot; + pub(crate) struct HandshakeServer { pub id: KeyID, pub key: PrivKey, + pub domains: Vec, } +pub(crate) type ConnectAnswer = Result<(KeyID, IDSend), crate::Error>; + pub(crate) struct HandshakeClient { pub service_id: ServiceID, pub service_conn_id: IDRecv, pub connection: Connection, pub timeout: Option<::tokio::task::JoinHandle<()>>, + pub answer: oneshot::Sender, + pub srv_key_id: KeyID, } /// Tracks the keys used by the client and the handshake @@ -73,7 +80,10 @@ impl HandshakeClientList { service_id: ServiceID, service_conn_id: IDRecv, connection: Connection, - ) -> Result<(KeyID, &mut HandshakeClient), ()> { + answer: oneshot::Sender, + srv_key_id: KeyID, + ) -> Result<(KeyID, &mut HandshakeClient), oneshot::Sender> + { let maybe_free_key_idx = self.used.iter().enumerate().find_map(|(idx, bmap)| { match bmap.first_false_index() { @@ -85,7 +95,7 @@ impl HandshakeClientList { Some((idx, false_idx)) => { let free_key_idx = idx * 1024 + false_idx; if free_key_idx > KeyID::MAX as usize { - return Err(()); + return Err(answer); } self.used[idx].set(false_idx, true); free_key_idx @@ -107,6 +117,8 @@ impl HandshakeClientList { service_conn_id, connection, timeout: None, + answer, + srv_key_id, }); Ok(( KeyID(free_key_idx as u16), @@ -136,6 +148,10 @@ pub(crate) struct ClientConnectInfo { pub handshake: Handshake, /// Connection pub connection: Connection, + /// where to wake up the waiting client + pub answer: oneshot::Sender, + /// server public key id that we used on the handshake + pub srv_key_id: KeyID, } /// Intermediate actions to be taken while parsing the handshake #[derive(Debug)] @@ -177,10 +193,42 @@ impl HandshakeTracker { hshake_cli: HandshakeClientList::new(), } } - pub(crate) fn add_server(&mut self, id: KeyID, key: PrivKey) { - self.keys_srv.push(HandshakeServer { id, key }); + pub(crate) fn add_server_key( + &mut self, + id: KeyID, + key: PrivKey, + ) -> Result<(), ()> { + if self.keys_srv.iter().find(|&k| k.id == id).is_some() { + return Err(()); + } + self.keys_srv.push(HandshakeServer { + id, + key, + domains: Vec::new(), + }); self.keys_srv.sort_by(|h_a, h_b| h_a.id.0.cmp(&h_b.id.0)); + Ok(()) } + pub(crate) fn add_server_domain( + &mut self, + domain: &Domain, + key_ids: &[KeyID], + ) -> Result<(), ()> { + // check that all the key ids are present + for id in key_ids.iter() { + if self.keys_srv.iter().find(|k| k.id == *id).is_none() { + return Err(()); + } + } + // add the domain to those keys + for id in key_ids.iter() { + if let Some(srv) = self.keys_srv.iter_mut().find(|k| k.id == *id) { + srv.domains.push(domain.clone()); + } + } + Ok(()) + } + pub(crate) fn add_client( &mut self, priv_key: PrivKey, @@ -188,20 +236,32 @@ impl HandshakeTracker { service_id: ServiceID, service_conn_id: IDRecv, connection: Connection, - ) -> Result<(KeyID, &mut HandshakeClient), ()> { + answer: oneshot::Sender, + srv_key_id: KeyID, + ) -> Result<(KeyID, &mut HandshakeClient), oneshot::Sender> + { self.hshake_cli.add( priv_key, pub_key, service_id, service_conn_id, connection, + answer, + srv_key_id, ) } + pub(crate) fn remove_client( + &mut self, + key_id: KeyID, + ) -> Option { + self.hshake_cli.remove(key_id) + } pub(crate) fn timeout_client( &mut self, key_id: KeyID, ) -> Option<[IDRecv; 2]> { if let Some(hshake) = self.hshake_cli.remove(key_id) { + let _ = hshake.answer.send(Err(Error::Timeout.into())); Some([hshake.connection.id_recv, hshake.service_conn_id]) } else { None @@ -257,9 +317,16 @@ impl HandshakeTracker { let cipher_recv = CipherRecv::new(req.cipher, secret_recv); use crate::enc::sym::AAD; let aad = AAD(&mut []); // no aad for now + + let encrypt_from = req.encrypted_offset(); + let encrypt_to = encrypt_from + + req.encrypted_length( + cipher_recv.nonce_len(), + cipher_recv.tag_len(), + ); match cipher_recv.decrypt( aad, - &mut handshake_raw[req.encrypted_offset()..], + &mut handshake_raw[encrypt_from..encrypt_to], ) { Ok(cleartext) => { req.data.deserialize_as_cleartext(cleartext)?; @@ -292,9 +359,13 @@ impl HandshakeTracker { use crate::enc::sym::AAD; // no aad for now let aad = AAD(&mut []); - let mut raw_data = &mut handshake_raw[resp - .encrypted_offset() - ..(resp.encrypted_offset() + resp.encrypted_length())]; + let data_from = resp.encrypted_offset(); + let data_to = data_from + + resp.encrypted_length( + cipher_recv.nonce_len(), + cipher_recv.tag_len(), + ); + let mut raw_data = &mut handshake_raw[data_from..data_to]; match cipher_recv.decrypt(aad, &mut raw_data) { Ok(cleartext) => { resp.data.deserialize_as_cleartext(&cleartext)?; @@ -314,6 +385,8 @@ impl HandshakeTracker { service_connection_id: hshake.service_conn_id, handshake, connection: hshake.connection, + answer: hshake.answer, + srv_key_id: hshake.srv_key_id, }, )); } diff --git a/src/connection/mod.rs b/src/connection/mod.rs index 07b7c18..45a50e9 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -113,7 +113,11 @@ pub(crate) struct ConnList { impl ConnList { pub(crate) fn new(thread_id: ThreadTracker) -> Self { - let bitmap_id = ::bitmaps::Bitmap::<1024>::new(); + let mut bitmap_id = ::bitmaps::Bitmap::<1024>::new(); + if thread_id.id == 0 { + // make sure we don't count the Handshake ID + bitmap_id.set(0, true); + } const INITIAL_CAP: usize = 128; let mut ret = Self { thread_id, @@ -199,13 +203,6 @@ impl ConnList { } } -use ::std::collections::HashMap; - -enum MapEntry { - Present(IDSend), - Reserved, -} - /// return wether we already have a connection, we are waiting for one, or you /// can start one #[derive(Debug, Clone, Copy)] @@ -218,6 +215,12 @@ pub(crate) enum Reservation { Reserved, } +enum MapEntry { + Present(IDSend), + Reserved, +} +use ::std::collections::HashMap; + /// Link the public key of the authentication server to a connection id /// so that we can reuse that connection to ask for more authentications /// @@ -229,16 +232,16 @@ pub(crate) enum Reservation { /// * wait for the connection to finish /// * remove all those reservations, exept the one key that actually succeded /// While searching, we return a connection ID if just one key is a match +// TODO: can we shard this per-core by hashing the pubkey? or domain? or...??? +// This needs a mutex and it will be our goeal to avoid any synchronization pub(crate) struct AuthServerConnections { conn_map: HashMap, - next_reservation: u64, } impl AuthServerConnections { pub(crate) fn new() -> Self { Self { conn_map: HashMap::with_capacity(32), - next_reservation: 0, } } /// add an ID to the reserved spot, diff --git a/src/connection/socket.rs b/src/connection/socket.rs index 717ecb3..abfc106 100644 --- a/src/connection/socket.rs +++ b/src/connection/socket.rs @@ -1,40 +1,10 @@ //! Socket related types and functions -use ::std::{net::SocketAddr, sync::Arc, vec::Vec}; +use ::std::net::SocketAddr; use ::tokio::{net::UdpSocket, task::JoinHandle}; /// Pair to easily track the socket and its async listening handle -pub type SocketTracker = - (Arc, Arc>>); - -/// async free socket list -pub(crate) struct SocketList { - pub list: Vec, -} -impl SocketList { - pub(crate) fn new() -> Self { - Self { list: Vec::new() } - } - pub(crate) fn rm_all(&mut self) -> Self { - let mut old_list = Vec::new(); - ::core::mem::swap(&mut self.list, &mut old_list); - Self { list: old_list } - } - pub(crate) async fn add_socket( - &mut self, - socket: Arc, - handle: JoinHandle<::std::io::Result<()>>, - ) { - let arc_handle = Arc::new(handle); - self.list.push((socket, arc_handle)); - } - /// This method assumes no other `add_sockets` are being run - pub(crate) async fn stop_all(self) { - for (_socket, mut handle) in self.list.into_iter() { - let _ = Arc::get_mut(&mut handle).unwrap().await; - } - } -} +pub type SocketTracker = (SocketAddr, JoinHandle<::std::io::Result<()>>); /// Strong typedef for a client socket address #[derive(Debug, Copy, Clone)] @@ -53,7 +23,7 @@ fn enable_sock_opt( unsafe { #[allow(trivial_casts)] let val = &value as *const _ as *const ::libc::c_void; - let size = ::std::mem::size_of_val(&value) as ::libc::socklen_t; + let size = ::core::mem::size_of_val(&value) as ::libc::socklen_t; // always clear the error bit before doing a new syscall let _ = ::std::io::Error::last_os_error(); let ret = ::libc::setsockopt(fd, ::libc::SOL_SOCKET, option, val, size); @@ -64,23 +34,107 @@ fn enable_sock_opt( Ok(()) } /// Add an async udp listener -pub async fn bind_udp(sock: SocketAddr) -> ::std::io::Result { - let socket = UdpSocket::bind(sock).await?; +pub async fn bind_udp(addr: SocketAddr) -> ::std::io::Result { + // I know, kind of a mess. but I really wanted SO_REUSE{ADDR,PORT} and + // no-fragmenting stuff. + // I also did not want to load another library for this. + // feel free to simplify, + // especially if we can avoid libc and other libraries + // we currently use libc because it's a dependency of many other deps - use ::std::os::fd::AsRawFd; - let fd = socket.as_raw_fd(); - // can be useful later on for reloads - enable_sock_opt(fd, ::libc::SO_REUSEADDR, 1)?; - enable_sock_opt(fd, ::libc::SO_REUSEPORT, 1)?; + let fd: ::std::os::fd::RawFd = { + let domain = if addr.is_ipv6() { + ::libc::AF_INET6 + } else { + ::libc::AF_INET + }; + #[allow(unsafe_code)] + let tmp = unsafe { ::libc::socket(domain, ::libc::SOCK_DGRAM, 0) }; + let lasterr = ::std::io::Error::last_os_error(); + if tmp == -1 { + return Err(lasterr); + } + tmp.into() + }; + + if let Err(e) = enable_sock_opt(fd, ::libc::SO_REUSEPORT, 1) { + #[allow(unsafe_code)] + unsafe { + ::libc::close(fd); + } + return Err(e); + } + if let Err(e) = enable_sock_opt(fd, ::libc::SO_REUSEADDR, 1) { + #[allow(unsafe_code)] + unsafe { + ::libc::close(fd); + } + return Err(e); + } // We will do path MTU discovery by ourselves, // always set the "don't fragment" bit - if sock.is_ipv6() { - enable_sock_opt(fd, ::libc::IPV6_DONTFRAG, 1)?; + let res = if addr.is_ipv6() { + enable_sock_opt(fd, ::libc::IPV6_DONTFRAG, 1) } else { // FIXME: linux only - enable_sock_opt(fd, ::libc::IP_MTU_DISCOVER, ::libc::IP_PMTUDISC_DO)?; + enable_sock_opt(fd, ::libc::IP_MTU_DISCOVER, ::libc::IP_PMTUDISC_DO) + }; + if let Err(e) = res { + #[allow(unsafe_code)] + unsafe { + ::libc::close(fd); + } + return Err(e); + } + // manually convert rust SockAddr to C sockaddr + #[allow(unsafe_code, trivial_casts, trivial_numeric_casts)] + { + let bind_ret = match addr { + SocketAddr::V4(s4) => { + let ip4: u32 = (*s4.ip()).into(); + let bind_addr = ::libc::sockaddr_in { + sin_family: ::libc::AF_INET as u16, + sin_port: s4.port().to_be(), + sin_addr: ::libc::in_addr { s_addr: ip4 }, + sin_zero: [0; 8], + }; + unsafe { + let c_addr = + &bind_addr as *const _ as *const ::libc::sockaddr; + ::libc::bind(fd, c_addr, 16) + } + } + SocketAddr::V6(s6) => { + let ip6: [u8; 16] = (*s6.ip()).octets(); + let bind_addr = ::libc::sockaddr_in6 { + sin6_family: ::libc::AF_INET6 as u16, + sin6_port: s6.port().to_be(), + sin6_flowinfo: 0, + sin6_addr: ::libc::in6_addr { s6_addr: ip6 }, + sin6_scope_id: 0, + }; + unsafe { + let c_addr = + &bind_addr as *const _ as *const ::libc::sockaddr; + ::libc::bind(fd, c_addr, 24) + } + } + }; + let lasterr = ::std::io::Error::last_os_error(); + if bind_ret != 0 { + unsafe { + ::libc::close(fd); + } + return Err(lasterr); + } } - Ok(socket) + use ::std::os::fd::FromRawFd; + #[allow(unsafe_code)] + let std_sock = unsafe { ::std::net::UdpSocket::from_raw_fd(fd) }; + std_sock.set_nonblocking(true)?; + ::tracing::debug!("Listening udp sock: {}", std_sock.local_addr().unwrap()); + + Ok(UdpSocket::from_std(std_sock)?) } diff --git a/src/enc/asym.rs b/src/enc/asym.rs index 32c3aa2..47a3b5a 100644 --- a/src/enc/asym.rs +++ b/src/enc/asym.rs @@ -152,7 +152,7 @@ pub enum KeyExchangeKind { } impl KeyExchangeKind { /// The serialize length of the field - pub fn len() -> usize { + pub const fn len() -> usize { 1 } /// Build a new keypair for key exchange diff --git a/src/enc/mod.rs b/src/enc/mod.rs index 9a01b98..3aeea7b 100644 --- a/src/enc/mod.rs +++ b/src/enc/mod.rs @@ -4,6 +4,8 @@ pub mod asym; mod errors; pub mod hkdf; pub mod sym; +#[cfg(test)] +mod tests; pub use errors::Error; diff --git a/src/enc/sym.rs b/src/enc/sym.rs index 5728808..d4204e0 100644 --- a/src/enc/sym.rs +++ b/src/enc/sym.rs @@ -25,12 +25,11 @@ pub enum CipherKind { impl CipherKind { /// length of the serialized id for the cipher kind field - pub fn len() -> usize { + pub const fn len() -> usize { 1 } /// required length of the nonce pub fn nonce_len(&self) -> HeadLen { - // TODO: how the hell do I take this from ::chacha20poly1305? HeadLen(Nonce::len()) } /// required length of the key @@ -92,10 +91,7 @@ impl Cipher { } fn nonce_len(&self) -> HeadLen { match self { - Cipher::XChaCha20Poly1305(_) => { - // TODO: how the hell do I take this from ::chacha20poly1305? - HeadLen(::ring::aead::CHACHA20_POLY1305.nonce_len()) - } + Cipher::XChaCha20Poly1305(_) => HeadLen(Nonce::len()), } } fn tag_len(&self) -> TagLen { @@ -117,10 +113,13 @@ impl Cipher { aead::generic_array::GenericArray, AeadInPlace, }; let final_len: usize = { - // FIXME: check min data length - let (nonce_bytes, data_and_tag) = raw_data.split_at_mut(13); + if raw_data.len() <= self.overhead() { + return Err(Error::NotEnoughData(raw_data.len())); + } + let (nonce_bytes, data_and_tag) = + raw_data.split_at_mut(Nonce::len()); let (data_notag, tag_bytes) = data_and_tag.split_at_mut( - data_and_tag.len() + 1 + data_and_tag.len() - ::ring::aead::CHACHA20_POLY1305.tag_len(), ); let nonce = GenericArray::from_slice(nonce_bytes); @@ -172,10 +171,7 @@ impl Cipher { &mut data[Nonce::len()..data_len_notag], ) { Ok(tag) => { - data[data_len_notag..] - // add tag - //data.get_tag_slice() - .copy_from_slice(tag.as_slice()); + data[data_len_notag..].copy_from_slice(tag.as_slice()); Ok(()) } Err(_) => Err(Error::Encrypt), @@ -205,6 +201,10 @@ impl CipherRecv { pub fn nonce_len(&self) -> HeadLen { self.0.nonce_len() } + /// Get the length of the nonce for this cipher + pub fn tag_len(&self) -> TagLen { + self.0.tag_len() + } /// Decrypt a paket. Nonce and Tag are taken from the packet, /// while you need to provide AAD (Additional Authenticated Data) pub fn decrypt<'a>( @@ -285,7 +285,7 @@ struct NonceNum { #[repr(C)] pub union Nonce { num: NonceNum, - raw: [u8; 12], + raw: [u8; Self::len()], } impl ::core::fmt::Debug for Nonce { @@ -303,13 +303,17 @@ impl ::core::fmt::Debug for Nonce { impl Nonce { /// Generate a new random Nonce pub fn new(rand: &Random) -> Self { - let mut raw = [0; 12]; + let mut raw = [0; Self::len()]; rand.fill(&mut raw); Self { raw } } /// Length of this nonce in bytes pub const fn len() -> usize { - return 12; + // FIXME: was:12. xchacha20poly1305 requires 24. + // but we should change keys much earlier than that, and our + // nonces are not random, but sequential. + // we should change keys every 2^30 bytes to be sure (stream max window) + return 24; } /// Get reference to the nonce bytes pub fn as_bytes(&self) -> &[u8] { @@ -319,7 +323,7 @@ impl Nonce { } } /// Create Nonce from array - pub fn from_slice(raw: [u8; 12]) -> Self { + pub fn from_slice(raw: [u8; Self::len()]) -> Self { Self { raw } } /// Go to the next nonce @@ -336,6 +340,7 @@ impl Nonce { } /// Synchronize the mutex acess with a nonce for multithread safety +// TODO: remove mutex, not needed anymore #[derive(Debug)] pub struct NonceSync { nonce: ::std::sync::Mutex, diff --git a/src/enc/tests.rs b/src/enc/tests.rs new file mode 100644 index 0000000..ead07ee --- /dev/null +++ b/src/enc/tests.rs @@ -0,0 +1,135 @@ +use crate::{ + auth, + connection::{handshake::*, ID}, + enc::{self, asym::KeyID}, +}; + +#[test] +fn test_simple_encrypt_decrypt() { + let rand = enc::Random::new(); + let cipher = enc::sym::CipherKind::XChaCha20Poly1305; + let secret = enc::Secret::new_rand(&rand); + let secret2 = secret.clone(); + + let cipher_send = enc::sym::CipherSend::new(cipher, secret, &rand); + let cipher_recv = enc::sym::CipherRecv::new(cipher, secret2); + + let mut data = Vec::new(); + let tot_len = cipher_recv.nonce_len().0 + 1234 + cipher_recv.tag_len().0; + data.resize(tot_len, 0); + rand.fill(&mut data); + data[..enc::sym::Nonce::len()].copy_from_slice(&[0; 24]); + let last = data.len() - cipher_recv.tag_len().0; + data[last..].copy_from_slice(&[0; 16]); + let orig = data.clone(); + let raw_aad: [u8; 8] = [0, 1, 2, 3, 4, 5, 6, 7]; + let aad = enc::sym::AAD(&raw_aad[..]); + let aad2 = enc::sym::AAD(&raw_aad[..]); + if cipher_send.encrypt(aad, &mut data).is_err() { + assert!(false, "Encrypt failed"); + } + if cipher_recv.decrypt(aad2, &mut data).is_err() { + assert!(false, "Decrypt failed"); + } + data[..enc::sym::Nonce::len()].copy_from_slice(&[0; 24]); + let last = data.len() - cipher_recv.tag_len().0; + data[last..].copy_from_slice(&[0; 16]); + assert!(orig == data, "DIFFERENT!\n{:?}\n{:?}\n", orig, data); +} + +#[test] +fn test_encrypt_decrypt() { + let rand = enc::Random::new(); + let cipher = enc::sym::CipherKind::XChaCha20Poly1305; + let secret = enc::Secret::new_rand(&rand); + let secret2 = secret.clone(); + + let cipher_send = enc::sym::CipherSend::new(cipher, secret, &rand); + let cipher_recv = enc::sym::CipherRecv::new(cipher, secret2); + let nonce_len = cipher_recv.nonce_len(); + let tag_len = cipher_recv.tag_len(); + + let service_key = enc::Secret::new_rand(&rand); + + let data = dirsync::RespInner::ClearText(dirsync::RespData { + client_nonce: dirsync::Nonce::new(&rand), + id: ID::ID(::core::num::NonZeroU64::new(424242).unwrap()), + service_connection_id: ID::ID( + ::core::num::NonZeroU64::new(434343).unwrap(), + ), + service_key, + }); + + let resp = dirsync::Resp { + client_key_id: KeyID(4444), + data, + }; + let encrypt_from = resp.encrypted_offset(); + let encrypt_to = encrypt_from + resp.encrypted_length(nonce_len, tag_len); + + let h_resp = + Handshake::new(HandshakeData::DirSync(dirsync::DirSync::Resp(resp))); + + let mut bytes = Vec::::with_capacity( + h_resp.len(cipher.nonce_len(), cipher.tag_len()), + ); + bytes.resize(h_resp.len(cipher.nonce_len(), cipher.tag_len()), 0); + h_resp.serialize(cipher.nonce_len(), cipher.tag_len(), &mut bytes); + + let raw_aad: [u8; 7] = [0, 1, 2, 3, 4, 5, 6]; + let aad = enc::sym::AAD(&raw_aad[..]); + let aad2 = enc::sym::AAD(&raw_aad[..]); + + let pre_encrypt = bytes.clone(); + // encrypt + if cipher_send + .encrypt(aad, &mut bytes[encrypt_from..encrypt_to]) + .is_err() + { + assert!(false, "Encrypt failed"); + } + if cipher_recv + .decrypt(aad2, &mut bytes[encrypt_from..encrypt_to]) + .is_err() + { + assert!(false, "Decrypt failed"); + } + // make sure Nonce and Tag are 0 + bytes[encrypt_from..(encrypt_from + nonce_len.0)].copy_from_slice(&[0; 24]); + let tag_from = encrypt_to - tag_len.0; + bytes[tag_from..(tag_from + tag_len.0)].copy_from_slice(&[0; 16]); + assert!( + pre_encrypt == bytes, + "{}|{}=\n{:?}\n{:?}", + encrypt_from, + encrypt_to, + pre_encrypt, + bytes + ); + + // decrypt + + let mut deserialized = match Handshake::deserialize(&bytes) { + Ok(deserialized) => deserialized, + Err(e) => { + assert!(false, "{}", e.to_string()); + return; + } + }; + // reparse + if let HandshakeData::DirSync(dirsync::DirSync::Resp(r_a)) = + &mut deserialized.data + { + let enc_start = r_a.encrypted_offset() + cipher.nonce_len().0; + if let Err(e) = r_a.data.deserialize_as_cleartext( + &bytes[enc_start..(bytes.len() - cipher.tag_len().0)], + ) { + assert!(false, "DirSync Resp Inner serialize: {}", e.to_string()); + } + }; + + assert!( + deserialized == h_resp, + "DirSync Resp (de)serialization not working", + ); +} diff --git a/src/inner/worker.rs b/src/inner/worker.rs index 360b004..f352fc0 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -37,7 +37,7 @@ pub(crate) struct RawUdp { } pub(crate) struct ConnectInfo { - pub answer: oneshot::Sender>, + pub answer: oneshot::Sender, pub resolved: dnssec::Record, pub service_id: ServiceID, pub domain: Domain, @@ -57,14 +57,15 @@ pub(crate) enum WorkAnswer { } /// Actual worker implementation. -pub(crate) struct Worker { +#[allow(missing_debug_implementations)] +pub struct Worker { cfg: Config, thread_id: ThreadTracker, // PERF: rand uses syscalls. how to do that async? rand: Random, stop_working: crate::StopWorkingRecvCh, token_check: Option>>, - sockets: Vec, + sockets: Vec>, queue: ::async_channel::Receiver, queue_timeouts_recv: mpsc::UnboundedReceiver, queue_timeouts_send: mpsc::UnboundedSender, @@ -73,64 +74,18 @@ pub(crate) struct Worker { handshakes: HandshakeTracker, } +#[allow(unsafe_code)] +unsafe impl Send for Worker {} + impl Worker { - pub(crate) async fn new_and_loop( - cfg: Config, - thread_id: ThreadTracker, - stop_working: crate::StopWorkingRecvCh, - token_check: Option>>, - socket_addrs: Vec<::std::net::SocketAddr>, - queue: ::async_channel::Receiver, - ) -> ::std::io::Result<()> { - // TODO: get a channel to send back information, and send the error - let mut worker = Self::new( - cfg, - thread_id, - stop_working, - token_check, - socket_addrs, - queue, - ) - .await?; - worker.work_loop().await; - Ok(()) - } pub(crate) async fn new( mut cfg: Config, thread_id: ThreadTracker, stop_working: crate::StopWorkingRecvCh, token_check: Option>>, - socket_addrs: Vec<::std::net::SocketAddr>, + sockets: Vec>, queue: ::async_channel::Receiver, ) -> ::std::io::Result { - // bind all sockets again so that we can easily - // send without sharing resources - // in the future we will want to have a thread-local listener too, - // but before that we need ebpf to pin a connection to a thread - // directly from the kernel - let mut sock_set = ::tokio::task::JoinSet::new(); - socket_addrs.into_iter().for_each(|s_addr| { - sock_set.spawn(async move { - let socket = - connection::socket::bind_udp(s_addr.clone()).await?; - Ok(socket) - }); - }); - // make sure we either add all of them, or none - let mut sockets = Vec::with_capacity(cfg.listen.len()); - while let Some(join_res) = sock_set.join_next().await { - match join_res { - Ok(s_res) => match s_res { - Ok(sock) => sockets.push(sock), - Err(e) => { - ::tracing::error!("Can't rebind socket"); - return Err(e); - } - }, - Err(e) => return Err(e.into()), - } - } - let (queue_timeouts_send, queue_timeouts_recv) = mpsc::unbounded_channel(); let mut handshakes = HandshakeTracker::new( @@ -138,11 +93,24 @@ impl Worker { cfg.ciphers.clone(), cfg.key_exchanges.clone(), ); - let mut keys = Vec::new(); + let mut server_keys = Vec::new(); // make sure the keys are no longer in the config - ::core::mem::swap(&mut keys, &mut cfg.keys); - for k in keys.into_iter() { - handshakes.add_server(k.0, k.1); + ::core::mem::swap(&mut server_keys, &mut cfg.server_keys); + for k in server_keys.into_iter() { + if handshakes.add_server_key(k.id, k.priv_key).is_err() { + return Err(::std::io::Error::new( + ::std::io::ErrorKind::AlreadyExists, + "You can't use the same KeyID for multiple keys", + )); + } + } + for srv in cfg.servers.iter() { + if handshakes.add_server_domain(&srv.fqdn, &srv.keys).is_err() { + return Err(::std::io::Error::new( + ::std::io::ErrorKind::NotFound, + "Specified a KeyID that we don't have", + )); + } } Ok(Self { @@ -160,12 +128,15 @@ impl Worker { handshakes, }) } - pub(crate) async fn work_loop(&mut self) { + /// Continuously loop and process work as needed + pub async fn work_loop(&mut self) { 'mainloop: loop { let work = ::tokio::select! { tell_stopped = self.stop_working.recv() => { - let _ = tell_stopped.unwrap().send( + if let Ok(stop_ch) = tell_stopped { + let _ = stop_ch.send( crate::StopWorking::WorkerStopped).await; + } break; } maybe_timeout = self.queue.recv() => { @@ -302,6 +273,7 @@ impl Worker { } }; let hkdf; + if let PubKey::Exchange(srv_pub) = key.1 { let secret = match priv_key.key_exchange(exchange, srv_pub) { @@ -341,11 +313,13 @@ impl Worker { conn_info.service_id, service_conn_id, conn, + conn_info.answer, + key.0, ) { Ok((client_key_id, hshake)) => (client_key_id, hshake), - Err(_) => { + Err(answer) => { ::tracing::warn!("Too many client handshakes"); - let _ = conn_info.answer.send(Err( + let _ = answer.send(Err( handshake::Error::TooManyClientHandshakes .into(), )); @@ -363,7 +337,7 @@ impl Worker { let req_data = dirsync::ReqData { nonce: dirsync::Nonce::new(&self.rand), client_key_id, - id: auth_recv_id.0, + id: auth_recv_id.0, //FIXME: is zero auth: auth_info, }; let req = dirsync::Req { @@ -374,28 +348,50 @@ impl Worker { exchange_key: pub_key, data: dirsync::ReqInner::ClearText(req_data), }; - let mut raw = Vec::::with_capacity(req.len()); - req.serialize( + let encrypt_start = ID::len() + req.encrypted_offset(); + let encrypt_end = encrypt_start + + req.encrypted_length( + cipher_selected.nonce_len(), + cipher_selected.tag_len(), + ); + let h_req = Handshake::new(HandshakeData::DirSync( + DirSync::Req(req), + )); + use connection::{PacketData, ID}; + let packet = Packet { + id: ID::Handshake, + data: PacketData::Handshake(h_req), + }; + + let tot_len = packet.len( + cipher_selected.nonce_len(), + cipher_selected.tag_len(), + ); + let mut raw = Vec::::with_capacity(tot_len); + raw.resize(tot_len, 0); + packet.serialize( cipher_selected.nonce_len(), cipher_selected.tag_len(), &mut raw[..], ); // encrypt - let encrypt_start = req.encrypted_offset(); - let encrypt_end = encrypt_start + req.encrypted_length(); if let Err(e) = hshake.connection.cipher_send.encrypt( sym::AAD(&[]), &mut raw[encrypt_start..encrypt_end], ) { ::tracing::error!("Can't encrypt DirSync Request"); - let _ = conn_info.answer.send(Err(e.into())); + if let Some(client) = + self.handshakes.remove_client(client_key_id) + { + let _ = client.answer.send(Err(e.into())); + }; continue 'mainloop; } // send always from the first socket // FIXME: select based on routing table let sender = self.sockets[0].local_addr().unwrap(); - let dest = UdpServer(addr.as_sockaddr().unwrap()); + let dest = UdpClient(addr.as_sockaddr().unwrap()); // start the timeout right before sending the packet hshake.timeout = Some(::tokio::task::spawn_local( @@ -406,7 +402,7 @@ impl Worker { )); // send packet - self.send_packet(raw, UdpClient(sender), dest).await; + self.send_packet(raw, dest, UdpServer(sender)).await; continue 'mainloop; } @@ -435,17 +431,19 @@ impl Worker { /// Read and do stuff with the raw udp packet async fn recv(&mut self, mut udp: RawUdp) { if udp.packet.id.is_handshake() { - let handshake = match Handshake::deserialize(&udp.data[8..]) { + let handshake = match Handshake::deserialize( + &udp.data[connection::ID::len()..], + ) { Ok(handshake) => handshake, Err(e) => { - ::tracing::warn!("Handshake parsing: {}", e); + ::tracing::debug!("Handshake parsing: {}", e); return; } }; - let action = match self - .handshakes - .recv_handshake(handshake, &mut udp.data[8..]) - { + let action = match self.handshakes.recv_handshake( + handshake, + &mut udp.data[connection::ID::len()..], + ) { Ok(action) => action, Err(err) => { ::tracing::debug!("Handshake recv error {}", err); @@ -454,16 +452,6 @@ impl Worker { }; match action { HandshakeAction::AuthNeeded(authinfo) => { - let token_check = match self.token_check.as_ref() { - Some(token_check) => token_check, - None => { - ::tracing::error!( - "Authentication requested but \ - we have no token checker" - ); - return; - } - }; let req; if let HandshakeData::DirSync(DirSync::Req(r)) = authinfo.handshake.data @@ -477,25 +465,36 @@ impl Worker { let req_data = match req.data { ReqInner::ClearText(req_data) => req_data, _ => { - ::tracing::error!( - "token_check: expected ClearText" - ); + ::tracing::error!("AuthNeeded: expected ClearText"); + assert!(false, "AuthNeeded: unreachable"); return; } }; // FIXME: This part can take a while, // we should just spawn it probably - let is_authenticated = { - let tk_check = token_check.lock().await; - tk_check( - req_data.auth.user, - req_data.auth.token, - req_data.auth.service_id, - req_data.auth.domain, - ) - .await + let maybe_auth_check = { + match &self.token_check { + None => { + if req_data.auth.user == auth::USERID_ANONYMOUS + { + Ok(true) + } else { + Ok(false) + } + } + Some(token_check) => { + let tk_check = token_check.lock().await; + tk_check( + req_data.auth.user, + req_data.auth.token, + req_data.auth.service_id, + req_data.auth.domain, + ) + .await + } + } }; - let is_authenticated = match is_authenticated { + let is_authenticated = match maybe_auth_check { Ok(is_authenticated) => is_authenticated, Err(_) => { ::tracing::error!("error in token auth"); @@ -545,9 +544,9 @@ impl Worker { client_key_id: req_data.client_key_id, data: RespInner::ClearText(resp_data), }; - let offset_to_encrypt = resp.encrypted_offset(); + let encrypt_from = ID::len() + resp.encrypted_offset(); let encrypt_until = - offset_to_encrypt + resp.encrypted_length() + tag_len.0; + encrypt_from + resp.encrypted_length(head_len, tag_len); let resp_handshake = Handshake::new( HandshakeData::DirSync(DirSync::Resp(resp)), ); @@ -556,14 +555,15 @@ impl Worker { id: ID::new_handshake(), data: PacketData::Handshake(resp_handshake), }; - let mut raw_out = - Vec::::with_capacity(packet.len(head_len, tag_len)); + let tot_len = packet.len(head_len, tag_len); + let mut raw_out = Vec::::with_capacity(tot_len); + raw_out.resize(tot_len, 0); packet.serialize(head_len, tag_len, &mut raw_out); - if let Err(e) = auth_conn.cipher_send.encrypt( - aad, - &mut raw_out[offset_to_encrypt..encrypt_until], - ) { + if let Err(e) = auth_conn + .cipher_send + .encrypt(aad, &mut raw_out[encrypt_from..encrypt_until]) + { ::tracing::error!("can't encrypt: {:?}", e); return; } @@ -588,43 +588,46 @@ impl Worker { ::tracing::error!( "ClientConnect on non DS::Resp::ClearText" ); - return; + unreachable!(); } + let auth_srv_conn = IDSend(resp_data.id); let mut conn = cci.connection; - conn.id_send = IDSend(resp_data.id); + conn.id_send = auth_srv_conn; let id_recv = conn.id_recv; let cipher = conn.cipher_recv.kind(); // track the connection to the authentication server if self.connections.track(conn.into()).is_err() { ::tracing::error!("Could not track new connection"); self.connections.remove(id_recv); + let _ = cci.answer.send(Err( + handshake::Error::InternalTracking.into(), + )); return; } - if cci.service_id == auth::SERVICEID_AUTH { - // the user asked a single connection - // to the authentication server, without any additional - // service. No more connections to setup - return; + if cci.service_id != auth::SERVICEID_AUTH { + // create and track the connection to the service + // SECURITY: xor with secrets + //FIXME: the Secret should be XORed with the client + // stored secret (if any) + let hkdf = Hkdf::new( + HkdfKind::Sha3, + cci.service_id.as_bytes(), + resp_data.service_key, + ); + let mut service_connection = Connection::new( + hkdf, + cipher, + connection::Role::Client, + &self.rand, + ); + service_connection.id_recv = cci.service_connection_id; + service_connection.id_send = + IDSend(resp_data.service_connection_id); + let _ = + self.connections.track(service_connection.into()); } - // create and track the connection to the service - // SECURITY: xor with secrets - //FIXME: the Secret should be XORed with the client stored - // secret (if any) - let hkdf = Hkdf::new( - HkdfKind::Sha3, - cci.service_id.as_bytes(), - resp_data.service_key, - ); - let mut service_connection = Connection::new( - hkdf, - cipher, - connection::Role::Client, - &self.rand, - ); - service_connection.id_recv = cci.service_connection_id; - service_connection.id_send = - IDSend(resp_data.service_connection_id); - let _ = self.connections.track(service_connection.into()); + let _ = + cci.answer.send(Ok((cci.srv_key_id, auth_srv_conn))); } HandshakeAction::Nothing => {} }; @@ -644,11 +647,12 @@ impl Worker { Some(src_sock) => src_sock, None => { ::tracing::error!( - "Can't send packet: Server changed listening ip!" + "Can't send packet: Server changed listening ip{}!", + server.0 ); return; } }; - let _ = src_sock.send_to(&data, client.0).await; + let res = src_sock.send_to(&data, client.0).await; } } diff --git a/src/lib.rs b/src/lib.rs index d08ca2e..697fff8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -30,7 +30,7 @@ use crate::{ auth::{Domain, ServiceID, TokenChecker}, connection::{ handshake, - socket::{SocketList, UdpClient, UdpServer}, + socket::{SocketTracker, UdpClient, UdpServer}, AuthServerConnections, Packet, }, inner::{ @@ -86,7 +86,7 @@ pub struct Fenrir { /// library Configuration cfg: Config, /// listening udp sockets - sockets: SocketList, + sockets: Vec, /// DNSSEC resolver, with failovers dnssec: dnssec::Dnssec, /// Broadcast channel to tell workers to stop working @@ -100,9 +100,6 @@ pub struct Fenrir { // manner _thread_pool: Vec<::std::thread::JoinHandle<()>>, _thread_work: Arc>>, - // This can be different from cfg.listen since using port 0 will result - // in a random port assigned by the operative system - _listen_addrs: Vec<::std::net::SocketAddr>, } // TODO: graceful vs immediate stop @@ -127,16 +124,23 @@ impl Fenrir { } fn stop_sync( &mut self, - ) -> Option<(::tokio::sync::mpsc::Receiver, usize, usize)> - { - let listeners_num = self.sockets.list.len(); + ) -> Option<( + ::tokio::sync::mpsc::Receiver, + Vec<::tokio::task::JoinHandle<::std::io::Result<()>>>, + usize, + )> { let workers_num = self._thread_work.len(); - if self.sockets.list.len() > 0 || self._thread_work.len() > 0 { + if self.sockets.len() > 0 || self._thread_work.len() > 0 { let (ch_send, ch_recv) = ::tokio::sync::mpsc::channel(4); let _ = self.stop_working.send(ch_send); - let _ = self.sockets.rm_all(); + let mut old_listeners = Vec::with_capacity(self.sockets.len()); + ::core::mem::swap(&mut old_listeners, &mut self.sockets); self._thread_pool.clear(); - Some((ch_recv, listeners_num, workers_num)) + let listeners = old_listeners + .into_iter() + .map(|(_, joinable)| joinable) + .collect(); + Some((ch_recv, listeners, workers_num)) } else { None } @@ -144,9 +148,10 @@ impl Fenrir { async fn stop_wait( &mut self, mut ch: ::tokio::sync::mpsc::Receiver, - mut listeners_num: usize, + listeners: Vec<::tokio::task::JoinHandle<::std::io::Result<()>>>, mut workers_num: usize, ) { + let mut listeners_num = listeners.len(); while listeners_num > 0 && workers_num > 0 { match ch.recv().await { Some(stopped) => match stopped { @@ -158,6 +163,11 @@ impl Fenrir { _ => break, } } + for l in listeners.into_iter() { + if let Err(e) = l.await { + ::tracing::error!("Unclean shutdown of listener: {:?}", e); + } + } } /// Create a new Fenrir endpoint /// spawn threads pinned to cpus in our own way with tokio's runtime @@ -167,22 +177,32 @@ impl Fenrir { ) -> Result { let (sender, _) = ::tokio::sync::broadcast::channel(1); let dnssec = dnssec::Dnssec::new(&config.resolvers)?; + // bind sockets early so we can change "port 0" (aka: random) + // in the config + let binded_sockets = Self::bind_sockets(&config).await?; + let socket_addrs = binded_sockets + .iter() + .map(|s| s.local_addr().unwrap()) + .collect(); + let cfg = { + let mut tmp = config.clone(); + tmp.listen = socket_addrs; + tmp + }; let mut endpoint = Self { - cfg: config.clone(), - sockets: SocketList::new(), + cfg, + sockets: Vec::with_capacity(config.listen.len()), dnssec, stop_working: sender, token_check: None, conn_auth_srv: Mutex::new(AuthServerConnections::new()), _thread_pool: Vec::new(), _thread_work: Arc::new(Vec::new()), - _listen_addrs: Vec::with_capacity(config.listen.len()), }; - endpoint.start_work_threads_pinned(tokio_rt).await?; - match endpoint.add_sockets().await { - Ok(addrs) => endpoint._listen_addrs = addrs, - Err(e) => return Err(e.into()), - } + endpoint + .start_work_threads_pinned(tokio_rt, binded_sockets.clone()) + .await?; + endpoint.run_listeners(binded_sockets).await?; Ok(endpoint) } /// Create a new Fenrir endpoint @@ -192,41 +212,39 @@ impl Fenrir { /// * make sure that the threads are pinned on the cpu pub async fn with_workers( config: &Config, - ) -> Result< - ( - Self, - Vec>>, - ), - Error, - > { + ) -> Result<(Self, Vec), Error> { let (stop_working, _) = ::tokio::sync::broadcast::channel(1); let dnssec = dnssec::Dnssec::new(&config.resolvers)?; - let cfg = config.clone(); - let sockets = SocketList::new(); - let conn_auth_srv = Mutex::new(AuthServerConnections::new()); - let thread_pool = Vec::new(); - let thread_work = Arc::new(Vec::new()); - let listen_addrs = Vec::with_capacity(config.listen.len()); + // bind sockets early so we can change "port 0" (aka: random) + // in the config + let binded_sockets = Self::bind_sockets(&config).await?; + let socket_addrs = binded_sockets + .iter() + .map(|s| s.local_addr().unwrap()) + .collect(); + let cfg = { + let mut tmp = config.clone(); + tmp.listen = socket_addrs; + tmp + }; let mut endpoint = Self { cfg, - sockets, + sockets: Vec::with_capacity(config.listen.len()), dnssec, stop_working: stop_working.clone(), token_check: None, - conn_auth_srv, - _thread_pool: thread_pool, - _thread_work: thread_work, - _listen_addrs: listen_addrs, + conn_auth_srv: Mutex::new(AuthServerConnections::new()), + _thread_pool: Vec::new(), + _thread_work: Arc::new(Vec::new()), }; let worker_num = config.threads.unwrap().get(); let mut workers = Vec::with_capacity(worker_num); for _ in 0..worker_num { - workers.push(endpoint.start_single_worker().await?); - } - match endpoint.add_sockets().await { - Ok(addrs) => endpoint._listen_addrs = addrs, - Err(e) => return Err(e.into()), + workers.push( + endpoint.start_single_worker(binded_sockets.clone()).await?, + ); } + endpoint.run_listeners(binded_sockets).await?; Ok((endpoint, workers)) } /// Returns the list of the actual addresses we are listening on @@ -234,57 +252,56 @@ impl Fenrir { /// if you specified UDP port 0 a random one has been assigned to you /// by the operating system. pub fn addresses(&self) -> Vec<::std::net::SocketAddr> { - self._listen_addrs.clone() + self.sockets.iter().map(|(s, _)| s.clone()).collect() } - // only call **after** starting all threads - /// Add all UDP sockets found in config - /// and start listening for packets - async fn add_sockets( - &mut self, - ) -> ::std::io::Result> { + // only call **before** starting all threads + /// bind all UDP sockets found in config + async fn bind_sockets(cfg: &Config) -> Result>, Error> { // try to bind multiple sockets in parallel let mut sock_set = ::tokio::task::JoinSet::new(); - self.cfg.listen.iter().for_each(|s_addr| { + cfg.listen.iter().for_each(|s_addr| { let socket_address = s_addr.clone(); - let stop_working = self.stop_working.subscribe(); - let th_work = self._thread_work.clone(); sock_set.spawn(async move { - let s = connection::socket::bind_udp(socket_address).await?; - let arc_s = Arc::new(s); - let join = ::tokio::spawn(Self::listen_udp( - stop_working, - th_work, - arc_s.clone(), - )); - Ok((arc_s, join)) + connection::socket::bind_udp(socket_address).await }); }); - - // make sure we either add all of them, or none - let mut all_socks = Vec::with_capacity(self.cfg.listen.len()); + // make sure we either return all of them, or none + let mut all_socks = Vec::with_capacity(cfg.listen.len()); while let Some(join_res) = sock_set.join_next().await { match join_res { Ok(s_res) => match s_res { Ok(s) => { - all_socks.push(s); + all_socks.push(Arc::new(s)); } Err(e) => { - return Err(e); + return Err(e.into()); } }, Err(e) => { - return Err(e.into()); + return Err(Error::Setup(e.to_string())); } } } - - let mut ret = Vec::with_capacity(self.cfg.listen.len()); - for (arc_s, join) in all_socks.into_iter() { - ret.push(arc_s.local_addr().unwrap()); - self.sockets.add_socket(arc_s, join).await; + assert!(all_socks.len() == cfg.listen.len(), "missing socks"); + Ok(all_socks) + } + // only call **after** starting all threads + /// spawn all listeners + async fn run_listeners( + &mut self, + socks: Vec>, + ) -> Result<(), Error> { + for sock in socks.into_iter() { + let sockaddr = sock.local_addr().unwrap(); + let stop_working = self.stop_working.subscribe(); + let th_work = self._thread_work.clone(); + let joinable = ::tokio::spawn(async move { + Self::listen_udp(stop_working, th_work, sock.clone()).await + }); + self.sockets.push((sockaddr, joinable)); } - Ok(ret) + Ok(()) } /// Run a dedicated loop to read packets on the listening socket @@ -301,12 +318,15 @@ impl Fenrir { let (bytes, sock_sender) = ::tokio::select! { tell_stopped = stop_working.recv() => { drop(socket); - let _ = tell_stopped.unwrap() - .send(StopWorking::ListenerStopped).await; + if let Ok(stop_ch) = tell_stopped { + let _ = stop_ch + .send(StopWorking::ListenerStopped).await; + } return Ok(()); } result = socket.recv_from(&mut buffer) => { - result? + let (bytes, from) = result?; + (bytes, UdpClient(from)) } }; let data: Vec = buffer[..bytes].to_vec(); @@ -324,17 +344,15 @@ impl Fenrir { use connection::packet::ConnectionID; match packet.id { ConnectionID::Handshake => { - let send_port = sock_sender.port() as u64; - ((send_port % queues_num) - 1) as usize - } - ConnectionID::ID(id) => { - ((id.get() % queues_num) - 1) as usize + let send_port = sock_sender.0.port() as u64; + (send_port % queues_num) as usize } + ConnectionID::ID(id) => (id.get() % queues_num) as usize, } }; let _ = work_queues[thread_idx] .send(Work::Recv(RawUdp { - src: UdpClient(sock_sender), + src: sock_sender, dst: sock_receiver, packet, data, @@ -431,7 +449,7 @@ impl Fenrir { .unwrap(); // and tell that thread to connect somewhere - let (send, recv) = ::tokio::sync::oneshot::channel(); + let (send, mut recv) = ::tokio::sync::oneshot::channel(); let _ = self._thread_work[thread_idx] .send(Work::Connect(ConnectInfo { answer: send, @@ -450,10 +468,15 @@ impl Fenrir { conn_auth_lock.remove_reserved(&resolved); Err(e) } - Ok((pubkey, id_send)) => { + Ok((key_id, id_send)) => { + let key = resolved + .public_keys + .iter() + .find(|k| k.0 == key_id) + .unwrap(); let mut conn_auth_lock = self.conn_auth_srv.lock().await; - conn_auth_lock.add(&pubkey, id_send, &resolved); + conn_auth_lock.add(&key.1, id_send, &resolved); //FIXME: user needs to somehow track the connection Ok(()) @@ -472,13 +495,11 @@ impl Fenrir { } } - // needs to be called before add_sockets + // needs to be called before run_listeners async fn start_single_worker( &mut self, - ) -> ::std::result::Result< - impl futures::Future>, - Error, - > { + socks: Vec>, + ) -> ::std::result::Result { let thread_idx = self._thread_work.len() as u16; let max_threads = self.cfg.threads.unwrap().get() as u16; if thread_idx >= max_threads { @@ -496,17 +517,18 @@ impl Fenrir { total: max_threads, }; let (work_send, work_recv) = ::async_channel::unbounded::(); - let worker = Worker::new_and_loop( + let worker = Worker::new( self.cfg.clone(), thread_id, self.stop_working.subscribe(), self.token_check.clone(), - self.cfg.listen.clone(), + socks, work_recv, - ); + ) + .await?; // don't keep around private keys too much if (thread_idx + 1) == max_threads { - self.cfg.keys.clear(); + self.cfg.server_keys.clear(); } loop { let queues_lock = match Arc::get_mut(&mut self._thread_work) { @@ -533,6 +555,7 @@ impl Fenrir { async fn start_work_threads_pinned( &mut self, tokio_rt: Arc<::tokio::runtime::Runtime>, + sockets: Vec>, ) -> ::std::result::Result<(), Error> { use ::std::sync::Mutex; let hw_topology = match ::hwloc2::Topology::new() { @@ -568,7 +591,7 @@ impl Fenrir { let (work_send, work_recv) = ::async_channel::unbounded::(); let th_stop_working = self.stop_working.subscribe(); let th_token_check = self.token_check.clone(); - let th_socket_addrs = self.cfg.listen.clone(); + let th_sockets = sockets.clone(); let thread_id = ThreadTracker { total: cores as u16, id: 1 + (core as u16), @@ -598,17 +621,22 @@ impl Fenrir { // finally run the main worker. // make sure things stay on this thread let tk_local = ::tokio::task::LocalSet::new(); - let _ = tk_local.block_on( - &th_tokio_rt, - Worker::new_and_loop( + let _ = tk_local.block_on(&th_tokio_rt, async move { + let mut worker = match Worker::new( th_config, thread_id, th_stop_working, th_token_check, - th_socket_addrs, + th_sockets, work_recv, - ), - ); + ) + .await + { + Ok(worker) => worker, + Err(_) => return, + }; + worker.work_loop().await + }); }); loop { let queues_lock = match Arc::get_mut(&mut self._thread_work) { @@ -627,7 +655,7 @@ impl Fenrir { self._thread_pool.push(join_handle); } // don't keep around private keys too much - self.cfg.keys.clear(); + self.cfg.server_keys.clear(); Ok(()) } } diff --git a/src/tests.rs b/src/tests.rs index 7d19b4f..acf57cc 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -21,23 +21,46 @@ async fn test_connection_dirsync() { cfg.threads = Some(::core::num::NonZeroUsize::new(1).unwrap()); cfg }; + let test_domain: Domain = "example.com".into(); let cfg_server = { let mut cfg = cfg_client.clone(); - cfg.keys = [(KeyID(42), priv_exchange_key, pub_exchange_key)].to_vec(); + cfg.server_keys = [config::ServerKey { + id: KeyID(42), + priv_key: priv_exchange_key, + pub_key: pub_exchange_key, + }] + .to_vec(); + cfg.servers = [config::AuthServer { + fqdn: test_domain.clone(), + keys: [KeyID(42)].to_vec(), + }] + .to_vec(); cfg }; let (server, mut srv_workers) = Fenrir::with_workers(&cfg_server).await.unwrap(); - - let srv_worker = srv_workers.pop().unwrap(); - let local_thread = ::tokio::task::LocalSet::new(); - local_thread.spawn_local(async move { srv_worker.await }); - let (client, mut cli_workers) = Fenrir::with_workers(&cfg_client).await.unwrap(); - let cli_worker = cli_workers.pop().unwrap(); - local_thread.spawn_local(async move { cli_worker.await }); + let mut srv_worker = srv_workers.pop().unwrap(); + let mut cli_worker = cli_workers.pop().unwrap(); + + ::std::thread::spawn(move || { + let rt = ::tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + let local_thread = ::tokio::task::LocalSet::new(); + local_thread.spawn_local(async move { + srv_worker.work_loop().await; + }); + + local_thread.spawn_local(async move { + ::tokio::time::sleep(::std::time::Duration::from_millis(100)).await; + cli_worker.work_loop().await; + }); + rt.block_on(local_thread); + }); use crate::{ connection::handshake::HandshakeID, @@ -63,17 +86,17 @@ async fn test_connection_dirsync() { ciphers: [enc::sym::CipherKind::XChaCha20Poly1305].to_vec(), }; - server.graceful_stop().await; - client.graceful_stop().await; - return; + ::tokio::time::sleep(::std::time::Duration::from_millis(500)).await; + match client + .connect_resolved(dnssec_record, &test_domain, auth::SERVICEID_AUTH) + .await + { + Ok(()) => {} + Err(e) => { + assert!(false, "Err on client connection: {:?}", e); + } + } - let _ = client - .connect_resolved( - dnssec_record, - &Domain("example.com".to_owned()), - auth::SERVICEID_AUTH, - ) - .await; server.graceful_stop().await; client.graceful_stop().await; }