#![deny( missing_docs, missing_debug_implementations, missing_copy_implementations, trivial_casts, trivial_numeric_casts, unsafe_code, unstable_features, unused_import_braces, unused_qualifications )] //! //! libFenrir is the official rust library implementing the Fenrir protocol pub mod auth; mod config; pub mod connection; pub mod dnssec; pub mod enc; mod inner; use ::std::{sync::Arc, vec::Vec}; use ::tokio::{net::UdpSocket, sync::Mutex}; use crate::{ auth::{Domain, ServiceID, TokenChecker}, connection::{ handshake, socket::{SocketList, UdpClient, UdpServer}, AuthServerConnections, Packet, }, inner::{ worker::{ConnectInfo, RawUdp, Work, Worker}, ThreadTracker, }, }; pub use config::Config; /// Main fenrir library errors #[derive(::thiserror::Error, Debug)] pub enum Error { /// The library was not initialized (run .start()) #[error("not initialized")] NotInitialized, /// Error in setting up worker threads #[error("Setup err: {0}")] Setup(String), /// General I/O error #[error("IO: {0:?}")] IO(#[from] ::std::io::Error), /// Dnssec errors #[error("Dnssec: {0:?}")] Dnssec(#[from] dnssec::Error), /// Handshake errors #[error("Handshake: {0:?}")] Handshake(#[from] handshake::Error), /// Key error #[error("key: {0:?}")] Key(#[from] crate::enc::Error), /// Resolution problems. wrong or incomplete DNSSEC data #[error("DNSSEC resolution: {0}")] Resolution(String), /// Wrapper on encryption errors #[error("Encrypt: {0}")] Encrypt(enc::Error), } /// Instance of a fenrir endpoint #[allow(missing_copy_implementations, missing_debug_implementations)] pub struct Fenrir { /// library Configuration cfg: Config, /// listening udp sockets sockets: SocketList, /// DNSSEC resolver, with failovers dnssec: Option, /// Broadcast channel to tell workers to stop working stop_working: ::tokio::sync::broadcast::Sender, /// where to ask for token check token_check: Option>>, /// tracks the connections to authentication servers /// so that we can reuse them conn_auth_srv: Mutex, // TODO: find a way to both increase and decrease these two in a thread-safe // manner _thread_pool: Vec<::std::thread::JoinHandle<()>>, _thread_work: Arc>>, } // TODO: graceful vs immediate stop impl Drop for Fenrir { fn drop(&mut self) { self.stop_sync() } } impl Fenrir { /// Create a new Fenrir endpoint pub fn new(config: &Config) -> Result { let (sender, _) = ::tokio::sync::broadcast::channel(1); let endpoint = Fenrir { cfg: config.clone(), sockets: SocketList::new(), dnssec: None, stop_working: sender, token_check: None, conn_auth_srv: Mutex::new(AuthServerConnections::new()), _thread_pool: Vec::new(), _thread_work: Arc::new(Vec::new()), }; Ok(endpoint) } /// Start all workers, listeners pub async fn start( &mut self, tokio_rt: Arc<::tokio::runtime::Runtime>, ) -> Result<(), Error> { self.start_work_threads_pinned(tokio_rt).await?; if let Err(e) = self.add_sockets().await { self.stop().await; return Err(e.into()); } self.dnssec = Some(dnssec::Dnssec::new(&self.cfg.resolvers).await?); Ok(()) } /// Stop all workers, listeners /// asyncronous version for Drop fn stop_sync(&mut self) { let _ = self.stop_working.send(true); let toempty_sockets = self.sockets.rm_all(); let task = ::tokio::task::spawn(toempty_sockets.stop_all()); let _ = ::futures::executor::block_on(task); let mut old_thread_pool = Vec::new(); ::std::mem::swap(&mut self._thread_pool, &mut old_thread_pool); let _ = old_thread_pool.into_iter().map(|th| th.join()); self.dnssec = None; } /// Stop all workers, listeners pub async fn stop(&mut self) { let _ = self.stop_working.send(true); let toempty_sockets = self.sockets.rm_all(); toempty_sockets.stop_all().await; let mut old_thread_pool = Vec::new(); ::std::mem::swap(&mut self._thread_pool, &mut old_thread_pool); let _ = old_thread_pool.into_iter().map(|th| th.join()); self.dnssec = None; } /// Add all UDP sockets found in config /// and start listening for packets async fn add_sockets(&self) -> ::std::io::Result<()> { let sockets = self.cfg.listen.iter().map(|s_addr| async { let socket = ::tokio::spawn(connection::socket::bind_udp(s_addr.clone())) .await??; Ok(socket) }); let sockets = ::futures::future::join_all(sockets).await; for s_res in sockets.into_iter() { match s_res { Ok(s) => { let stop_working = self.stop_working.subscribe(); let arc_s = Arc::new(s); let join = ::tokio::spawn(Self::listen_udp( stop_working, self._thread_work.clone(), arc_s.clone(), )); self.sockets.add_socket(arc_s, join).await; } Err(e) => { return Err(e); } } } Ok(()) } /// Run a dedicated loop to read packets on the listening socket async fn listen_udp( mut stop_working: ::tokio::sync::broadcast::Receiver, work_queues: Arc>>, socket: Arc, ) -> ::std::io::Result<()> { // jumbo frames are 9K max let sock_receiver = UdpServer(socket.local_addr()?); let mut buffer: [u8; 9000] = [0; 9000]; let queues_num = work_queues.len() as u64; loop { let (bytes, sock_sender) = ::tokio::select! { _done = stop_working.recv() => { break; } result = socket.recv_from(&mut buffer) => { result? } }; let data: Vec = buffer[..bytes].to_vec(); // we very likely have multiple threads, pinned to different cpus. // use the ConnectionID to send the same connection // to the same thread. // Handshakes have connection ID 0, so we use the sender's UDP port let packet = match Packet::deserialize_id(&data) { Ok(packet) => packet, Err(_) => continue, // packet way too short, ignore. }; let thread_idx: usize = { use connection::packet::ConnectionID; match packet.id { ConnectionID::Handshake => { let send_port = sock_sender.port() as u64; ((send_port % queues_num) - 1) as usize } ConnectionID::ID(id) => { ((id.get() % queues_num) - 1) as usize } } }; let _ = work_queues[thread_idx] .send(Work::Recv(RawUdp { src: UdpClient(sock_sender), dst: sock_receiver, packet, data, })) .await; } Ok(()) } /// Get the raw TXT record of a Fenrir domain pub async fn resolv_txt(&self, domain: &Domain) -> Result { match &self.dnssec { Some(dnssec) => Ok(dnssec.resolv(domain).await?), None => Err(Error::NotInitialized), } } /// Get the raw TXT record of a Fenrir domain pub async fn resolv( &self, domain: &Domain, ) -> Result { let record_str = self.resolv_txt(domain).await?; Ok(dnssec::Dnssec::parse_txt_record(&record_str)?) } /// Connect to a service pub async fn connect( &self, domain: &Domain, service: ServiceID, ) -> Result<(), Error> { let resolved = self.resolv(domain).await?; loop { // check if we already have a connection to that auth. srv let is_reserved = { let mut conn_auth_lock = self.conn_auth_srv.lock().await; conn_auth_lock.get_or_reserve(&resolved) }; use connection::Reservation; match is_reserved { Reservation::Waiting => { use ::std::time::Duration; use ::tokio::time::sleep; // PERF: exponential backoff. // or we can have a broadcast channel sleep(Duration::from_millis(50)).await; continue; } Reservation::Reserved => break, Reservation::Present(_id_send) => { //TODO: reuse connection todo!() } } } // Spot reserved for the connection // find the thread with less connections let th_num = self._thread_work.len(); let mut conn_count = Vec::::with_capacity(th_num); let mut wait_res = Vec::<::tokio::sync::oneshot::Receiver>::with_capacity( th_num, ); for th in self._thread_work.iter() { let (send, recv) = ::tokio::sync::oneshot::channel(); wait_res.push(recv); let _ = th.send(Work::CountConnections(send)).await; } for ch in wait_res.into_iter() { if let Ok(conn_num) = ch.await { conn_count.push(conn_num); } } if conn_count.len() != th_num { return Err(Error::IO(::std::io::Error::new( ::std::io::ErrorKind::NotConnected, "can't connect to a thread", ))); } let thread_idx = conn_count .iter() .enumerate() .max_by(|(_, a), (_, b)| a.cmp(b)) .map(|(index, _)| index) .unwrap(); // and tell that thread to connect somewhere let (send, recv) = ::tokio::sync::oneshot::channel(); let _ = self._thread_work[thread_idx] .send(Work::Connect(ConnectInfo { answer: send, resolved: resolved.clone(), service_id: service, domain: domain.clone(), })) .await; match recv.await { Ok(res) => { match res { Err(e) => { let mut conn_auth_lock = self.conn_auth_srv.lock().await; conn_auth_lock.remove_reserved(&resolved); Err(e) } Ok((pubkey, id_send)) => { let mut conn_auth_lock = self.conn_auth_srv.lock().await; conn_auth_lock.add(&pubkey, id_send, &resolved); //FIXME: user needs to somehow track the connection Ok(()) } } } Err(e) => { // Thread dropped the sender. no more thread? let mut conn_auth_lock = self.conn_auth_srv.lock().await; conn_auth_lock.remove_reserved(&resolved); Err(Error::IO(::std::io::Error::new( ::std::io::ErrorKind::Interrupted, "recv failure on connect: ".to_owned() + &e.to_string(), ))) } } } // TODO: start work on a LocalSet provided by the user /// Start one working thread for each physical cpu /// threads are pinned to each cpu core. /// Work will be divided and rerouted so that there is no need to lock async fn start_work_threads_pinned( &mut self, tokio_rt: Arc<::tokio::runtime::Runtime>, ) -> ::std::result::Result<(), Error> { use ::std::sync::Mutex; let hw_topology = match ::hwloc2::Topology::new() { Some(hw_topology) => Arc::new(Mutex::new(hw_topology)), None => { return Err(Error::Setup( "Can't get hardware topology".to_owned(), )) } }; let cores; { let topology_lock = hw_topology.lock().unwrap(); let all_cores = match topology_lock .objects_with_type(&::hwloc2::ObjectType::Core) { Ok(all_cores) => all_cores, Err(_) => { return Err(Error::Setup("can't list cores".to_owned())) } }; cores = all_cores.len(); if cores <= 0 || !topology_lock.support().cpu().set_thread() { ::tracing::error!("No support for CPU pinning"); return Err(Error::Setup("No cpu pinning support".to_owned())); } } for core in 0..cores { ::tracing::debug!("Spawning thread {}", core); let th_topology = hw_topology.clone(); let th_tokio_rt = tokio_rt.clone(); let th_config = self.cfg.clone(); let (work_send, work_recv) = ::async_channel::unbounded::(); let th_stop_working = self.stop_working.subscribe(); let th_token_check = self.token_check.clone(); let th_socket_addrs = self.cfg.listen.clone(); let thread_id = ThreadTracker { total: cores as u16, id: 1 + (core as u16), }; let join_handle = ::std::thread::spawn(move || { // bind to a specific core let th_pinning; { let mut th_topology_lock = th_topology.lock().unwrap(); let th_cores = th_topology_lock .objects_with_type(&::hwloc2::ObjectType::Core) .unwrap(); let cpuset = th_cores.get(core).unwrap().cpuset().unwrap(); th_pinning = th_topology_lock.set_cpubind( cpuset, ::hwloc2::CpuBindFlags::CPUBIND_THREAD, ); } match th_pinning { Ok(_) => {} Err(_) => { ::tracing::error!("Can't bind thread to cpu"); return; } } // finally run the main worker. // make sure things stay on this thread let tk_local = ::tokio::task::LocalSet::new(); let _ = tk_local.block_on( &th_tokio_rt, Worker::new_and_loop( th_config, thread_id, th_stop_working, th_token_check, th_socket_addrs, work_recv, ), ); }); loop { let queues_lock = match Arc::get_mut(&mut self._thread_work) { Some(queues_lock) => queues_lock, None => { ::tokio::time::sleep( ::std::time::Duration::from_millis(50), ) .await; continue; } }; queues_lock.push(work_send); break; } self._thread_pool.push(join_handle); } Ok(()) } }