#![deny( missing_docs, missing_debug_implementations, missing_copy_implementations, trivial_casts, trivial_numeric_casts, unsafe_code, unstable_features, unused_import_braces, unused_qualifications )] //! //! libFenrir is the official rust library implementing the Fenrir protocol pub mod auth; mod config; pub mod connection; pub mod dnssec; pub mod enc; use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption}; use ::std::{net::SocketAddr, pin::Pin, sync::Arc, vec, vec::Vec}; use ::tokio::{ macros::support::Future, net::UdpSocket, sync::RwLock, task::JoinHandle, }; use crate::enc::{ asym, hkdf::HkdfSha3, sym::{CipherKind, CipherRecv, CipherSend, HeadLen, TagLen}, }; pub use config::Config; use connection::{ handshake::{self, Handshake, HandshakeClient, HandshakeServer}, Connection, }; /// Main fenrir library errors #[derive(::thiserror::Error, Debug)] pub enum Error { /// The library was not initialized (run .start()) #[error("not initialized")] NotInitialized, /// General I/O error #[error("IO: {0:?}")] IO(#[from] ::std::io::Error), /// Dnssec errors #[error("Dnssec: {0:?}")] Dnssec(#[from] dnssec::Error), /// Handshake errors #[error("Handshake: {0:?}")] Handshake(#[from] handshake::Error), /// Key error #[error("key: {0:?}")] Key(#[from] crate::enc::Error), } /// Information needed to reply after the key exchange #[derive(Debug, Clone)] pub struct AuthNeededInfo { /// Parsed handshake pub handshake: Handshake, /// hkdf generated from the handshake pub hkdf: HkdfSha3, /// cipher to be used in both directions pub cipher: CipherKind, } /// Client information needed to fully establish the conenction #[derive(Debug)] pub struct ClientConnectInfo { /// Parsed handshake pub handshake: Handshake, /// hkdf generated from the handshake pub hkdf: HkdfSha3, /// cipher to be used in both directions pub cipher_recv: CipherRecv, } /// Intermediate actions to be taken while parsing the handshake #[derive(Debug)] pub enum HandshakeAction { /// Parsing finished, all ok, nothing to do None, /// Packet parsed, now go perform authentication AuthNeeded(AuthNeededInfo), /// the client can fully establish a connection with this info ClientConnect(ClientConnectInfo), } // No async here struct FenrirInner { key_exchanges: ArcSwapAny>>, ciphers: ArcSwapAny>>, /// ephemeral keys used server side in key exchange keys_srv: ArcSwapAny>>, /// ephemeral keys used client side in key exchange hshake_cli: ArcSwapAny>>, } #[allow(unsafe_code)] unsafe impl Send for FenrirInner {} #[allow(unsafe_code)] unsafe impl Sync for FenrirInner {} // No async here impl FenrirInner { fn recv_handshake( &self, mut handshake: Handshake, handshake_raw: &mut [u8], ) -> Result { use connection::handshake::{ dirsync::{self, DirSync}, HandshakeData, }; match handshake.data { HandshakeData::DirSync(ref mut ds) => match ds { DirSync::Req(ref mut req) => { let ephemeral_key = { // Keep this block short to avoid contention // on self.keys_srv let keys = self.keys_srv.load(); if let Some(h_k) = keys.iter().find(|k| k.id == req.key_id) { use enc::asym::PrivKey; // Directory synchronized can only use keys // for key exchange, not signing keys if let PrivKey::Exchange(k) = &h_k.key { Some(k.clone()) } else { None } } else { None } }; if ephemeral_key.is_none() { ::tracing::debug!( "No such server key id: {:?}", req.key_id ); return Err(handshake::Error::UnknownKeyID.into()); } let ephemeral_key = ephemeral_key.unwrap(); { let exchanges = self.key_exchanges.load(); if None == exchanges.iter().find(|&x| { *x == (ephemeral_key.kind(), req.exchange) }) { return Err( enc::Error::UnsupportedKeyExchange.into() ); } } { let ciphers = self.ciphers.load(); if None == ciphers.iter().find(|&x| *x == req.cipher) { return Err(enc::Error::UnsupportedCipher.into()); } } let shared_key = match ephemeral_key .key_exchange(req.exchange, req.exchange_key) { Ok(shared_key) => shared_key, Err(e) => return Err(handshake::Error::Key(e).into()), }; let hkdf = HkdfSha3::new(b"fenrir", shared_key); let secret_recv = hkdf.get_secret(b"to_server"); let cipher_recv = CipherRecv::new(req.cipher, secret_recv); use crate::enc::sym::AAD; let aad = AAD(&mut []); // no aad for now match cipher_recv.decrypt( aad, &mut handshake_raw[req.encrypted_offset()..], ) { Ok(cleartext) => { req.data.deserialize_as_cleartext(cleartext) } Err(e) => { return Err(handshake::Error::Key(e).into()); } } let cipher = req.cipher; return Ok(HandshakeAction::AuthNeeded(AuthNeededInfo { handshake, hkdf, cipher, })); } DirSync::Resp(resp) => { let hshake = { // Keep this block short to avoid contention // on self.hshake_cli let hshake_cli_lock = self.hshake_cli.load(); match hshake_cli_lock .iter() .find(|h| h.id == resp.client_key_id) { Some(h) => Some(h.clone()), None => None, } }; if hshake.is_none() { ::tracing::debug!( "No such client key id: {:?}", resp.client_key_id ); return Err(handshake::Error::UnknownKeyID.into()); } let hshake = hshake.unwrap(); let secret_recv = hshake.hkdf.get_secret(b"to_client"); let cipher_recv = CipherRecv::new(hshake.cipher, secret_recv); use crate::enc::sym::AAD; // no aad for now let aad = AAD(&mut []); let mut raw_data = &mut handshake_raw[resp .encrypted_offset() ..(resp.encrypted_offset() + resp.encrypted_length())]; match cipher_recv.decrypt(aad, &mut raw_data) { Ok(cleartext) => { resp.data.deserialize_as_cleartext(&cleartext) } Err(e) => { return Err(handshake::Error::Key(e).into()); } } return Ok(HandshakeAction::ClientConnect( ClientConnectInfo { handshake, hkdf: hshake.hkdf, cipher_recv, }, )); } }, } } } type TokenChecker = fn( user: auth::UserID, token: auth::Token, service_id: auth::ServiceID, domain: auth::Domain, ) -> ::futures::future::BoxFuture<'static, Result>; // PERF: move to arcswap: // We use multiple UDP sockets. // We need a list of them all and to scan this list. // But we want to be able to update this list without locking everyone // So we should wrap into ArcSwap, and use unsafe `as_raf_fd()` and // `from_raw_fd` to update the list. // This means that we will have multiple `UdpSocket` per actual socket // so we have to handle `drop()` manually, and garbage-collect the ones we // are no longer using in the background. sigh. // Just go with a ArcSwapAny, Arc>>); struct SocketList { list: ArcSwap>, } impl SocketList { fn new() -> Self { Self { list: ArcSwap::new(Arc::new(Vec::new())), } } // TODO: fn rm_socket() fn rm_all(&self) -> Self { let new_list = Arc::new(Vec::new()); let old_list = self.list.swap(new_list); Self { list: old_list.into(), } } async fn add_socket( &self, socket: Arc, handle: JoinHandle<::std::io::Result<()>>, ) { // we could simplify this into just a `.swap` instead of `.rcu` but // it is not yet guaranteed that only one thread will call this fn // ...we don't need performance here anyway let arc_handle = Arc::new(handle); self.list.rcu(|old_list| { let mut new_list = Arc::new(Vec::with_capacity(old_list.len() + 1)); new_list = old_list.to_vec().into(); Arc::get_mut(&mut new_list) .unwrap() .push((socket.clone(), arc_handle.clone())); new_list }); } /// This method assumes no other `add_sockets` are being run async fn stop_all(mut self) { let mut arc_list = self.list.into_inner(); let list = loop { match Arc::try_unwrap(arc_list) { Ok(list) => break list, Err(arc_retry) => { arc_list = arc_retry; ::tokio::time::sleep(::core::time::Duration::from_millis( 50, )) .await; } } }; for (_socket, mut handle) in list.into_iter() { Arc::get_mut(&mut handle).unwrap().await; } } fn lock(&self) -> SocketListRef { SocketListRef { list: self.list.load_full(), } } } // TODO: impl Drop for SocketList struct SocketListRef { list: Arc>, } impl SocketListRef { fn find(&self, sock: UdpServer) -> Option> { match self .list .iter() .find(|&(s, _)| s.local_addr().unwrap() == sock.0) { Some((sock_srv, _)) => Some(sock_srv.clone()), None => None, } } } #[derive(Debug, Copy, Clone)] struct UdpClient(SocketAddr); #[derive(Debug, Copy, Clone)] struct UdpServer(SocketAddr); struct RawUdp { data: Vec, src: UdpClient, dst: UdpServer, } enum Work { Recv(RawUdp), } // PERF: Arc> loks a bit too much, need to find // faster ways to do this struct ConnList { connections: Vec>>, /// Bitmap to track which connection ids are used or free ids_used: Vec<::bitmaps::Bitmap<1024>>, } impl ConnList { fn new() -> Self { let mut bitmap_id = ::bitmaps::Bitmap::<1024>::new(); bitmap_id.set(0, true); // ID(0) == handshake Self { connections: Vec::with_capacity(128), ids_used: vec![bitmap_id], } } fn reserve_first(&mut self, mut conn: Connection) -> Arc { // uhm... bad things are going on here: // * id must be initialized, but only because: // * rust does not understand that after the `!found` id is always // initialized // * `ID::new_u64` is really safe only with >0, but here it always is // ...we should probably rewrite it in better, safer rust let mut id: u64 = 0; let mut found = false; for (i, b) in self.ids_used.iter_mut().enumerate() { match b.first_false_index() { Some(idx) => { b.set(idx, true); id = ((i as u64) * 1024) + (idx as u64); found = true; break; } None => {} } } if !found { let mut new_bitmap = ::bitmaps::Bitmap::<1024>::new(); new_bitmap.set(0, true); id = (self.ids_used.len() as u64) * 1024; self.ids_used.push(new_bitmap); } let new_id = connection::ID::new_u64(id); conn.id = new_id; let conn = Arc::new(conn); if (self.connections.len() as u64) < id { self.connections.push(Some(conn.clone())); } else { // very probably redundant self.connections[id as usize] = Some(conn.clone()); } conn } } /// Instance of a fenrir endpoint #[allow(missing_copy_implementations, missing_debug_implementations)] pub struct Fenrir { /// library Configuration cfg: Config, /// listening udp sockets sockets: SocketList, /// DNSSEC resolver, with failovers dnssec: Option, /// Broadcast channel to tell workers to stop working stop_working: ::tokio::sync::broadcast::Sender, /// Private keys used in the handshake _inner: Arc, /// where to ask for token check token_check: Arc>, /// MPMC work queue. sender work_send: Arc<::async_channel::Sender>, /// MPMC work queue. receiver work_recv: Arc<::async_channel::Receiver>, // PERF: rand uses syscalls. should we do that async? rand: ::ring::rand::SystemRandom, /// list of Established connections connections: Arc>, } // TODO: graceful vs immediate stop impl Drop for Fenrir { fn drop(&mut self) { self.stop_sync() } } impl Fenrir { /// Create a new Fenrir endpoint pub fn new(config: &Config) -> Result { let listen_num = config.listen.len(); let (sender, _) = ::tokio::sync::broadcast::channel(1); let (work_send, work_recv) = ::async_channel::unbounded::(); let endpoint = Fenrir { cfg: config.clone(), sockets: SocketList::new(), dnssec: None, stop_working: sender, _inner: Arc::new(FenrirInner { ciphers: ArcSwapAny::new(Arc::new(Vec::new())), key_exchanges: ArcSwapAny::new(Arc::new(Vec::new())), keys_srv: ArcSwapAny::new(Arc::new(Vec::new())), hshake_cli: ArcSwapAny::new(Arc::new(Vec::new())), }), token_check: Arc::new(ArcSwapOption::from(None)), work_send: Arc::new(work_send), work_recv: Arc::new(work_recv), rand: ::ring::rand::SystemRandom::new(), connections: Arc::new(RwLock::new(ConnList::new())), }; Ok(endpoint) } /// Start all workers, listeners pub async fn start(&mut self) -> Result<(), Error> { if let Err(e) = self.add_sockets().await { self.stop().await; return Err(e.into()); } self.dnssec = Some(dnssec::Dnssec::new(&self.cfg.resolvers).await?); Ok(()) } /// Stop all workers, listeners /// asyncronous version for Drop fn stop_sync(&mut self) { let _ = self.stop_working.send(true); let mut toempty_sockets = self.sockets.rm_all(); let task = ::tokio::task::spawn(toempty_sockets.stop_all()); let _ = ::futures::executor::block_on(task); self.dnssec = None; } /// Stop all workers, listeners pub async fn stop(&mut self) { let _ = self.stop_working.send(true); let mut toempty_sockets = self.sockets.rm_all(); toempty_sockets.stop_all().await; self.dnssec = None; } /// Enable some common socket options. This is just the unsafe part fn enable_sock_opt( fd: ::std::os::fd::RawFd, option: ::libc::c_int, value: ::libc::c_int, ) -> ::std::io::Result<()> { #[allow(unsafe_code)] unsafe { #[allow(trivial_casts)] let val = &value as *const _ as *const ::libc::c_void; let size = ::std::mem::size_of_val(&value) as ::libc::socklen_t; // always clear the error bit before doing a new syscall let _ = ::std::io::Error::last_os_error(); let ret = ::libc::setsockopt(fd, ::libc::SOL_SOCKET, option, val, size); if ret != 0 { return Err(::std::io::Error::last_os_error()); } } Ok(()) } /// Add all UDP sockets found in config /// and start listening for packets async fn add_sockets(&self) -> ::std::io::Result<()> { let sockets = self.cfg.listen.iter().map(|s_addr| async { let socket = ::tokio::spawn(Self::bind_udp(s_addr.clone())).await??; Ok(socket) }); let sockets = ::futures::future::join_all(sockets).await; for s_res in sockets.into_iter() { match s_res { Ok(s) => { let stop_working = self.stop_working.subscribe(); let arc_s = Arc::new(s); let join = ::tokio::spawn(Self::listen_udp( stop_working, self.work_send.clone(), arc_s.clone(), )); self.sockets.add_socket(arc_s, join); } Err(e) => { return Err(e); } } } Ok(()) } /// Add an async udp listener async fn bind_udp(sock: SocketAddr) -> ::std::io::Result { let socket = UdpSocket::bind(sock).await?; use ::std::os::fd::AsRawFd; let fd = socket.as_raw_fd(); // can be useful later on for reloads Self::enable_sock_opt(fd, ::libc::SO_REUSEADDR, 1)?; Self::enable_sock_opt(fd, ::libc::SO_REUSEPORT, 1)?; // We will do path MTU discovery by ourselves, // always set the "don't fragment" bit if sock.is_ipv6() { Self::enable_sock_opt(fd, ::libc::IPV6_DONTFRAG, 1)?; } else { // FIXME: linux only Self::enable_sock_opt( fd, ::libc::IP_MTU_DISCOVER, ::libc::IP_PMTUDISC_DO, )?; } Ok(socket) } /// Run a dedicated loop to read packets on the listening socket async fn listen_udp( mut stop_working: ::tokio::sync::broadcast::Receiver, work_queue: Arc<::async_channel::Sender>, socket: Arc, ) -> ::std::io::Result<()> { // jumbo frames are 9K max let sock_receiver = UdpServer(socket.local_addr()?); let mut buffer: [u8; 9000] = [0; 9000]; loop { let (bytes, sock_sender) = ::tokio::select! { _done = stop_working.recv() => { break; } result = socket.recv_from(&mut buffer) => { result? } }; let data: Vec = buffer[..bytes].to_vec(); work_queue.send(Work::Recv(RawUdp { data, src: UdpClient(sock_sender), dst: sock_receiver, })); } Ok(()) } /// Get the raw TXT record of a Fenrir domain pub async fn resolv_str(&self, domain: &str) -> Result { match &self.dnssec { Some(dnssec) => Ok(dnssec.resolv(domain).await?), None => Err(Error::NotInitialized), } } /// Get the raw TXT record of a Fenrir domain pub async fn resolv(&self, domain: &str) -> Result { let record_str = self.resolv_str(domain).await?; Ok(dnssec::Dnssec::parse_txt_record(&record_str)?) } /// Loop continuously and parse packets and other work pub async fn work_loop(&self) { let mut stop_working = self.stop_working.subscribe(); loop { let work = ::tokio::select! { _done = stop_working.recv() => { break; } maybe_work = self.work_recv.recv() => { match maybe_work { Ok(work) => work, Err(_) => break, } } }; match work { Work::Recv(pkt) => { self.recv(pkt).await; } } } } const MIN_PACKET_BYTES: usize = 8; /// Read and do stuff with the raw udp packet async fn recv(&self, mut udp: RawUdp) { if udp.data.len() < Self::MIN_PACKET_BYTES { return; } use connection::ID; let raw_id: [u8; 8] = (udp.data[..8]).try_into().expect("unreachable"); if ID::from(raw_id).is_handshake() { use connection::handshake::Handshake; let handshake = match Handshake::deserialize(&udp.data[8..]) { Ok(handshake) => handshake, Err(e) => { ::tracing::warn!("Handshake parsing: {}", e); return; } }; let action = match self._inner.recv_handshake(handshake, &mut udp.data[8..]) { Ok(action) => action, Err(err) => { ::tracing::debug!("Handshake recv error {}", err); return; } }; match action { HandshakeAction::AuthNeeded(authinfo) => { let tk_check = match self.token_check.load_full() { Some(tk_check) => tk_check, None => { ::tracing::error!( "Handshake received, but no tocken_checker" ); return; } }; use handshake::{ dirsync::{self, DirSync}, HandshakeData, }; match authinfo.handshake.data { HandshakeData::DirSync(ds) => match ds { DirSync::Req(req) => { use dirsync::ReqInner; let req_data = match req.data { ReqInner::ClearText(req_data) => req_data, _ => { ::tracing::error!( "token_check: expected ClearText" ); return; } }; let is_authenticated = match tk_check( req_data.auth.user, req_data.auth.token, req_data.auth.service_id, req_data.auth.domain, ) .await { Ok(is_authenticated) => is_authenticated, Err(_) => { ::tracing::error!( "error in token auth" ); // TODO: retry? return; } }; if !is_authenticated { ::tracing::warn!( "Wrong authentication for user {:?}", req_data.auth.user ); // TODO: error response return; } // Client has correctly authenticated // TODO: contact the service, get the key and // connection ID let srv_conn_id = connection::ID::new_rand(&self.rand); let srv_secret = enc::sym::Secret::new_rand(&self.rand); let head_len = req.cipher.nonce_len(); let tag_len = req.cipher.tag_len(); let raw_conn = Connection::new( authinfo.hkdf, req.cipher, connection::Role::Server, &self.rand, ); // track connection let auth_conn = { let mut lock = self.connections.write().await; lock.reserve_first(raw_conn) }; let resp_data = dirsync::RespData { client_nonce: req_data.nonce, id: auth_conn.id, service_id: srv_conn_id, service_key: srv_secret, }; use crate::enc::sym::AAD; // no aad for now let aad = AAD(&mut []); use dirsync::RespInner; let resp = dirsync::Resp { client_key_id: req_data.client_key_id, data: RespInner::ClearText(resp_data), }; let offset_to_encrypt = resp.encrypted_offset(); let encrypt_until = offset_to_encrypt + resp.encrypted_length() + tag_len.0; let resp_handshake = Handshake::new( HandshakeData::DirSync(DirSync::Resp(resp)), ); use connection::{Packet, PacketData, ID}; let packet = Packet { id: ID::new_handshake(), data: PacketData::Handshake(resp_handshake), }; let mut raw_out = Vec::::with_capacity(packet.len()); packet.serialize( head_len, tag_len, &mut raw_out, ); if let Err(e) = auth_conn.cipher_send.encrypt( aad, &mut raw_out [offset_to_encrypt..encrypt_until], ) { ::tracing::error!("can't encrypt: {:?}", e); return; } self.send_packet(raw_out, udp.src, udp.dst) .await; } DirSync::Resp(resp) => { todo!() } _ => { todo!() } }, } } _ => {} }; } // copy packet, spawn todo!(); } async fn send_packet( &self, data: Vec, client: UdpClient, server: UdpServer, ) { let src_sock; { let sockets = self.sockets.lock(); src_sock = match sockets.find(server) { Some(src_sock) => src_sock, None => { ::tracing::error!( "Can't send packet: Server changed listening ip!" ); return; } }; } src_sock.send_to(&data, client.0); } }