//! Socket related types and functions use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption}; use ::std::{ net::SocketAddr, sync::Arc, vec::{self, Vec}, }; use ::tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle}; /// Pair to easily track the socket and its async listening handle pub type SocketTracker = (Arc, Arc>>); /// async free socket list pub(crate) struct SocketList { pub list: ArcSwap>, } impl SocketList { pub(crate) fn new() -> Self { Self { list: ArcSwap::new(Arc::new(Vec::new())), } } // TODO: fn rm_socket() pub(crate) fn rm_all(&self) -> Self { let new_list = Arc::new(Vec::new()); let old_list = self.list.swap(new_list); Self { list: old_list.into(), } } pub(crate) async fn add_socket( &self, socket: Arc, handle: JoinHandle<::std::io::Result<()>>, ) { // we could simplify this into just a `.swap` instead of `.rcu` but // it is not yet guaranteed that only one thread will call this fn // ...we don't need performance here anyway let arc_handle = Arc::new(handle); self.list.rcu(|old_list| { let mut new_list = Arc::new(Vec::with_capacity(old_list.len() + 1)); new_list = old_list.to_vec().into(); Arc::get_mut(&mut new_list) .unwrap() .push((socket.clone(), arc_handle.clone())); new_list }); } /// This method assumes no other `add_sockets` are being run pub(crate) async fn stop_all(mut self) { let mut arc_list = self.list.into_inner(); let list = loop { match Arc::try_unwrap(arc_list) { Ok(list) => break list, Err(arc_retry) => { arc_list = arc_retry; ::tokio::time::sleep(::core::time::Duration::from_millis( 50, )) .await; } } }; for (_socket, mut handle) in list.into_iter() { Arc::get_mut(&mut handle).unwrap().await; } } pub(crate) fn lock(&self) -> SocketListRef { SocketListRef { list: self.list.load_full(), } } } /// Reference to a locked SocketList // TODO: impl Drop for SocketList pub(crate) struct SocketListRef { list: Arc>, } impl SocketListRef { pub(crate) fn find(&self, sock: UdpServer) -> Option> { match self .list .iter() .find(|&(s, _)| s.local_addr().unwrap() == sock.0) { Some((sock_srv, _)) => Some(sock_srv.clone()), None => None, } } } /// Strong typedef for a client socket address #[derive(Debug, Copy, Clone)] pub(crate) struct UdpClient(pub SocketAddr); /// Strong typedef for a server socket address #[derive(Debug, Copy, Clone)] pub(crate) struct UdpServer(pub SocketAddr); /// Enable some common socket options. This is just the unsafe part fn enable_sock_opt( fd: ::std::os::fd::RawFd, option: ::libc::c_int, value: ::libc::c_int, ) -> ::std::io::Result<()> { #[allow(unsafe_code)] unsafe { #[allow(trivial_casts)] let val = &value as *const _ as *const ::libc::c_void; let size = ::std::mem::size_of_val(&value) as ::libc::socklen_t; // always clear the error bit before doing a new syscall let _ = ::std::io::Error::last_os_error(); let ret = ::libc::setsockopt(fd, ::libc::SOL_SOCKET, option, val, size); if ret != 0 { return Err(::std::io::Error::last_os_error()); } } Ok(()) } /// Add an async udp listener pub async fn bind_udp(sock: SocketAddr) -> ::std::io::Result { let socket = UdpSocket::bind(sock).await?; use ::std::os::fd::AsRawFd; let fd = socket.as_raw_fd(); // can be useful later on for reloads enable_sock_opt(fd, ::libc::SO_REUSEADDR, 1)?; enable_sock_opt(fd, ::libc::SO_REUSEPORT, 1)?; // We will do path MTU discovery by ourselves, // always set the "don't fragment" bit if sock.is_ipv6() { enable_sock_opt(fd, ::libc::IPV6_DONTFRAG, 1)?; } else { // FIXME: linux only enable_sock_opt(fd, ::libc::IP_MTU_DISCOVER, ::libc::IP_PMTUDISC_DO)?; } Ok(socket) }