libFenrir/src/lib.rs
Luca Fulchir 787e11e8e4
Fixes for Hati
Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
2023-06-07 11:07:46 +02:00

460 lines
16 KiB
Rust

#![deny(
missing_docs,
missing_debug_implementations,
missing_copy_implementations,
trivial_casts,
trivial_numeric_casts,
unsafe_code,
unstable_features,
unused_import_braces,
unused_qualifications
)]
//!
//! libFenrir is the official rust library implementing the Fenrir protocol
pub mod auth;
mod config;
pub mod connection;
pub mod dnssec;
pub mod enc;
mod inner;
use ::std::{sync::Arc, vec::Vec};
use ::tokio::{net::UdpSocket, sync::Mutex};
use crate::{
auth::{Domain, ServiceID, TokenChecker},
connection::{
handshake,
socket::{SocketList, UdpClient, UdpServer},
AuthServerConnections, Packet,
},
inner::{
worker::{ConnectInfo, RawUdp, Work, Worker},
ThreadTracker,
},
};
pub use config::Config;
/// Main fenrir library errors
#[derive(::thiserror::Error, Debug)]
pub enum Error {
/// The library was not initialized (run .start())
#[error("not initialized")]
NotInitialized,
/// Error in setting up worker threads
#[error("Setup err: {0}")]
Setup(String),
/// General I/O error
#[error("IO: {0:?}")]
IO(#[from] ::std::io::Error),
/// Dnssec errors
#[error("Dnssec: {0:?}")]
Dnssec(#[from] dnssec::Error),
/// Handshake errors
#[error("Handshake: {0:?}")]
Handshake(#[from] handshake::Error),
/// Key error
#[error("key: {0:?}")]
Key(#[from] crate::enc::Error),
/// Resolution problems. wrong or incomplete DNSSEC data
#[error("DNSSEC resolution: {0}")]
Resolution(String),
/// Wrapper on encryption errors
#[error("Encrypt: {0}")]
Encrypt(enc::Error),
}
/// Instance of a fenrir endpoint
#[allow(missing_copy_implementations, missing_debug_implementations)]
pub struct Fenrir {
/// library Configuration
cfg: Config,
/// listening udp sockets
sockets: SocketList,
/// DNSSEC resolver, with failovers
dnssec: Option<dnssec::Dnssec>,
/// Broadcast channel to tell workers to stop working
stop_working: ::tokio::sync::broadcast::Sender<bool>,
/// where to ask for token check
token_check: Option<Arc<::tokio::sync::Mutex<TokenChecker>>>,
/// tracks the connections to authentication servers
/// so that we can reuse them
conn_auth_srv: Mutex<AuthServerConnections>,
// TODO: find a way to both increase and decrease these two in a thread-safe
// manner
_thread_pool: Vec<::std::thread::JoinHandle<()>>,
_thread_work: Arc<Vec<::async_channel::Sender<Work>>>,
}
// TODO: graceful vs immediate stop
impl Drop for Fenrir {
fn drop(&mut self) {
self.stop_sync()
}
}
impl Fenrir {
/// Create a new Fenrir endpoint
pub fn new(config: &Config) -> Result<Self, Error> {
let (sender, _) = ::tokio::sync::broadcast::channel(1);
let endpoint = Fenrir {
cfg: config.clone(),
sockets: SocketList::new(),
dnssec: None,
stop_working: sender,
token_check: None,
conn_auth_srv: Mutex::new(AuthServerConnections::new()),
_thread_pool: Vec::new(),
_thread_work: Arc::new(Vec::new()),
};
Ok(endpoint)
}
/// Start all workers, listeners
pub async fn start(
&mut self,
tokio_rt: Arc<::tokio::runtime::Runtime>,
) -> Result<(), Error> {
self.start_work_threads_pinned(tokio_rt).await?;
if let Err(e) = self.add_sockets().await {
self.stop().await;
return Err(e.into());
}
self.dnssec = Some(dnssec::Dnssec::new(&self.cfg.resolvers).await?);
Ok(())
}
/// Stop all workers, listeners
/// asyncronous version for Drop
fn stop_sync(&mut self) {
let _ = self.stop_working.send(true);
let toempty_sockets = self.sockets.rm_all();
let task = ::tokio::task::spawn(toempty_sockets.stop_all());
let _ = ::futures::executor::block_on(task);
let mut old_thread_pool = Vec::new();
::std::mem::swap(&mut self._thread_pool, &mut old_thread_pool);
let _ = old_thread_pool.into_iter().map(|th| th.join());
self.dnssec = None;
}
/// Stop all workers, listeners
pub async fn stop(&mut self) {
let _ = self.stop_working.send(true);
let toempty_sockets = self.sockets.rm_all();
toempty_sockets.stop_all().await;
let mut old_thread_pool = Vec::new();
::std::mem::swap(&mut self._thread_pool, &mut old_thread_pool);
let _ = old_thread_pool.into_iter().map(|th| th.join());
self.dnssec = None;
}
/// Add all UDP sockets found in config
/// and start listening for packets
async fn add_sockets(&self) -> ::std::io::Result<()> {
let sockets = self.cfg.listen.iter().map(|s_addr| async {
let socket =
::tokio::spawn(connection::socket::bind_udp(s_addr.clone()))
.await??;
Ok(socket)
});
let sockets = ::futures::future::join_all(sockets).await;
for s_res in sockets.into_iter() {
match s_res {
Ok(s) => {
let stop_working = self.stop_working.subscribe();
let arc_s = Arc::new(s);
let join = ::tokio::spawn(Self::listen_udp(
stop_working,
self._thread_work.clone(),
arc_s.clone(),
));
self.sockets.add_socket(arc_s, join).await;
}
Err(e) => {
return Err(e);
}
}
}
Ok(())
}
/// Run a dedicated loop to read packets on the listening socket
async fn listen_udp(
mut stop_working: ::tokio::sync::broadcast::Receiver<bool>,
work_queues: Arc<Vec<::async_channel::Sender<Work>>>,
socket: Arc<UdpSocket>,
) -> ::std::io::Result<()> {
// jumbo frames are 9K max
let sock_receiver = UdpServer(socket.local_addr()?);
let mut buffer: [u8; 9000] = [0; 9000];
let queues_num = work_queues.len() as u64;
loop {
let (bytes, sock_sender) = ::tokio::select! {
_done = stop_working.recv() => {
break;
}
result = socket.recv_from(&mut buffer) => {
result?
}
};
let data: Vec<u8> = buffer[..bytes].to_vec();
// we very likely have multiple threads, pinned to different cpus.
// use the ConnectionID to send the same connection
// to the same thread.
// Handshakes have connection ID 0, so we use the sender's UDP port
let packet = match Packet::deserialize_id(&data) {
Ok(packet) => packet,
Err(_) => continue, // packet way too short, ignore.
};
let thread_idx: usize = {
use connection::packet::ConnectionID;
match packet.id {
ConnectionID::Handshake => {
let send_port = sock_sender.port() as u64;
((send_port % queues_num) - 1) as usize
}
ConnectionID::ID(id) => {
((id.get() % queues_num) - 1) as usize
}
}
};
let _ = work_queues[thread_idx]
.send(Work::Recv(RawUdp {
src: UdpClient(sock_sender),
dst: sock_receiver,
packet,
data,
}))
.await;
}
Ok(())
}
/// Get the raw TXT record of a Fenrir domain
pub async fn resolv_txt(&self, domain: &Domain) -> Result<String, Error> {
match &self.dnssec {
Some(dnssec) => Ok(dnssec.resolv(domain).await?),
None => Err(Error::NotInitialized),
}
}
/// Get the raw TXT record of a Fenrir domain
pub async fn resolv(
&self,
domain: &Domain,
) -> Result<dnssec::Record, Error> {
let record_str = self.resolv_txt(domain).await?;
Ok(dnssec::Dnssec::parse_txt_record(&record_str)?)
}
/// Connect to a service
pub async fn connect(
&self,
domain: &Domain,
service: ServiceID,
) -> Result<(), Error> {
let resolved = self.resolv(domain).await?;
loop {
// check if we already have a connection to that auth. srv
let is_reserved = {
let mut conn_auth_lock = self.conn_auth_srv.lock().await;
conn_auth_lock.get_or_reserve(&resolved)
};
use connection::Reservation;
match is_reserved {
Reservation::Waiting => {
use ::std::time::Duration;
use ::tokio::time::sleep;
// PERF: exponential backoff.
// or we can have a broadcast channel
sleep(Duration::from_millis(50)).await;
continue;
}
Reservation::Reserved => break,
Reservation::Present(_id_send) => {
//TODO: reuse connection
todo!()
}
}
}
// Spot reserved for the connection
// find the thread with less connections
let th_num = self._thread_work.len();
let mut conn_count = Vec::<usize>::with_capacity(th_num);
let mut wait_res =
Vec::<::tokio::sync::oneshot::Receiver<usize>>::with_capacity(
th_num,
);
for th in self._thread_work.iter() {
let (send, recv) = ::tokio::sync::oneshot::channel();
wait_res.push(recv);
let _ = th.send(Work::CountConnections(send)).await;
}
for ch in wait_res.into_iter() {
if let Ok(conn_num) = ch.await {
conn_count.push(conn_num);
}
}
if conn_count.len() != th_num {
return Err(Error::IO(::std::io::Error::new(
::std::io::ErrorKind::NotConnected,
"can't connect to a thread",
)));
}
let thread_idx = conn_count
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.cmp(b))
.map(|(index, _)| index)
.unwrap();
// and tell that thread to connect somewhere
let (send, recv) = ::tokio::sync::oneshot::channel();
let _ = self._thread_work[thread_idx]
.send(Work::Connect(ConnectInfo {
answer: send,
resolved: resolved.clone(),
service_id: service,
domain: domain.clone(),
}))
.await;
match recv.await {
Ok(res) => {
match res {
Err(e) => {
let mut conn_auth_lock =
self.conn_auth_srv.lock().await;
conn_auth_lock.remove_reserved(&resolved);
Err(e)
}
Ok((pubkey, id_send)) => {
let mut conn_auth_lock =
self.conn_auth_srv.lock().await;
conn_auth_lock.add(&pubkey, id_send, &resolved);
//FIXME: user needs to somehow track the connection
Ok(())
}
}
}
Err(e) => {
// Thread dropped the sender. no more thread?
let mut conn_auth_lock = self.conn_auth_srv.lock().await;
conn_auth_lock.remove_reserved(&resolved);
Err(Error::IO(::std::io::Error::new(
::std::io::ErrorKind::Interrupted,
"recv failure on connect: ".to_owned() + &e.to_string(),
)))
}
}
}
// TODO: start work on a LocalSet provided by the user
/// Start one working thread for each physical cpu
/// threads are pinned to each cpu core.
/// Work will be divided and rerouted so that there is no need to lock
async fn start_work_threads_pinned(
&mut self,
tokio_rt: Arc<::tokio::runtime::Runtime>,
) -> ::std::result::Result<(), Error> {
use ::std::sync::Mutex;
let hw_topology = match ::hwloc2::Topology::new() {
Some(hw_topology) => Arc::new(Mutex::new(hw_topology)),
None => {
return Err(Error::Setup(
"Can't get hardware topology".to_owned(),
))
}
};
let cores;
{
let topology_lock = hw_topology.lock().unwrap();
let all_cores = match topology_lock
.objects_with_type(&::hwloc2::ObjectType::Core)
{
Ok(all_cores) => all_cores,
Err(_) => {
return Err(Error::Setup("can't list cores".to_owned()))
}
};
cores = all_cores.len();
if cores <= 0 || !topology_lock.support().cpu().set_thread() {
::tracing::error!("No support for CPU pinning");
return Err(Error::Setup("No cpu pinning support".to_owned()));
}
}
for core in 0..cores {
::tracing::debug!("Spawning thread {}", core);
let th_topology = hw_topology.clone();
let th_tokio_rt = tokio_rt.clone();
let th_config = self.cfg.clone();
let (work_send, work_recv) = ::async_channel::unbounded::<Work>();
let th_stop_working = self.stop_working.subscribe();
let th_token_check = self.token_check.clone();
let th_socket_addrs = self.cfg.listen.clone();
let thread_id = ThreadTracker {
total: cores as u16,
id: 1 + (core as u16),
};
let join_handle = ::std::thread::spawn(move || {
// bind to a specific core
let th_pinning;
{
let mut th_topology_lock = th_topology.lock().unwrap();
let th_cores = th_topology_lock
.objects_with_type(&::hwloc2::ObjectType::Core)
.unwrap();
let cpuset = th_cores.get(core).unwrap().cpuset().unwrap();
th_pinning = th_topology_lock.set_cpubind(
cpuset,
::hwloc2::CpuBindFlags::CPUBIND_THREAD,
);
}
match th_pinning {
Ok(_) => {}
Err(_) => {
::tracing::error!("Can't bind thread to cpu");
return;
}
}
// finally run the main worker.
// make sure things stay on this thread
let tk_local = ::tokio::task::LocalSet::new();
let _ = tk_local.block_on(
&th_tokio_rt,
Worker::new_and_loop(
th_config,
thread_id,
th_stop_working,
th_token_check,
th_socket_addrs,
work_recv,
),
);
});
loop {
let queues_lock = match Arc::get_mut(&mut self._thread_work) {
Some(queues_lock) => queues_lock,
None => {
::tokio::time::sleep(
::std::time::Duration::from_millis(50),
)
.await;
continue;
}
};
queues_lock.push(work_send);
break;
}
self._thread_pool.push(join_handle);
}
Ok(())
}
}