//! Socket related types and functions use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption}; use ::std::{ net::SocketAddr, sync::Arc, vec::{self, Vec}, }; use ::tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle}; // PERF: move to arcswap: // We use multiple UDP sockets. // We need a list of them all and to scan this list. // But we want to be able to update this list without locking everyone // So we should wrap into ArcSwap, and use unsafe `as_raf_fd()` and // `from_raw_fd` to update the list. // This means that we will have multiple `UdpSocket` per actual socket // so we have to handle `drop()` manually, and garbage-collect the ones we // are no longer using in the background. sigh. // Just go with a ArcSwapAny, Arc>>); /// async free socket list pub(crate) struct SocketList { pub list: ArcSwap>, } impl SocketList { pub(crate) fn new() -> Self { Self { list: ArcSwap::new(Arc::new(Vec::new())), } } // TODO: fn rm_socket() pub(crate) fn rm_all(&self) -> Self { let new_list = Arc::new(Vec::new()); let old_list = self.list.swap(new_list); Self { list: old_list.into(), } } pub(crate) async fn add_socket( &self, socket: Arc, handle: JoinHandle<::std::io::Result<()>>, ) { // we could simplify this into just a `.swap` instead of `.rcu` but // it is not yet guaranteed that only one thread will call this fn // ...we don't need performance here anyway let arc_handle = Arc::new(handle); self.list.rcu(|old_list| { let mut new_list = Arc::new(Vec::with_capacity(old_list.len() + 1)); new_list = old_list.to_vec().into(); Arc::get_mut(&mut new_list) .unwrap() .push((socket.clone(), arc_handle.clone())); new_list }); } /// This method assumes no other `add_sockets` are being run pub(crate) async fn stop_all(mut self) { let mut arc_list = self.list.into_inner(); let list = loop { match Arc::try_unwrap(arc_list) { Ok(list) => break list, Err(arc_retry) => { arc_list = arc_retry; ::tokio::time::sleep(::core::time::Duration::from_millis( 50, )) .await; } } }; for (_socket, mut handle) in list.into_iter() { Arc::get_mut(&mut handle).unwrap().await; } } pub(crate) fn lock(&self) -> SocketListRef { SocketListRef { list: self.list.load_full(), } } } /// Reference to a locked SocketList // TODO: impl Drop for SocketList pub(crate) struct SocketListRef { list: Arc>, } impl SocketListRef { pub(crate) fn find(&self, sock: UdpServer) -> Option> { match self .list .iter() .find(|&(s, _)| s.local_addr().unwrap() == sock.0) { Some((sock_srv, _)) => Some(sock_srv.clone()), None => None, } } } /// Strong typedef for a client socket address #[derive(Debug, Copy, Clone)] pub(crate) struct UdpClient(pub SocketAddr); /// Strong typedef for a server socket address #[derive(Debug, Copy, Clone)] pub(crate) struct UdpServer(pub SocketAddr);