#![deny( missing_docs, missing_debug_implementations, missing_copy_implementations, trivial_casts, trivial_numeric_casts, unsafe_code, unstable_features, unused_import_braces, unused_qualifications )] //! //! libFenrir is the official rust library implementing the Fenrir protocol pub mod auth; mod config; pub mod connection; pub mod dnssec; pub mod enc; mod inner; #[cfg(test)] mod tests; use ::std::{sync::Arc, vec::Vec}; use ::tokio::{net::UdpSocket, sync::Mutex}; use crate::{ auth::{Domain, ServiceID, TokenChecker}, connection::{ handshake, socket::{SocketTracker, UdpClient, UdpServer}, AuthServerConnections, Packet, }, inner::{ worker::{ConnectInfo, RawUdp, Work, Worker}, ThreadTracker, }, }; pub use config::Config; pub use connection::Connection; /// Main fenrir library errors #[derive(::thiserror::Error, Debug)] pub enum Error { /// The library was not initialized (run .start()) #[error("not initialized")] NotInitialized, /// Error in setting up worker threads #[error("Setup err: {0}")] Setup(String), /// General I/O error #[error("IO: {0:?}")] IO(#[from] ::std::io::Error), /// Dnssec errors #[error("Dnssec: {0:?}")] Dnssec(#[from] dnssec::Error), /// Handshake errors #[error("Handshake: {0:?}")] Handshake(#[from] handshake::Error), /// Key error #[error("key: {0:?}")] Key(#[from] crate::enc::Error), /// Resolution problems. wrong or incomplete DNSSEC data #[error("DNSSEC resolution: {0}")] Resolution(String), /// Wrapper on encryption errors #[error("Encrypt: {0}")] Encrypt(enc::Error), } pub(crate) enum StopWorking { WorkerStopped, ListenerStopped, } pub(crate) type StopWorkingSendCh = ::tokio::sync::broadcast::Sender<::tokio::sync::mpsc::Sender>; pub(crate) type StopWorkingRecvCh = ::tokio::sync::broadcast::Receiver< ::tokio::sync::mpsc::Sender, >; /// Instance of a fenrir endpoint #[allow(missing_copy_implementations, missing_debug_implementations)] pub struct Fenrir { /// library Configuration cfg: Config, /// listening udp sockets sockets: Vec, /// DNSSEC resolver, with failovers dnssec: dnssec::Dnssec, /// Broadcast channel to tell workers to stop working stop_working: StopWorkingSendCh, /// where to ask for token check token_check: Option>>, /// tracks the connections to authentication servers /// so that we can reuse them conn_auth_srv: Mutex, // TODO: find a way to both increase and decrease these two in a thread-safe // manner _thread_pool: Vec<::std::thread::JoinHandle<()>>, _thread_work: Arc>>, } // TODO: graceful vs immediate stop impl Drop for Fenrir { fn drop(&mut self) { ::tracing::debug!( "Fenrir fast shutdown.\ Some threads might remain a bit longer" ); let _ = self.stop_sync(); } } impl Fenrir { /// Gracefully stop all listeners and workers /// only return when all resources have been deallocated pub async fn graceful_stop(mut self) { ::tracing::debug!("Fenrir full shut down"); if let Some((ch, listeners, workers)) = self.stop_sync() { self.stop_wait(ch, listeners, workers).await; } } fn stop_sync( &mut self, ) -> Option<( ::tokio::sync::mpsc::Receiver, Vec<::tokio::task::JoinHandle<::std::io::Result<()>>>, usize, )> { let workers_num = self._thread_work.len(); if self.sockets.len() > 0 || self._thread_work.len() > 0 { let (ch_send, ch_recv) = ::tokio::sync::mpsc::channel(4); let _ = self.stop_working.send(ch_send); let mut old_listeners = Vec::with_capacity(self.sockets.len()); ::core::mem::swap(&mut old_listeners, &mut self.sockets); self._thread_pool.clear(); let listeners = old_listeners .into_iter() .map(|(_, joinable)| joinable) .collect(); Some((ch_recv, listeners, workers_num)) } else { None } } async fn stop_wait( &mut self, mut ch: ::tokio::sync::mpsc::Receiver, listeners: Vec<::tokio::task::JoinHandle<::std::io::Result<()>>>, mut workers_num: usize, ) { let mut listeners_num = listeners.len(); while listeners_num > 0 && workers_num > 0 { match ch.recv().await { Some(stopped) => match stopped { StopWorking::WorkerStopped => workers_num = workers_num - 1, StopWorking::ListenerStopped => { listeners_num = listeners_num - 1 } }, _ => break, } } for l in listeners.into_iter() { if let Err(e) = l.await { ::tracing::error!("Unclean shutdown of listener: {:?}", e); } } } /// Create a new Fenrir endpoint /// spawn threads pinned to cpus in our own way with tokio's runtime pub async fn with_threads( config: &Config, tokio_rt: Arc<::tokio::runtime::Runtime>, ) -> Result { let (sender, _) = ::tokio::sync::broadcast::channel(1); let dnssec = dnssec::Dnssec::new(&config.resolvers)?; // bind sockets early so we can change "port 0" (aka: random) // in the config let binded_sockets = Self::bind_sockets(&config).await?; let socket_addrs = binded_sockets .iter() .map(|s| s.local_addr().unwrap()) .collect(); let cfg = { let mut tmp = config.clone(); tmp.listen = socket_addrs; tmp }; let mut endpoint = Self { cfg, sockets: Vec::with_capacity(config.listen.len()), dnssec, stop_working: sender, token_check: None, conn_auth_srv: Mutex::new(AuthServerConnections::new()), _thread_pool: Vec::new(), _thread_work: Arc::new(Vec::new()), }; endpoint .start_work_threads_pinned(tokio_rt, binded_sockets.clone()) .await?; endpoint.run_listeners(binded_sockets).await?; Ok(endpoint) } /// Create a new Fenrir endpoint /// Get the workers that you can use in a tokio LocalSet /// You should: /// * move these workers each in its own thread /// * make sure that the threads are pinned on the cpu pub async fn with_workers( config: &Config, ) -> Result<(Self, Vec), Error> { let (stop_working, _) = ::tokio::sync::broadcast::channel(1); let dnssec = dnssec::Dnssec::new(&config.resolvers)?; // bind sockets early so we can change "port 0" (aka: random) // in the config let binded_sockets = Self::bind_sockets(&config).await?; let socket_addrs = binded_sockets .iter() .map(|s| s.local_addr().unwrap()) .collect(); let cfg = { let mut tmp = config.clone(); tmp.listen = socket_addrs; tmp }; let mut endpoint = Self { cfg, sockets: Vec::with_capacity(config.listen.len()), dnssec, stop_working: stop_working.clone(), token_check: None, conn_auth_srv: Mutex::new(AuthServerConnections::new()), _thread_pool: Vec::new(), _thread_work: Arc::new(Vec::new()), }; let worker_num = config.threads.unwrap().get(); let mut workers = Vec::with_capacity(worker_num); for _ in 0..worker_num { workers.push( endpoint.start_single_worker(binded_sockets.clone()).await?, ); } endpoint.run_listeners(binded_sockets).await?; Ok((endpoint, workers)) } /// Returns the list of the actual addresses we are listening on /// Note that this can be different from what was configured: /// if you specified UDP port 0 a random one has been assigned to you /// by the operating system. pub fn addresses(&self) -> Vec<::std::net::SocketAddr> { self.sockets.iter().map(|(s, _)| s.clone()).collect() } // only call **before** starting all threads /// bind all UDP sockets found in config async fn bind_sockets(cfg: &Config) -> Result>, Error> { // try to bind multiple sockets in parallel let mut sock_set = ::tokio::task::JoinSet::new(); cfg.listen.iter().for_each(|s_addr| { let socket_address = s_addr.clone(); sock_set.spawn(async move { connection::socket::bind_udp(socket_address).await }); }); // make sure we either return all of them, or none let mut all_socks = Vec::with_capacity(cfg.listen.len()); while let Some(join_res) = sock_set.join_next().await { match join_res { Ok(s_res) => match s_res { Ok(s) => { all_socks.push(Arc::new(s)); } Err(e) => { return Err(e.into()); } }, Err(e) => { return Err(Error::Setup(e.to_string())); } } } assert!(all_socks.len() == cfg.listen.len(), "missing socks"); Ok(all_socks) } // only call **after** starting all threads /// spawn all listeners async fn run_listeners( &mut self, socks: Vec>, ) -> Result<(), Error> { for sock in socks.into_iter() { let sockaddr = sock.local_addr().unwrap(); let stop_working = self.stop_working.subscribe(); let th_work = self._thread_work.clone(); let joinable = ::tokio::spawn(async move { Self::listen_udp(stop_working, th_work, sock.clone()).await }); self.sockets.push((sockaddr, joinable)); } Ok(()) } /// Run a dedicated loop to read packets on the listening socket async fn listen_udp( mut stop_working: StopWorkingRecvCh, work_queues: Arc>>, socket: Arc, ) -> ::std::io::Result<()> { // jumbo frames are 9K max let sock_receiver = UdpServer(socket.local_addr()?); let mut buffer: [u8; 9000] = [0; 9000]; let queues_num = work_queues.len() as u64; loop { let (bytes, sock_sender) = ::tokio::select! { tell_stopped = stop_working.recv() => { drop(socket); if let Ok(stop_ch) = tell_stopped { let _ = stop_ch .send(StopWorking::ListenerStopped).await; } return Ok(()); } result = socket.recv_from(&mut buffer) => { let (bytes, from) = result?; (bytes, UdpClient(from)) } }; let data: Vec = buffer[..bytes].to_vec(); // we very likely have multiple threads, pinned to different cpus. // use the connection::ID to send the same connection // to the same thread. // Handshakes have connection ID 0, so we use the sender's UDP port let packet = match Packet::deserialize_id(&data) { Ok(packet) => packet, Err(_) => continue, // packet way too short, ignore. }; let thread_idx: usize = { match packet.id { connection::ID::Handshake => { let send_port = sock_sender.0.port() as u64; (send_port % queues_num) as usize } connection::ID::ID(id) => (id.get() % queues_num) as usize, } }; let _ = work_queues[thread_idx] .send(Work::Recv(RawUdp { src: sock_sender, dst: sock_receiver, packet, data, })) .await; } } /// Get the raw TXT record of a Fenrir domain pub async fn resolv_txt(&self, domain: &Domain) -> Result { match self.dnssec.resolv(domain).await { Ok(res) => Ok(res), Err(e) => Err(e.into()), } } /// Get the raw TXT record of a Fenrir domain pub async fn resolv( &self, domain: &Domain, ) -> Result { let record_str = self.resolv_txt(domain).await?; Ok(dnssec::Dnssec::parse_txt_record(&record_str)?) } /// Connect to a service, doing the dnssec resolution ourselves pub async fn connect( &self, domain: &Domain, service: ServiceID, ) -> Result<(), Error> { let resolved = self.resolv(domain).await?; self.connect_resolved(resolved, domain, service).await } /// Connect to a service, with the user provided details pub async fn connect_resolved( &self, resolved: dnssec::Record, domain: &Domain, service: ServiceID, ) -> Result<(), Error> { loop { // check if we already have a connection to that auth. srv let is_reserved = { let mut conn_auth_lock = self.conn_auth_srv.lock().await; conn_auth_lock.get_or_reserve(&resolved) }; use connection::Reservation; match is_reserved { Reservation::Waiting => { use ::std::time::Duration; use ::tokio::time::sleep; // PERF: exponential backoff. // or we can have a broadcast channel sleep(Duration::from_millis(50)).await; continue; } Reservation::Reserved => break, Reservation::Present(_id_send) => { //TODO: reuse connection todo!() } } } // Spot reserved for the connection // find the thread with less connections let th_num = self._thread_work.len(); let mut conn_count = Vec::::with_capacity(th_num); let mut wait_res = Vec::<::tokio::sync::oneshot::Receiver>::with_capacity( th_num, ); for th in self._thread_work.iter() { let (send, recv) = ::tokio::sync::oneshot::channel(); wait_res.push(recv); let _ = th.send(Work::CountConnections(send)).await; } for ch in wait_res.into_iter() { if let Ok(conn_num) = ch.await { conn_count.push(conn_num); } } if conn_count.len() != th_num { return Err(Error::IO(::std::io::Error::new( ::std::io::ErrorKind::NotConnected, "can't connect to a thread", ))); } let thread_idx = conn_count .iter() .enumerate() .max_by(|(_, a), (_, b)| a.cmp(b)) .map(|(index, _)| index) .unwrap(); // and tell that thread to connect somewhere let (send, recv) = ::tokio::sync::oneshot::channel(); let _ = self._thread_work[thread_idx] .send(Work::Connect(ConnectInfo { answer: send, resolved: resolved.clone(), service_id: service, domain: domain.clone(), })) .await; match recv.await { Ok(res) => { match res { Err(e) => { let mut conn_auth_lock = self.conn_auth_srv.lock().await; conn_auth_lock.remove_reserved(&resolved); Err(e) } Ok((key_id, id_send)) => { let key = resolved .public_keys .iter() .find(|k| k.0 == key_id) .unwrap(); let mut conn_auth_lock = self.conn_auth_srv.lock().await; conn_auth_lock.add(&key.1, id_send, &resolved); //FIXME: user needs to somehow track the connection Ok(()) } } } Err(e) => { // Thread dropped the sender. no more thread? let mut conn_auth_lock = self.conn_auth_srv.lock().await; conn_auth_lock.remove_reserved(&resolved); Err(Error::IO(::std::io::Error::new( ::std::io::ErrorKind::Interrupted, "recv failure on connect: ".to_owned() + &e.to_string(), ))) } } } // needs to be called before run_listeners async fn start_single_worker( &mut self, socks: Vec>, ) -> ::std::result::Result { let thread_idx = self._thread_work.len() as u16; let max_threads = self.cfg.threads.unwrap().get() as u16; if thread_idx >= max_threads { ::tracing::error!( "thread id higher than number of threads in config" ); assert!( false, "thread_idx is an index that can't reach cfg.threads" ); return Err(Error::Setup("Thread id > threads_max".to_owned())); } let thread_id = ThreadTracker { id: thread_idx, total: max_threads, }; let (work_send, work_recv) = ::async_channel::unbounded::(); let worker = Worker::new( self.cfg.clone(), thread_id, self.stop_working.subscribe(), self.token_check.clone(), socks, work_recv, ) .await?; // don't keep around private keys too much if (thread_idx + 1) == max_threads { self.cfg.server_keys.clear(); } loop { let queues_lock = match Arc::get_mut(&mut self._thread_work) { Some(queues_lock) => queues_lock, None => { // should not even ever happen ::tokio::time::sleep(::std::time::Duration::from_millis( 50, )) .await; continue; } }; queues_lock.push(work_send); break; } Ok(worker) } // needs to be called before add_sockets /// Start one working thread for each physical cpu /// threads are pinned to each cpu core. /// Work will be divided and rerouted so that there is no need to lock async fn start_work_threads_pinned( &mut self, tokio_rt: Arc<::tokio::runtime::Runtime>, sockets: Vec>, ) -> ::std::result::Result<(), Error> { use ::std::sync::Mutex; let hw_topology = match ::hwloc2::Topology::new() { Some(hw_topology) => Arc::new(Mutex::new(hw_topology)), None => { return Err(Error::Setup( "Can't get hardware topology".to_owned(), )) } }; let cores; { let topology_lock = hw_topology.lock().unwrap(); let all_cores = match topology_lock .objects_with_type(&::hwloc2::ObjectType::Core) { Ok(all_cores) => all_cores, Err(_) => { return Err(Error::Setup("can't list cores".to_owned())) } }; cores = all_cores.len(); if cores <= 0 || !topology_lock.support().cpu().set_thread() { ::tracing::error!("No support for CPU pinning"); return Err(Error::Setup("No cpu pinning support".to_owned())); } } for core in 0..cores { ::tracing::debug!("Spawning thread {}", core); let th_topology = hw_topology.clone(); let th_tokio_rt = tokio_rt.clone(); let th_config = self.cfg.clone(); let (work_send, work_recv) = ::async_channel::unbounded::(); let th_stop_working = self.stop_working.subscribe(); let th_token_check = self.token_check.clone(); let th_sockets = sockets.clone(); let thread_id = ThreadTracker { total: cores as u16, id: 1 + (core as u16), }; let join_handle = ::std::thread::spawn(move || { // bind to a specific core let th_pinning; { let mut th_topology_lock = th_topology.lock().unwrap(); let th_cores = th_topology_lock .objects_with_type(&::hwloc2::ObjectType::Core) .unwrap(); let cpuset = th_cores.get(core).unwrap().cpuset().unwrap(); th_pinning = th_topology_lock.set_cpubind( cpuset, ::hwloc2::CpuBindFlags::CPUBIND_THREAD, ); } match th_pinning { Ok(_) => {} Err(_) => { ::tracing::error!("Can't bind thread to cpu"); return; } } // finally run the main worker. // make sure things stay on this thread let tk_local = ::tokio::task::LocalSet::new(); let _ = tk_local.block_on(&th_tokio_rt, async move { let mut worker = match Worker::new( th_config, thread_id, th_stop_working, th_token_check, th_sockets, work_recv, ) .await { Ok(worker) => worker, Err(_) => return, }; worker.work_loop().await }); }); loop { let queues_lock = match Arc::get_mut(&mut self._thread_work) { Some(queues_lock) => queues_lock, None => { ::tokio::time::sleep( ::std::time::Duration::from_millis(50), ) .await; continue; } }; queues_lock.push(work_send); break; } self._thread_pool.push(join_handle); } // don't keep around private keys too much self.cfg.server_keys.clear(); Ok(()) } }