diff --git a/src/connection/mod.rs b/src/connection/mod.rs index 8db39c4..b247cbc 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -2,12 +2,14 @@ pub mod handshake; mod packet; +pub mod socket; -use ::std::vec::Vec; +use ::std::{sync::Arc, vec::Vec}; -pub use handshake::Handshake; -pub use packet::ConnectionID as ID; -pub use packet::{Packet, PacketData}; +pub use crate::connection::{ + handshake::Handshake, + packet::{ConnectionID as ID, Packet, PacketData}, +}; use crate::enc::{ hkdf::HkdfSha3, @@ -83,3 +85,62 @@ impl Connection { } } } + +// PERF: Arc> loks a bit too much, need to find +// faster ways to do this +pub(crate) struct ConnList { + connections: Vec>>, + /// Bitmap to track which connection ids are used or free + ids_used: Vec<::bitmaps::Bitmap<1024>>, +} + +impl ConnList { + pub(crate) fn new() -> Self { + let mut bitmap_id = ::bitmaps::Bitmap::<1024>::new(); + bitmap_id.set(0, true); // ID(0) == handshake + Self { + connections: Vec::with_capacity(128), + ids_used: vec![bitmap_id], + } + } + pub(crate) fn reserve_first( + &mut self, + mut conn: Connection, + ) -> Arc { + // uhm... bad things are going on here: + // * id must be initialized, but only because: + // * rust does not understand that after the `!found` id is always + // initialized + // * `ID::new_u64` is really safe only with >0, but here it always is + // ...we should probably rewrite it in better, safer rust + let mut id: u64 = 0; + let mut found = false; + for (i, b) in self.ids_used.iter_mut().enumerate() { + match b.first_false_index() { + Some(idx) => { + b.set(idx, true); + id = ((i as u64) * 1024) + (idx as u64); + found = true; + break; + } + None => {} + } + } + if !found { + let mut new_bitmap = ::bitmaps::Bitmap::<1024>::new(); + new_bitmap.set(0, true); + id = (self.ids_used.len() as u64) * 1024; + self.ids_used.push(new_bitmap); + } + let new_id = ID::new_u64(id); + conn.id = new_id; + let conn = Arc::new(conn); + if (self.connections.len() as u64) < id { + self.connections.push(Some(conn.clone())); + } else { + // very probably redundant + self.connections[id as usize] = Some(conn.clone()); + } + conn + } +} diff --git a/src/connection/socket.rs b/src/connection/socket.rs new file mode 100644 index 0000000..455c567 --- /dev/null +++ b/src/connection/socket.rs @@ -0,0 +1,108 @@ +//! Socket related types and functions + +use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption}; +use ::std::{ + net::SocketAddr, + sync::Arc, + vec::{self, Vec}, +}; +use ::tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle}; + +// PERF: move to arcswap: +// We use multiple UDP sockets. +// We need a list of them all and to scan this list. +// But we want to be able to update this list without locking everyone +// So we should wrap into ArcSwap, and use unsafe `as_raf_fd()` and +// `from_raw_fd` to update the list. +// This means that we will have multiple `UdpSocket` per actual socket +// so we have to handle `drop()` manually, and garbage-collect the ones we +// are no longer using in the background. sigh. +// Just go with a ArcSwapAny, Arc>>); + +/// async free socket list +pub(crate) struct SocketList { + pub list: ArcSwap>, +} +impl SocketList { + pub(crate) fn new() -> Self { + Self { + list: ArcSwap::new(Arc::new(Vec::new())), + } + } + // TODO: fn rm_socket() + pub(crate) fn rm_all(&self) -> Self { + let new_list = Arc::new(Vec::new()); + let old_list = self.list.swap(new_list); + Self { + list: old_list.into(), + } + } + pub(crate) async fn add_socket( + &self, + socket: Arc, + handle: JoinHandle<::std::io::Result<()>>, + ) { + // we could simplify this into just a `.swap` instead of `.rcu` but + // it is not yet guaranteed that only one thread will call this fn + // ...we don't need performance here anyway + let arc_handle = Arc::new(handle); + self.list.rcu(|old_list| { + let mut new_list = Arc::new(Vec::with_capacity(old_list.len() + 1)); + new_list = old_list.to_vec().into(); + Arc::get_mut(&mut new_list) + .unwrap() + .push((socket.clone(), arc_handle.clone())); + new_list + }); + } + /// This method assumes no other `add_sockets` are being run + pub(crate) async fn stop_all(mut self) { + let mut arc_list = self.list.into_inner(); + let list = loop { + match Arc::try_unwrap(arc_list) { + Ok(list) => break list, + Err(arc_retry) => { + arc_list = arc_retry; + ::tokio::time::sleep(::core::time::Duration::from_millis( + 50, + )) + .await; + } + } + }; + for (_socket, mut handle) in list.into_iter() { + Arc::get_mut(&mut handle).unwrap().await; + } + } + pub(crate) fn lock(&self) -> SocketListRef { + SocketListRef { + list: self.list.load_full(), + } + } +} + +/// Reference to a locked SocketList +// TODO: impl Drop for SocketList +pub(crate) struct SocketListRef { + list: Arc>, +} +impl SocketListRef { + pub(crate) fn find(&self, sock: UdpServer) -> Option> { + match self + .list + .iter() + .find(|&(s, _)| s.local_addr().unwrap() == sock.0) + { + Some((sock_srv, _)) => Some(sock_srv.clone()), + None => None, + } + } +} + +/// Strong typedef for a client socket address +#[derive(Debug, Copy, Clone)] +pub(crate) struct UdpClient(pub SocketAddr); +/// Strong typedef for a server socket address +#[derive(Debug, Copy, Clone)] +pub(crate) struct UdpServer(pub SocketAddr); diff --git a/src/inner/mod.rs b/src/inner/mod.rs new file mode 100644 index 0000000..22cafc8 --- /dev/null +++ b/src/inner/mod.rs @@ -0,0 +1,213 @@ +//! Inner Fenrir tracking +//! This is meant to be **async-free** so that others might use it +//! without the tokio runtime + +use crate::{ + connection::{ + self, + handshake::{self, Handshake, HandshakeClient, HandshakeServer}, + Connection, + }, + enc::{ + self, asym, + hkdf::HkdfSha3, + sym::{CipherKind, CipherRecv, CipherSend, HeadLen, TagLen}, + }, + Error, +}; +use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption}; +use ::std::{sync::Arc, vec::Vec}; + +/// Information needed to reply after the key exchange +#[derive(Debug, Clone)] +pub struct AuthNeededInfo { + /// Parsed handshake + pub handshake: Handshake, + /// hkdf generated from the handshake + pub hkdf: HkdfSha3, + /// cipher to be used in both directions + pub cipher: CipherKind, +} + +/// Client information needed to fully establish the conenction +#[derive(Debug)] +pub struct ClientConnectInfo { + /// Parsed handshake + pub handshake: Handshake, + /// hkdf generated from the handshake + pub hkdf: HkdfSha3, + /// cipher to be used in both directions + pub cipher_recv: CipherRecv, +} +/// Intermediate actions to be taken while parsing the handshake +#[derive(Debug)] +pub enum HandshakeAction { + /// Parsing finished, all ok, nothing to do + None, + /// Packet parsed, now go perform authentication + AuthNeeded(AuthNeededInfo), + /// the client can fully establish a connection with this info + ClientConnect(ClientConnectInfo), +} + +/// Async free but thread safe tracking of handhsakes and conenctions +pub struct Tracker { + key_exchanges: ArcSwapAny>>, + ciphers: ArcSwapAny>>, + /// ephemeral keys used server side in key exchange + keys_srv: ArcSwapAny>>, + /// ephemeral keys used client side in key exchange + hshake_cli: ArcSwapAny>>, +} +#[allow(unsafe_code)] +unsafe impl Send for Tracker {} +#[allow(unsafe_code)] +unsafe impl Sync for Tracker {} + +impl Tracker { + pub fn new() -> Self { + Self { + ciphers: ArcSwapAny::new(Arc::new(Vec::new())), + key_exchanges: ArcSwapAny::new(Arc::new(Vec::new())), + keys_srv: ArcSwapAny::new(Arc::new(Vec::new())), + hshake_cli: ArcSwapAny::new(Arc::new(Vec::new())), + } + } + pub(crate) fn recv_handshake( + &self, + mut handshake: Handshake, + handshake_raw: &mut [u8], + ) -> Result { + use connection::handshake::{ + dirsync::{self, DirSync}, + HandshakeData, + }; + match handshake.data { + HandshakeData::DirSync(ref mut ds) => match ds { + DirSync::Req(ref mut req) => { + let ephemeral_key = { + // Keep this block short to avoid contention + // on self.keys_srv + let keys = self.keys_srv.load(); + if let Some(h_k) = + keys.iter().find(|k| k.id == req.key_id) + { + use enc::asym::PrivKey; + // Directory synchronized can only use keys + // for key exchange, not signing keys + if let PrivKey::Exchange(k) = &h_k.key { + Some(k.clone()) + } else { + None + } + } else { + None + } + }; + if ephemeral_key.is_none() { + ::tracing::debug!( + "No such server key id: {:?}", + req.key_id + ); + return Err(handshake::Error::UnknownKeyID.into()); + } + let ephemeral_key = ephemeral_key.unwrap(); + { + let exchanges = self.key_exchanges.load(); + if None + == exchanges.iter().find(|&x| { + *x == (ephemeral_key.kind(), req.exchange) + }) + { + return Err( + enc::Error::UnsupportedKeyExchange.into() + ); + } + } + { + let ciphers = self.ciphers.load(); + if None == ciphers.iter().find(|&x| *x == req.cipher) { + return Err(enc::Error::UnsupportedCipher.into()); + } + } + let shared_key = match ephemeral_key + .key_exchange(req.exchange, req.exchange_key) + { + Ok(shared_key) => shared_key, + Err(e) => return Err(handshake::Error::Key(e).into()), + }; + let hkdf = HkdfSha3::new(b"fenrir", shared_key); + let secret_recv = hkdf.get_secret(b"to_server"); + let cipher_recv = CipherRecv::new(req.cipher, secret_recv); + use crate::enc::sym::AAD; + let aad = AAD(&mut []); // no aad for now + match cipher_recv.decrypt( + aad, + &mut handshake_raw[req.encrypted_offset()..], + ) { + Ok(cleartext) => { + req.data.deserialize_as_cleartext(cleartext) + } + Err(e) => { + return Err(handshake::Error::Key(e).into()); + } + } + + let cipher = req.cipher; + + return Ok(HandshakeAction::AuthNeeded(AuthNeededInfo { + handshake, + hkdf, + cipher, + })); + } + DirSync::Resp(resp) => { + let hshake = { + // Keep this block short to avoid contention + // on self.hshake_cli + let hshake_cli_lock = self.hshake_cli.load(); + match hshake_cli_lock + .iter() + .find(|h| h.id == resp.client_key_id) + { + Some(h) => Some(h.clone()), + None => None, + } + }; + if hshake.is_none() { + ::tracing::debug!( + "No such client key id: {:?}", + resp.client_key_id + ); + return Err(handshake::Error::UnknownKeyID.into()); + } + let hshake = hshake.unwrap(); + let secret_recv = hshake.hkdf.get_secret(b"to_client"); + let cipher_recv = + CipherRecv::new(hshake.cipher, secret_recv); + use crate::enc::sym::AAD; + // no aad for now + let aad = AAD(&mut []); + let mut raw_data = &mut handshake_raw[resp + .encrypted_offset() + ..(resp.encrypted_offset() + resp.encrypted_length())]; + match cipher_recv.decrypt(aad, &mut raw_data) { + Ok(cleartext) => { + resp.data.deserialize_as_cleartext(&cleartext) + } + Err(e) => { + return Err(handshake::Error::Key(e).into()); + } + } + return Ok(HandshakeAction::ClientConnect( + ClientConnectInfo { + handshake, + hkdf: hshake.hkdf, + cipher_recv, + }, + )); + } + }, + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 0c75750..6972d8b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,23 +18,33 @@ mod config; pub mod connection; pub mod dnssec; pub mod enc; +mod inner; use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption}; -use ::std::{net::SocketAddr, pin::Pin, sync::Arc, vec, vec::Vec}; +use ::std::{ + net::SocketAddr, + pin::Pin, + sync::Arc, + vec::{self, Vec}, +}; use ::tokio::{ macros::support::Future, net::UdpSocket, sync::RwLock, task::JoinHandle, }; -use crate::enc::{ - asym, - hkdf::HkdfSha3, - sym::{CipherKind, CipherRecv, CipherSend, HeadLen, TagLen}, +use crate::{ + connection::{ + handshake::{self, Handshake, HandshakeClient, HandshakeServer}, + socket::{SocketList, UdpClient, UdpServer}, + ConnList, Connection, + }, + enc::{ + asym, + hkdf::HkdfSha3, + sym::{CipherKind, CipherRecv, CipherSend, HeadLen, TagLen}, + }, + inner::HandshakeAction, }; pub use config::Config; -use connection::{ - handshake::{self, Handshake, HandshakeClient, HandshakeServer}, - Connection, -}; /// Main fenrir library errors #[derive(::thiserror::Error, Debug)] @@ -56,194 +66,6 @@ pub enum Error { Key(#[from] crate::enc::Error), } -/// Information needed to reply after the key exchange -#[derive(Debug, Clone)] -pub struct AuthNeededInfo { - /// Parsed handshake - pub handshake: Handshake, - /// hkdf generated from the handshake - pub hkdf: HkdfSha3, - /// cipher to be used in both directions - pub cipher: CipherKind, -} - -/// Client information needed to fully establish the conenction -#[derive(Debug)] -pub struct ClientConnectInfo { - /// Parsed handshake - pub handshake: Handshake, - /// hkdf generated from the handshake - pub hkdf: HkdfSha3, - /// cipher to be used in both directions - pub cipher_recv: CipherRecv, -} - -/// Intermediate actions to be taken while parsing the handshake -#[derive(Debug)] -pub enum HandshakeAction { - /// Parsing finished, all ok, nothing to do - None, - /// Packet parsed, now go perform authentication - AuthNeeded(AuthNeededInfo), - /// the client can fully establish a connection with this info - ClientConnect(ClientConnectInfo), -} -// No async here -struct FenrirInner { - key_exchanges: ArcSwapAny>>, - ciphers: ArcSwapAny>>, - /// ephemeral keys used server side in key exchange - keys_srv: ArcSwapAny>>, - /// ephemeral keys used client side in key exchange - hshake_cli: ArcSwapAny>>, -} - -#[allow(unsafe_code)] -unsafe impl Send for FenrirInner {} -#[allow(unsafe_code)] -unsafe impl Sync for FenrirInner {} - -// No async here -impl FenrirInner { - fn recv_handshake( - &self, - mut handshake: Handshake, - handshake_raw: &mut [u8], - ) -> Result { - use connection::handshake::{ - dirsync::{self, DirSync}, - HandshakeData, - }; - match handshake.data { - HandshakeData::DirSync(ref mut ds) => match ds { - DirSync::Req(ref mut req) => { - let ephemeral_key = { - // Keep this block short to avoid contention - // on self.keys_srv - let keys = self.keys_srv.load(); - if let Some(h_k) = - keys.iter().find(|k| k.id == req.key_id) - { - use enc::asym::PrivKey; - // Directory synchronized can only use keys - // for key exchange, not signing keys - if let PrivKey::Exchange(k) = &h_k.key { - Some(k.clone()) - } else { - None - } - } else { - None - } - }; - if ephemeral_key.is_none() { - ::tracing::debug!( - "No such server key id: {:?}", - req.key_id - ); - return Err(handshake::Error::UnknownKeyID.into()); - } - let ephemeral_key = ephemeral_key.unwrap(); - { - let exchanges = self.key_exchanges.load(); - if None - == exchanges.iter().find(|&x| { - *x == (ephemeral_key.kind(), req.exchange) - }) - { - return Err( - enc::Error::UnsupportedKeyExchange.into() - ); - } - } - { - let ciphers = self.ciphers.load(); - if None == ciphers.iter().find(|&x| *x == req.cipher) { - return Err(enc::Error::UnsupportedCipher.into()); - } - } - let shared_key = match ephemeral_key - .key_exchange(req.exchange, req.exchange_key) - { - Ok(shared_key) => shared_key, - Err(e) => return Err(handshake::Error::Key(e).into()), - }; - let hkdf = HkdfSha3::new(b"fenrir", shared_key); - let secret_recv = hkdf.get_secret(b"to_server"); - let cipher_recv = CipherRecv::new(req.cipher, secret_recv); - use crate::enc::sym::AAD; - let aad = AAD(&mut []); // no aad for now - match cipher_recv.decrypt( - aad, - &mut handshake_raw[req.encrypted_offset()..], - ) { - Ok(cleartext) => { - req.data.deserialize_as_cleartext(cleartext) - } - Err(e) => { - return Err(handshake::Error::Key(e).into()); - } - } - - let cipher = req.cipher; - - return Ok(HandshakeAction::AuthNeeded(AuthNeededInfo { - handshake, - hkdf, - cipher, - })); - } - DirSync::Resp(resp) => { - let hshake = { - // Keep this block short to avoid contention - // on self.hshake_cli - let hshake_cli_lock = self.hshake_cli.load(); - match hshake_cli_lock - .iter() - .find(|h| h.id == resp.client_key_id) - { - Some(h) => Some(h.clone()), - None => None, - } - }; - if hshake.is_none() { - ::tracing::debug!( - "No such client key id: {:?}", - resp.client_key_id - ); - return Err(handshake::Error::UnknownKeyID.into()); - } - let hshake = hshake.unwrap(); - let secret_recv = hshake.hkdf.get_secret(b"to_client"); - let cipher_recv = - CipherRecv::new(hshake.cipher, secret_recv); - use crate::enc::sym::AAD; - // no aad for now - let aad = AAD(&mut []); - let mut raw_data = &mut handshake_raw[resp - .encrypted_offset() - ..(resp.encrypted_offset() + resp.encrypted_length())]; - match cipher_recv.decrypt(aad, &mut raw_data) { - Ok(cleartext) => { - resp.data.deserialize_as_cleartext(&cleartext) - } - Err(e) => { - return Err(handshake::Error::Key(e).into()); - } - } - return Ok(HandshakeAction::ClientConnect( - ClientConnectInfo { - handshake, - hkdf: hshake.hkdf, - cipher_recv, - }, - )); - } - }, - } - } -} - type TokenChecker = fn( user: auth::UserID, @@ -252,99 +74,7 @@ type TokenChecker = domain: auth::Domain, ) -> ::futures::future::BoxFuture<'static, Result>; -// PERF: move to arcswap: -// We use multiple UDP sockets. -// We need a list of them all and to scan this list. -// But we want to be able to update this list without locking everyone -// So we should wrap into ArcSwap, and use unsafe `as_raf_fd()` and -// `from_raw_fd` to update the list. -// This means that we will have multiple `UdpSocket` per actual socket -// so we have to handle `drop()` manually, and garbage-collect the ones we -// are no longer using in the background. sigh. -// Just go with a ArcSwapAny, Arc>>); -struct SocketList { - list: ArcSwap>, -} -impl SocketList { - fn new() -> Self { - Self { - list: ArcSwap::new(Arc::new(Vec::new())), - } - } - // TODO: fn rm_socket() - fn rm_all(&self) -> Self { - let new_list = Arc::new(Vec::new()); - let old_list = self.list.swap(new_list); - Self { - list: old_list.into(), - } - } - async fn add_socket( - &self, - socket: Arc, - handle: JoinHandle<::std::io::Result<()>>, - ) { - // we could simplify this into just a `.swap` instead of `.rcu` but - // it is not yet guaranteed that only one thread will call this fn - // ...we don't need performance here anyway - let arc_handle = Arc::new(handle); - self.list.rcu(|old_list| { - let mut new_list = Arc::new(Vec::with_capacity(old_list.len() + 1)); - new_list = old_list.to_vec().into(); - Arc::get_mut(&mut new_list) - .unwrap() - .push((socket.clone(), arc_handle.clone())); - new_list - }); - } - /// This method assumes no other `add_sockets` are being run - async fn stop_all(mut self) { - let mut arc_list = self.list.into_inner(); - let list = loop { - match Arc::try_unwrap(arc_list) { - Ok(list) => break list, - Err(arc_retry) => { - arc_list = arc_retry; - ::tokio::time::sleep(::core::time::Duration::from_millis( - 50, - )) - .await; - } - } - }; - for (_socket, mut handle) in list.into_iter() { - Arc::get_mut(&mut handle).unwrap().await; - } - } - fn lock(&self) -> SocketListRef { - SocketListRef { - list: self.list.load_full(), - } - } -} -// TODO: impl Drop for SocketList -struct SocketListRef { - list: Arc>, -} -impl SocketListRef { - fn find(&self, sock: UdpServer) -> Option> { - match self - .list - .iter() - .find(|&(s, _)| s.local_addr().unwrap() == sock.0) - { - Some((sock_srv, _)) => Some(sock_srv.clone()), - None => None, - } - } -} - -#[derive(Debug, Copy, Clone)] -struct UdpClient(SocketAddr); -#[derive(Debug, Copy, Clone)] -struct UdpServer(SocketAddr); - +/// Track a raw Udp packet struct RawUdp { data: Vec, src: UdpClient, @@ -355,62 +85,6 @@ enum Work { Recv(RawUdp), } -// PERF: Arc> loks a bit too much, need to find -// faster ways to do this -struct ConnList { - connections: Vec>>, - /// Bitmap to track which connection ids are used or free - ids_used: Vec<::bitmaps::Bitmap<1024>>, -} - -impl ConnList { - fn new() -> Self { - let mut bitmap_id = ::bitmaps::Bitmap::<1024>::new(); - bitmap_id.set(0, true); // ID(0) == handshake - Self { - connections: Vec::with_capacity(128), - ids_used: vec![bitmap_id], - } - } - fn reserve_first(&mut self, mut conn: Connection) -> Arc { - // uhm... bad things are going on here: - // * id must be initialized, but only because: - // * rust does not understand that after the `!found` id is always - // initialized - // * `ID::new_u64` is really safe only with >0, but here it always is - // ...we should probably rewrite it in better, safer rust - let mut id: u64 = 0; - let mut found = false; - for (i, b) in self.ids_used.iter_mut().enumerate() { - match b.first_false_index() { - Some(idx) => { - b.set(idx, true); - id = ((i as u64) * 1024) + (idx as u64); - found = true; - break; - } - None => {} - } - } - if !found { - let mut new_bitmap = ::bitmaps::Bitmap::<1024>::new(); - new_bitmap.set(0, true); - id = (self.ids_used.len() as u64) * 1024; - self.ids_used.push(new_bitmap); - } - let new_id = connection::ID::new_u64(id); - conn.id = new_id; - let conn = Arc::new(conn); - if (self.connections.len() as u64) < id { - self.connections.push(Some(conn.clone())); - } else { - // very probably redundant - self.connections[id as usize] = Some(conn.clone()); - } - conn - } -} - /// Instance of a fenrir endpoint #[allow(missing_copy_implementations, missing_debug_implementations)] pub struct Fenrir { @@ -423,7 +97,7 @@ pub struct Fenrir { /// Broadcast channel to tell workers to stop working stop_working: ::tokio::sync::broadcast::Sender, /// Private keys used in the handshake - _inner: Arc, + _inner: Arc, /// where to ask for token check token_check: Arc>, /// MPMC work queue. sender @@ -455,12 +129,7 @@ impl Fenrir { sockets: SocketList::new(), dnssec: None, stop_working: sender, - _inner: Arc::new(FenrirInner { - ciphers: ArcSwapAny::new(Arc::new(Vec::new())), - key_exchanges: ArcSwapAny::new(Arc::new(Vec::new())), - keys_srv: ArcSwapAny::new(Arc::new(Vec::new())), - hshake_cli: ArcSwapAny::new(Arc::new(Vec::new())), - }), + _inner: Arc::new(inner::Tracker::new()), token_check: Arc::new(ArcSwapOption::from(None)), work_send: Arc::new(work_send), work_recv: Arc::new(work_recv),