libFenrir/src/lib.rs
Luca Fulchir 2fe91d5dd3
Give the user a tracker for conn interactions
Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
2023-06-28 19:11:22 +02:00

663 lines
23 KiB
Rust

#![deny(
missing_docs,
missing_debug_implementations,
missing_copy_implementations,
trivial_casts,
trivial_numeric_casts,
unsafe_code,
unstable_features,
unused_import_braces,
unused_qualifications
)]
//!
//! libFenrir is the official rust library implementing the Fenrir protocol
pub mod auth;
mod config;
pub mod connection;
pub mod dnssec;
pub mod enc;
mod inner;
#[cfg(test)]
mod tests;
use ::std::{sync::Arc, vec::Vec};
use ::tokio::{net::UdpSocket, sync::Mutex};
use crate::{
auth::{Domain, ServiceID, TokenChecker},
connection::{
handshake,
socket::{SocketTracker, UdpClient, UdpServer},
AuthServerConnections, Packet,
},
inner::{
worker::{ConnectInfo, RawUdp, Work, Worker},
ThreadTracker,
},
};
pub use config::Config;
pub use connection::{AuthSrvConn, ServiceConn};
/// Main fenrir library errors
#[derive(::thiserror::Error, Debug)]
pub enum Error {
/// The library was not initialized (run .start())
#[error("not initialized")]
NotInitialized,
/// Error in setting up worker threads
#[error("Setup err: {0}")]
Setup(String),
/// General I/O error
#[error("IO: {0:?}")]
IO(#[from] ::std::io::Error),
/// Dnssec errors
#[error("Dnssec: {0:?}")]
Dnssec(#[from] dnssec::Error),
/// Handshake errors
#[error("Handshake: {0:?}")]
Handshake(#[from] handshake::Error),
/// Key error
#[error("key: {0:?}")]
Key(#[from] crate::enc::Error),
/// Resolution problems. wrong or incomplete DNSSEC data
#[error("DNSSEC resolution: {0}")]
Resolution(String),
/// Wrapper on encryption errors
#[error("Encrypt: {0}")]
Encrypt(enc::Error),
}
pub(crate) enum StopWorking {
WorkerStopped,
ListenerStopped,
}
pub(crate) type StopWorkingSendCh =
::tokio::sync::broadcast::Sender<::tokio::sync::mpsc::Sender<StopWorking>>;
pub(crate) type StopWorkingRecvCh = ::tokio::sync::broadcast::Receiver<
::tokio::sync::mpsc::Sender<StopWorking>,
>;
/// Instance of a fenrir endpoint
#[allow(missing_copy_implementations, missing_debug_implementations)]
pub struct Fenrir {
/// library Configuration
cfg: Config,
/// listening udp sockets
sockets: Vec<SocketTracker>,
/// DNSSEC resolver, with failovers
dnssec: dnssec::Dnssec,
/// Broadcast channel to tell workers to stop working
stop_working: StopWorkingSendCh,
/// where to ask for token check
token_check: Option<Arc<::tokio::sync::Mutex<TokenChecker>>>,
/// tracks the connections to authentication servers
/// so that we can reuse them
conn_auth_srv: Mutex<AuthServerConnections>,
// TODO: find a way to both increase and decrease these two in a thread-safe
// manner
_thread_pool: Vec<::std::thread::JoinHandle<()>>,
_thread_work: Arc<Vec<::async_channel::Sender<Work>>>,
}
// TODO: graceful vs immediate stop
impl Drop for Fenrir {
fn drop(&mut self) {
::tracing::debug!(
"Fenrir fast shutdown.\
Some threads might remain a bit longer"
);
let _ = self.stop_sync();
}
}
impl Fenrir {
/// Gracefully stop all listeners and workers
/// only return when all resources have been deallocated
pub async fn graceful_stop(mut self) {
::tracing::debug!("Fenrir full shut down");
if let Some((ch, listeners, workers)) = self.stop_sync() {
self.stop_wait(ch, listeners, workers).await;
}
}
fn stop_sync(
&mut self,
) -> Option<(
::tokio::sync::mpsc::Receiver<StopWorking>,
Vec<::tokio::task::JoinHandle<::std::io::Result<()>>>,
usize,
)> {
let workers_num = self._thread_work.len();
if self.sockets.len() > 0 || self._thread_work.len() > 0 {
let (ch_send, ch_recv) = ::tokio::sync::mpsc::channel(4);
let _ = self.stop_working.send(ch_send);
let mut old_listeners = Vec::with_capacity(self.sockets.len());
::core::mem::swap(&mut old_listeners, &mut self.sockets);
self._thread_pool.clear();
let listeners = old_listeners
.into_iter()
.map(|(_, joinable)| joinable)
.collect();
Some((ch_recv, listeners, workers_num))
} else {
None
}
}
async fn stop_wait(
&mut self,
mut ch: ::tokio::sync::mpsc::Receiver<StopWorking>,
listeners: Vec<::tokio::task::JoinHandle<::std::io::Result<()>>>,
mut workers_num: usize,
) {
let mut listeners_num = listeners.len();
while listeners_num > 0 && workers_num > 0 {
match ch.recv().await {
Some(stopped) => match stopped {
StopWorking::WorkerStopped => workers_num = workers_num - 1,
StopWorking::ListenerStopped => {
listeners_num = listeners_num - 1
}
},
_ => break,
}
}
for l in listeners.into_iter() {
if let Err(e) = l.await {
::tracing::error!("Unclean shutdown of listener: {:?}", e);
}
}
}
/// Create a new Fenrir endpoint
/// spawn threads pinned to cpus in our own way with tokio's runtime
pub async fn with_threads(
config: &Config,
tokio_rt: Arc<::tokio::runtime::Runtime>,
) -> Result<Self, Error> {
let (sender, _) = ::tokio::sync::broadcast::channel(1);
let dnssec = dnssec::Dnssec::new(&config.resolvers)?;
// bind sockets early so we can change "port 0" (aka: random)
// in the config
let binded_sockets = Self::bind_sockets(&config).await?;
let socket_addrs = binded_sockets
.iter()
.map(|s| s.local_addr().unwrap())
.collect();
let cfg = {
let mut tmp = config.clone();
tmp.listen = socket_addrs;
tmp
};
let mut endpoint = Self {
cfg,
sockets: Vec::with_capacity(config.listen.len()),
dnssec,
stop_working: sender,
token_check: None,
conn_auth_srv: Mutex::new(AuthServerConnections::new()),
_thread_pool: Vec::new(),
_thread_work: Arc::new(Vec::new()),
};
endpoint
.start_work_threads_pinned(tokio_rt, binded_sockets.clone())
.await?;
endpoint.run_listeners(binded_sockets).await?;
Ok(endpoint)
}
/// Create a new Fenrir endpoint
/// Get the workers that you can use in a tokio LocalSet
/// You should:
/// * move these workers each in its own thread
/// * make sure that the threads are pinned on the cpu
pub async fn with_workers(
config: &Config,
) -> Result<(Self, Vec<Worker>), Error> {
let (stop_working, _) = ::tokio::sync::broadcast::channel(1);
let dnssec = dnssec::Dnssec::new(&config.resolvers)?;
// bind sockets early so we can change "port 0" (aka: random)
// in the config
let binded_sockets = Self::bind_sockets(&config).await?;
let socket_addrs = binded_sockets
.iter()
.map(|s| s.local_addr().unwrap())
.collect();
let cfg = {
let mut tmp = config.clone();
tmp.listen = socket_addrs;
tmp
};
let mut endpoint = Self {
cfg,
sockets: Vec::with_capacity(config.listen.len()),
dnssec,
stop_working: stop_working.clone(),
token_check: None,
conn_auth_srv: Mutex::new(AuthServerConnections::new()),
_thread_pool: Vec::new(),
_thread_work: Arc::new(Vec::new()),
};
let worker_num = config.threads.unwrap().get();
let mut workers = Vec::with_capacity(worker_num);
for _ in 0..worker_num {
workers.push(
endpoint.start_single_worker(binded_sockets.clone()).await?,
);
}
endpoint.run_listeners(binded_sockets).await?;
Ok((endpoint, workers))
}
/// Returns the list of the actual addresses we are listening on
/// Note that this can be different from what was configured:
/// if you specified UDP port 0 a random one has been assigned to you
/// by the operating system.
pub fn addresses(&self) -> Vec<::std::net::SocketAddr> {
self.sockets.iter().map(|(s, _)| s.clone()).collect()
}
// only call **before** starting all threads
/// bind all UDP sockets found in config
async fn bind_sockets(cfg: &Config) -> Result<Vec<Arc<UdpSocket>>, Error> {
// try to bind multiple sockets in parallel
let mut sock_set = ::tokio::task::JoinSet::new();
cfg.listen.iter().for_each(|s_addr| {
let socket_address = s_addr.clone();
sock_set.spawn(async move {
connection::socket::bind_udp(socket_address).await
});
});
// make sure we either return all of them, or none
let mut all_socks = Vec::with_capacity(cfg.listen.len());
while let Some(join_res) = sock_set.join_next().await {
match join_res {
Ok(s_res) => match s_res {
Ok(s) => {
all_socks.push(Arc::new(s));
}
Err(e) => {
return Err(e.into());
}
},
Err(e) => {
return Err(Error::Setup(e.to_string()));
}
}
}
assert!(all_socks.len() == cfg.listen.len(), "missing socks");
Ok(all_socks)
}
// only call **after** starting all threads
/// spawn all listeners
async fn run_listeners(
&mut self,
socks: Vec<Arc<UdpSocket>>,
) -> Result<(), Error> {
for sock in socks.into_iter() {
let sockaddr = sock.local_addr().unwrap();
let stop_working = self.stop_working.subscribe();
let th_work = self._thread_work.clone();
let joinable = ::tokio::spawn(async move {
Self::listen_udp(stop_working, th_work, sock.clone()).await
});
self.sockets.push((sockaddr, joinable));
}
Ok(())
}
/// Run a dedicated loop to read packets on the listening socket
async fn listen_udp(
mut stop_working: StopWorkingRecvCh,
work_queues: Arc<Vec<::async_channel::Sender<Work>>>,
socket: Arc<UdpSocket>,
) -> ::std::io::Result<()> {
// jumbo frames are 9K max
let sock_receiver = UdpServer(socket.local_addr()?);
let mut buffer: [u8; 9000] = [0; 9000];
let queues_num = work_queues.len() as u64;
loop {
let (bytes, sock_sender) = ::tokio::select! {
tell_stopped = stop_working.recv() => {
drop(socket);
if let Ok(stop_ch) = tell_stopped {
let _ = stop_ch
.send(StopWorking::ListenerStopped).await;
}
return Ok(());
}
result = socket.recv_from(&mut buffer) => {
let (bytes, from) = result?;
(bytes, UdpClient(from))
}
};
let data: Vec<u8> = buffer[..bytes].to_vec();
// we very likely have multiple threads, pinned to different cpus.
// use the connection::ID to send the same connection
// to the same thread.
// Handshakes have connection ID 0, so we use the sender's UDP port
let packet = match Packet::deserialize_id(&data) {
Ok(packet) => packet,
Err(_) => continue, // packet way too short, ignore.
};
let thread_idx: usize = {
match packet.id {
connection::ID::Handshake => {
let send_port = sock_sender.0.port() as u64;
(send_port % queues_num) as usize
}
connection::ID::ID(id) => (id.get() % queues_num) as usize,
}
};
let _ = work_queues[thread_idx]
.send(Work::Recv(RawUdp {
src: sock_sender,
dst: sock_receiver,
packet,
data,
}))
.await;
}
}
/// Get the raw TXT record of a Fenrir domain
pub async fn resolv_txt(&self, domain: &Domain) -> Result<String, Error> {
match self.dnssec.resolv(domain).await {
Ok(res) => Ok(res),
Err(e) => Err(e.into()),
}
}
/// Get the raw TXT record of a Fenrir domain
pub async fn resolv(
&self,
domain: &Domain,
) -> Result<dnssec::Record, Error> {
let record_str = self.resolv_txt(domain).await?;
Ok(dnssec::Dnssec::parse_txt_record(&record_str)?)
}
/// Connect to a service, doing the dnssec resolution ourselves
pub async fn connect(
&self,
domain: &Domain,
service: ServiceID,
) -> Result<(AuthSrvConn, Option<ServiceConn>), Error> {
let resolved = self.resolv(domain).await?;
self.connect_resolved(resolved, domain, service).await
}
/// Connect to a service, with the user provided details
pub async fn connect_resolved(
&self,
resolved: dnssec::Record,
domain: &Domain,
service: ServiceID,
) -> Result<(AuthSrvConn, Option<ServiceConn>), Error> {
loop {
// check if we already have a connection to that auth. srv
let is_reserved = {
let mut conn_auth_lock = self.conn_auth_srv.lock().await;
conn_auth_lock.get_or_reserve(&resolved)
};
use connection::Reservation;
match is_reserved {
Reservation::Waiting => {
use ::std::time::Duration;
use ::tokio::time::sleep;
// PERF: exponential backoff.
// or we can have a broadcast channel
sleep(Duration::from_millis(50)).await;
continue;
}
Reservation::Reserved => break,
Reservation::Present(_id_send) => {
//TODO: reuse connection
todo!()
}
}
}
// Spot reserved for the connection
// find the thread with less connections
let th_num = self._thread_work.len();
let mut conn_count = Vec::<usize>::with_capacity(th_num);
let mut wait_res =
Vec::<::tokio::sync::oneshot::Receiver<usize>>::with_capacity(
th_num,
);
for th in self._thread_work.iter() {
let (send, recv) = ::tokio::sync::oneshot::channel();
wait_res.push(recv);
let _ = th.send(Work::CountConnections(send)).await;
}
for ch in wait_res.into_iter() {
if let Ok(conn_num) = ch.await {
conn_count.push(conn_num);
}
}
if conn_count.len() != th_num {
return Err(Error::IO(::std::io::Error::new(
::std::io::ErrorKind::NotConnected,
"can't connect to a thread",
)));
}
let thread_idx = conn_count
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.cmp(b))
.map(|(index, _)| index)
.unwrap();
// and tell that thread to connect somewhere
let (send, recv) = ::tokio::sync::oneshot::channel();
let _ = self._thread_work[thread_idx]
.send(Work::Connect(ConnectInfo {
answer: send,
resolved: resolved.clone(),
service_id: service,
domain: domain.clone(),
}))
.await;
match recv.await {
Ok(res) => match res {
Err(e) => {
let mut conn_auth_lock = self.conn_auth_srv.lock().await;
conn_auth_lock.remove_reserved(&resolved);
Err(e)
}
Ok(connections) => {
let key = resolved
.public_keys
.iter()
.find(|k| k.0 == connections.auth_key_id)
.unwrap();
let mut conn_auth_lock = self.conn_auth_srv.lock().await;
conn_auth_lock.add(
&key.1,
connections.auth_id_send,
&resolved,
);
Ok((connections.authsrv_conn, connections.service_conn))
}
},
Err(e) => {
// Thread dropped the sender. no more thread?
let mut conn_auth_lock = self.conn_auth_srv.lock().await;
conn_auth_lock.remove_reserved(&resolved);
Err(Error::IO(::std::io::Error::new(
::std::io::ErrorKind::Interrupted,
"recv failure on connect: ".to_owned() + &e.to_string(),
)))
}
}
}
// needs to be called before run_listeners
async fn start_single_worker(
&mut self,
socks: Vec<Arc<UdpSocket>>,
) -> ::std::result::Result<Worker, Error> {
let thread_idx = self._thread_work.len() as u16;
let max_threads = self.cfg.threads.unwrap().get() as u16;
if thread_idx >= max_threads {
::tracing::error!(
"thread id higher than number of threads in config"
);
assert!(
false,
"thread_idx is an index that can't reach cfg.threads"
);
return Err(Error::Setup("Thread id > threads_max".to_owned()));
}
let thread_id = ThreadTracker {
id: thread_idx,
total: max_threads,
};
let (work_send, work_recv) = ::async_channel::unbounded::<Work>();
let worker = Worker::new(
self.cfg.clone(),
thread_id,
self.stop_working.subscribe(),
self.token_check.clone(),
socks,
work_recv,
work_send.clone(),
)
.await?;
// don't keep around private keys too much
if (thread_idx + 1) == max_threads {
self.cfg.server_keys.clear();
}
loop {
let queues_lock = match Arc::get_mut(&mut self._thread_work) {
Some(queues_lock) => queues_lock,
None => {
// should not even ever happen
::tokio::time::sleep(::std::time::Duration::from_millis(
50,
))
.await;
continue;
}
};
queues_lock.push(work_send);
break;
}
Ok(worker)
}
// needs to be called before add_sockets
/// Start one working thread for each physical cpu
/// threads are pinned to each cpu core.
/// Work will be divided and rerouted so that there is no need to lock
async fn start_work_threads_pinned(
&mut self,
tokio_rt: Arc<::tokio::runtime::Runtime>,
sockets: Vec<Arc<UdpSocket>>,
) -> ::std::result::Result<(), Error> {
use ::std::sync::Mutex;
let hw_topology = match ::hwloc2::Topology::new() {
Some(hw_topology) => Arc::new(Mutex::new(hw_topology)),
None => {
return Err(Error::Setup(
"Can't get hardware topology".to_owned(),
))
}
};
let cores;
{
let topology_lock = hw_topology.lock().unwrap();
let all_cores = match topology_lock
.objects_with_type(&::hwloc2::ObjectType::Core)
{
Ok(all_cores) => all_cores,
Err(_) => {
return Err(Error::Setup("can't list cores".to_owned()))
}
};
cores = all_cores.len();
if cores <= 0 || !topology_lock.support().cpu().set_thread() {
::tracing::error!("No support for CPU pinning");
return Err(Error::Setup("No cpu pinning support".to_owned()));
}
}
for core in 0..cores {
::tracing::debug!("Spawning thread {}", core);
let th_topology = hw_topology.clone();
let th_tokio_rt = tokio_rt.clone();
let th_config = self.cfg.clone();
let (work_send, work_recv) = ::async_channel::unbounded::<Work>();
let th_work_send = work_send.clone();
let th_stop_working = self.stop_working.subscribe();
let th_token_check = self.token_check.clone();
let th_sockets = sockets.clone();
let thread_id = ThreadTracker {
total: cores as u16,
id: 1 + (core as u16),
};
let join_handle = ::std::thread::spawn(move || {
// bind to a specific core
let th_pinning;
{
let mut th_topology_lock = th_topology.lock().unwrap();
let th_cores = th_topology_lock
.objects_with_type(&::hwloc2::ObjectType::Core)
.unwrap();
let cpuset = th_cores.get(core).unwrap().cpuset().unwrap();
th_pinning = th_topology_lock.set_cpubind(
cpuset,
::hwloc2::CpuBindFlags::CPUBIND_THREAD,
);
}
match th_pinning {
Ok(_) => {}
Err(_) => {
::tracing::error!("Can't bind thread to cpu");
return;
}
}
// finally run the main worker.
// make sure things stay on this thread
let tk_local = ::tokio::task::LocalSet::new();
let _ = tk_local.block_on(&th_tokio_rt, async move {
let mut worker = match Worker::new(
th_config,
thread_id,
th_stop_working,
th_token_check,
th_sockets,
work_recv,
th_work_send,
)
.await
{
Ok(worker) => worker,
Err(_) => return,
};
worker.work_loop().await
});
});
loop {
let queues_lock = match Arc::get_mut(&mut self._thread_work) {
Some(queues_lock) => queues_lock,
None => {
::tokio::time::sleep(
::std::time::Duration::from_millis(50),
)
.await;
continue;
}
};
queues_lock.push(work_send);
break;
}
self._thread_pool.push(join_handle);
}
// don't keep around private keys too much
self.cfg.server_keys.clear();
Ok(())
}
}