diff --git a/src/enc/errors.rs b/src/enc/errors.rs index 6effb50..ded1fc0 100644 --- a/src/enc/errors.rs +++ b/src/enc/errors.rs @@ -25,4 +25,7 @@ pub enum Error { /// Can not decrypt. Either corrupted or malicious data #[error("decrypt: corrupted data")] Decrypt, + /// Can not encrypt. library failure + #[error("can't encrypt")] + Encrypt, } diff --git a/src/enc/sym.rs b/src/enc/sym.rs index 6e3d3e5..f54cd0b 100644 --- a/src/enc/sym.rs +++ b/src/enc/sym.rs @@ -151,28 +151,44 @@ impl Cipher { } } } + fn overhead(&self) -> usize { + match self { + Cipher::XChaCha20Poly1305(cipher) => { + let cipher = CipherKind::XChaCha20Poly1305; + cipher.nonce_len() + cipher.tag_len() + } + } + } fn encrypt( &self, + nonce: &Nonce, aad: AAD, - nonce: Nonce, - data: &mut [u8], + data: &mut Data, ) -> Result<(), Error> { + // No need to check for minimum buffer size since `Data` assures we + // already went through that match self { Cipher::XChaCha20Poly1305(cipher) => { use ::chacha20poly1305::{ aead::generic_array::GenericArray, AeadInPlace, }; - let min_len: usize = CipherKind::XChaCha20Poly1305.nonce_len() - + CipherKind::XChaCha20Poly1305.tag_len() - + 1; - if data.len() < min_len { - return Err(Error::InsufficientBuffer); - } - // write Nonce, then advance it + // write nonce + data.get_slice_full()[..Nonce::len()] + .copy_from_slice(nonce.as_bytes()); // encrypt data - - // add tag + match cipher.cipher.encrypt_in_place_detached( + nonce.as_bytes().into(), + aad.0, + data.get_slice(), + ) { + Ok(tag) => { + // add tag + data.get_tag_slice().copy_from_slice(tag.as_slice()); + Ok(()) + } + Err(_) => Err(Error::Encrypt), + }; } } todo!() @@ -203,6 +219,35 @@ impl CipherRecv { } } +/// Allocate some data, with additional indexes to track +/// where nonce and tags are +#[derive(Debug, Clone)] +pub struct Data { + data: Vec, + skip_start: usize, + skip_end: usize, +} + +impl Data { + /// Get the slice where you will write the actual data + /// this will skip the actual nonce and AEAD tag and give you + /// only the space for the data + pub fn get_slice(&mut self) -> &mut [u8] { + &mut self.data[self.skip_start..self.skip_end] + } + fn get_tag_slice(&mut self) -> &mut [u8] { + let start = self.data.len() - self.skip_end; + &mut self.data[start..] + } + fn get_slice_full(&mut self) -> &mut [u8] { + &mut self.data + } + /// Consume the data and return the whole raw vector + pub fn get_raw(self) -> Vec { + self.data + } +} + /// Send only cipher #[allow(missing_debug_implementations)] pub struct CipherSend { @@ -218,9 +263,19 @@ impl CipherSend { cipher: Cipher::new(kind, secret), } } - /// Get the current nonce as &[u8] - pub fn nonce_as_bytes(&self) -> &[u8] { - self.nonce.as_bytes() + /// Allocate the memory for the data that will be encrypted + pub fn make_data(&self, length: usize) -> Data { + Data { + data: Vec::with_capacity(length + self.cipher.overhead()), + skip_start: self.cipher.nonce_len(), + skip_end: self.cipher.tag_len(), + } + } + /// Encrypt the given data + pub fn encrypt(&mut self, aad: AAD, data: &mut Data) -> Result<(), Error> { + self.cipher.encrypt(&self.nonce, aad, data)?; + self.nonce.advance(); + Ok(()) } } diff --git a/src/lib.rs b/src/lib.rs index ad8fea1..9fcce95 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -179,13 +179,85 @@ type TokenChecker = domain: auth::Domain, ) -> ::futures::future::BoxFuture<'static, Result>; +// PERF: move to arcswap: +// We use multiple UDP sockets. +// We need a list of them all and to scan this list. +// But we want to be able to update this list without locking everyone +// So we should wrap into ArcSwap, and use unsafe `as_raf_fd()` and +// `from_raw_fd` to update the list. +// This means that we will have multiple `UdpSocket` per actual socket +// so we have to handle `drop()` manually, and garbage-collect the ones we +// are no longer using in the background. sigh. +// Just go with a ArcSwapAny, Arc>>)>>, +} +impl SocketList { + fn new() -> Self { + Self { + sockets: ArcSwap::new(Arc::new(Vec::new())), + } + } + // TODO: fn rm_socket() + fn rm_all(&self) -> Self { + let new_list = Arc::new(Vec::new()); + let old_list = self.sockets.swap(new_list); + Self { + sockets: old_list.into(), + } + } + async fn add_socket( + &self, + socket: Arc, + handle: JoinHandle<::std::io::Result<()>>, + ) { + let mut new_list; + { + let old_list = self.sockets.load(); + new_list = Arc::new(Vec::with_capacity(old_list.len() + 1)); + new_list = old_list.to_vec().into(); + } + Arc::get_mut(&mut new_list) + .unwrap() + .push((socket, Arc::new(handle))); + self.sockets.swap(new_list); + } + async fn find(&self, sock: SocketAddr) -> Option> { + let list = self.sockets.load(); + match list.iter().find(|&(s, _)| s.local_addr().unwrap() == sock) { + Some((sock, _)) => Some(sock.clone()), + None => None, + } + } + async fn stop_all(mut self) { + let mut arc_list = self.sockets.into_inner(); + let list = loop { + match Arc::try_unwrap(arc_list) { + Ok(list) => break list, + Err(arc_retry) => { + arc_list = arc_retry; + ::tokio::time::sleep(::core::time::Duration::from_millis( + 50, + )) + .await; + } + } + }; + for (_socket, mut handle) in list.into_iter() { + Arc::get_mut(&mut handle).unwrap().await; + } + } +} + /// Instance of a fenrir endpoint #[allow(missing_copy_implementations, missing_debug_implementations)] pub struct Fenrir { /// library Configuration cfg: Config, /// listening udp sockets - sockets: Vec<(Arc, JoinHandle<::std::io::Result<()>>)>, + //sockets: Vec<(Arc, JoinHandle<::std::io::Result<()>>)>, + sockets: SocketList, /// DNSSEC resolver, with failovers dnssec: Option, /// Broadcast channel to tell workers to stop working @@ -211,7 +283,7 @@ impl Fenrir { let (sender, _) = ::tokio::sync::broadcast::channel(1); let endpoint = Fenrir { cfg: config.clone(), - sockets: Vec::with_capacity(listen_num), + sockets: SocketList::new(), dnssec: None, stop_working: sender, _inner: Arc::new(FenrirInner { @@ -239,9 +311,11 @@ impl Fenrir { /// asyncronous version for Drop fn stop_sync(&mut self) { let _ = self.stop_working.send(true); - let mut toempty_socket = Vec::new(); - ::std::mem::swap(&mut self.sockets, &mut toempty_socket); - let task = ::tokio::task::spawn(Self::stop_sockets(toempty_socket)); + let mut toempty_sockets = self.sockets.rm_all(); + let task = ::tokio::task::spawn(Self::stop_sockets(toempty_sockets)); + //let mut toempty_socket = Vec::new(); + //::std::mem::swap(&mut self.sockets, &mut toempty_socket); + //let task = ::tokio::task::spawn(Self::stop_sockets(toempty_socket)); let _ = ::futures::executor::block_on(task); self.dnssec = None; } @@ -249,19 +323,19 @@ impl Fenrir { /// Stop all workers, listeners pub async fn stop(&mut self) { let _ = self.stop_working.send(true); - let mut toempty_socket = Vec::new(); - ::std::mem::swap(&mut self.sockets, &mut toempty_socket); - Self::stop_sockets(toempty_socket).await; + let mut toempty_sockets = self.sockets.rm_all(); + Self::stop_sockets(toempty_sockets).await; self.dnssec = None; } /// actually do the work of stopping resolvers and listeners - async fn stop_sockets( - sockets: Vec<(Arc, JoinHandle<::std::io::Result<()>>)>, - ) { + async fn stop_sockets(sockets: SocketList) { + sockets.stop_all().await; + /* for s in sockets.into_iter() { let _ = s.1.await; } + */ } /// Enable some common socket options. This is just the unsafe part @@ -285,6 +359,37 @@ impl Fenrir { } Ok(()) } + + /// Add all UDP sockets found in config + /// and start listening for packets + async fn add_sockets(&mut self) -> ::std::io::Result<()> { + let sockets = self.cfg.listen.iter().map(|s_addr| async { + let socket = + ::tokio::spawn(Self::bind_udp(s_addr.clone())).await??; + Ok(socket) + }); + let sockets = ::futures::future::join_all(sockets).await; + for s_res in sockets.into_iter() { + match s_res { + Ok(s) => { + let stop_working = self.stop_working.subscribe(); + let arc_s = Arc::new(s); + let join = ::tokio::spawn(Self::listen_udp( + stop_working, + self._inner.clone(), + self.token_check.clone(), + arc_s.clone(), + )); + self.sockets.add_socket(arc_s, join); + } + Err(e) => { + return Err(e); + } + } + } + Ok(()) + } + /// Add an async udp listener async fn bind_udp(sock: SocketAddr) -> ::std::io::Result { let socket = UdpSocket::bind(sock).await?; @@ -319,9 +424,10 @@ impl Fenrir { socket: Arc, ) -> ::std::io::Result<()> { // jumbo frames are 9K max + let sock_receiver = socket.local_addr()?; let mut buffer: [u8; 9000] = [0; 9000]; loop { - let (bytes, sock_from) = ::tokio::select! { + let (bytes, sock_sender) = ::tokio::select! { _done = stop_working.recv() => { break; } @@ -333,42 +439,14 @@ impl Fenrir { fenrir.clone(), token_check.clone(), &buffer[0..bytes], - sock_from, + sock_receiver, + sock_sender, ) .await; } Ok(()) } - /// Add all UDP sockets found in config - /// and start listening for packets - async fn add_sockets(&mut self) -> ::std::io::Result<()> { - let sockets = self.cfg.listen.iter().map(|s_addr| async { - let socket = - ::tokio::spawn(Self::bind_udp(s_addr.clone())).await??; - Ok(Arc::new(socket)) - }); - let sockets = ::futures::future::join_all(sockets).await; - for s_res in sockets.into_iter() { - match s_res { - Ok(s) => { - let stop_working = self.stop_working.subscribe(); - let join = ::tokio::spawn(Self::listen_udp( - stop_working, - self._inner.clone(), - self.token_check.clone(), - s.clone(), - )); - self.sockets.push((s, join)); - } - Err(e) => { - return Err(e); - } - } - } - Ok(()) - } - /// Get the raw TXT record of a Fenrir domain pub async fn resolv_str(&self, domain: &str) -> Result { match &self.dnssec { @@ -389,7 +467,8 @@ impl Fenrir { fenrir: Arc, token_check: Arc>, buffer: &[u8], - _sock_from: SocketAddr, + _sock_receiver: SocketAddr, + _sock_sender: SocketAddr, ) { if buffer.len() < Self::MIN_PACKET_BYTES { return; @@ -484,20 +563,25 @@ impl Fenrir { // build response let secret_send = authinfo.hkdf.get_secret(b"to_client"); - let cipher_send = CipherRecv::new( + let mut cipher_send = CipherSend::new( authinfo.cipher, secret_send, ); use crate::enc::sym::AAD; let aad = AAD(&mut []); // no aad for now - /* - match cipher_send.encrypt(aad, &mut req.data.ciphertext()) { - Ok(()) => req.data.mark_as_cleartext(), - Err(e) => { - return Err(handshake::Error::Key(e).into()); - } - } - */ + let mut data = cipher_send + .make_data(dirsync::RespData::len()); + + if let Err(e) = + cipher_send.encrypt(aad, &mut data) + { + ::tracing::error!("can't encrypt: {:?}", e); + return; + } + let resp = dirsync::Resp { + client_key_id: req_data.client_key_id, + enc: data.get_raw(), + }; todo!() } _ => {