libFenrir/src/lib.rs
Luca Fulchir 59394959bd
MPMC queue to distribute work on threads
Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
2023-02-24 22:00:56 +01:00

629 lines
22 KiB
Rust

#![deny(
missing_docs,
missing_debug_implementations,
missing_copy_implementations,
trivial_casts,
trivial_numeric_casts,
unsafe_code,
unstable_features,
unused_import_braces,
unused_qualifications
)]
//!
//! libFenrir is the official rust library implementing the Fenrir protocol
pub mod auth;
mod config;
pub mod connection;
pub mod dnssec;
pub mod enc;
use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption};
use ::std::{net::SocketAddr, pin::Pin, sync::Arc};
use ::tokio::macros::support::Future;
use ::tokio::{net::UdpSocket, task::JoinHandle};
use crate::enc::{
asym,
hkdf::HkdfSha3,
sym::{CipherKind, CipherRecv, CipherSend},
};
pub use config::Config;
use connection::handshake::{self, Handshake, HandshakeKey};
/// Main fenrir library errors
#[derive(::thiserror::Error, Debug)]
pub enum Error {
/// The library was not initialized (run .start())
#[error("not initialized")]
NotInitialized,
/// General I/O error
#[error("IO: {0:?}")]
IO(#[from] ::std::io::Error),
/// Dnssec errors
#[error("Dnssec: {0:?}")]
Dnssec(#[from] dnssec::Error),
/// Handshake errors
#[error("Handshake: {0:?}")]
Handshake(#[from] handshake::Error),
/// Key error
#[error("key: {0:?}")]
Key(#[from] crate::enc::Error),
}
// No async here
struct FenrirInner {
key_exchanges: ArcSwapAny<Arc<Vec<(asym::Key, asym::KeyExchange)>>>,
ciphers: ArcSwapAny<Arc<Vec<CipherKind>>>,
keys: ArcSwapAny<Arc<Vec<HandshakeKey>>>,
}
#[allow(unsafe_code)]
unsafe impl Send for FenrirInner {}
#[allow(unsafe_code)]
unsafe impl Sync for FenrirInner {}
/// Information needed to reply after the key exchange
#[derive(Debug, Clone)]
pub struct AuthNeededInfo {
/// Parsed handshake
pub handshake: Handshake,
/// hkdf generated from the handshake
pub hkdf: HkdfSha3,
/// cipher to be used in both directions
pub cipher: CipherKind,
}
/// Intermediate actions to be taken while parsing the handshake
#[derive(Debug, Clone)]
pub enum HandshakeAction {
/// Parsing finished, all ok, nothing to do
None,
/// Packet parsed, now go perform authentication
AuthNeeded(AuthNeededInfo),
}
// No async here
impl FenrirInner {
fn recv_handshake(
&self,
mut handshake: Handshake,
) -> Result<HandshakeAction, Error> {
use connection::handshake::{
dirsync::{self, DirSync},
HandshakeData,
};
match handshake.data {
HandshakeData::DirSync(ref mut ds) => match ds {
DirSync::Req(ref mut req) => {
let ephemeral_key = {
// Keep this block short to avoid contention
// on self.keys
let keys = self.keys.load();
if let Some(h_k) =
keys.iter().find(|k| k.id == req.key_id)
{
use enc::asym::PrivKey;
// Directory synchronized can only use keys
// for key exchange, not signing keys
if let PrivKey::Exchange(k) = &h_k.key {
Some(k.clone())
} else {
None
}
} else {
None
}
};
if ephemeral_key.is_none() {
::tracing::debug!("No such key id: {:?}", req.key_id);
return Err(handshake::Error::UnknownKeyID.into());
}
let ephemeral_key = ephemeral_key.unwrap();
{
let exchanges = self.key_exchanges.load();
if None
== exchanges.iter().find(|&x| {
*x == (ephemeral_key.kind(), req.exchange)
})
{
return Err(
enc::Error::UnsupportedKeyExchange.into()
);
}
}
{
let ciphers = self.ciphers.load();
if None == ciphers.iter().find(|&x| *x == req.cipher) {
return Err(enc::Error::UnsupportedCipher.into());
}
}
let shared_key = match ephemeral_key
.key_exchange(req.exchange, req.exchange_key)
{
Ok(shared_key) => shared_key,
Err(e) => return Err(handshake::Error::Key(e).into()),
};
let hkdf = HkdfSha3::new(b"fenrir", shared_key);
let secret_recv = hkdf.get_secret(b"to_server");
let cipher_recv = CipherRecv::new(req.cipher, secret_recv);
use crate::enc::sym::AAD;
let aad = AAD(&mut []); // no aad for now
match cipher_recv.decrypt(aad, &mut req.data.ciphertext()) {
Ok(()) => req.data.mark_as_cleartext(),
Err(e) => {
return Err(handshake::Error::Key(e).into());
}
}
req.set_data(dirsync::ReqData::deserialize(&req.data)?);
let cipher = req.cipher;
return Ok(HandshakeAction::AuthNeeded(AuthNeededInfo {
handshake,
hkdf,
cipher,
}));
}
DirSync::Resp(resp) => {
todo!();
}
},
}
}
}
type TokenChecker =
fn(
user: auth::UserID,
token: auth::Token,
service_id: auth::ServiceID,
domain: auth::Domain,
) -> ::futures::future::BoxFuture<'static, Result<bool, ()>>;
// PERF: move to arcswap:
// We use multiple UDP sockets.
// We need a list of them all and to scan this list.
// But we want to be able to update this list without locking everyone
// So we should wrap into ArcSwap, and use unsafe `as_raf_fd()` and
// `from_raw_fd` to update the list.
// This means that we will have multiple `UdpSocket` per actual socket
// so we have to handle `drop()` manually, and garbage-collect the ones we
// are no longer using in the background. sigh.
// Just go with a ArcSwapAny<Arc<Vec<Arc< ...
struct SocketList {
sockets:
ArcSwap<Vec<(Arc<UdpSocket>, Arc<JoinHandle<::std::io::Result<()>>>)>>,
}
impl SocketList {
fn new() -> Self {
Self {
sockets: ArcSwap::new(Arc::new(Vec::new())),
}
}
// TODO: fn rm_socket()
fn rm_all(&self) -> Self {
let new_list = Arc::new(Vec::new());
let old_list = self.sockets.swap(new_list);
Self {
sockets: old_list.into(),
}
}
async fn add_socket(
&self,
socket: Arc<UdpSocket>,
handle: JoinHandle<::std::io::Result<()>>,
) {
let mut new_list;
{
let old_list = self.sockets.load();
new_list = Arc::new(Vec::with_capacity(old_list.len() + 1));
new_list = old_list.to_vec().into();
}
Arc::get_mut(&mut new_list)
.unwrap()
.push((socket, Arc::new(handle)));
self.sockets.swap(new_list);
}
async fn find(&self, sock: SocketAddr) -> Option<Arc<UdpSocket>> {
let list = self.sockets.load();
match list.iter().find(|&(s, _)| s.local_addr().unwrap() == sock) {
Some((sock, _)) => Some(sock.clone()),
None => None,
}
}
async fn stop_all(mut self) {
let mut arc_list = self.sockets.into_inner();
let list = loop {
match Arc::try_unwrap(arc_list) {
Ok(list) => break list,
Err(arc_retry) => {
arc_list = arc_retry;
::tokio::time::sleep(::core::time::Duration::from_millis(
50,
))
.await;
}
}
};
for (_socket, mut handle) in list.into_iter() {
Arc::get_mut(&mut handle).unwrap().await;
}
}
}
struct RawUdp {
data: Vec<u8>,
src: SocketAddr,
dst: SocketAddr,
}
enum Work {
Recv(RawUdp),
}
/// Instance of a fenrir endpoint
#[allow(missing_copy_implementations, missing_debug_implementations)]
pub struct Fenrir {
/// library Configuration
cfg: Config,
/// listening udp sockets
sockets: SocketList,
/// DNSSEC resolver, with failovers
dnssec: Option<dnssec::Dnssec>,
/// Broadcast channel to tell workers to stop working
stop_working: ::tokio::sync::broadcast::Sender<bool>,
/// Private keys used in the handshake
_inner: Arc<FenrirInner>,
/// where to ask for token check
token_check: Arc<ArcSwapOption<TokenChecker>>,
/// MPMC work queue. sender
work_send: Arc<::async_channel::Sender<Work>>,
/// MPMC work queue. receiver
work_recv: Arc<::async_channel::Receiver<Work>>,
// PERF: rand uses syscalls. should we do that async?
rand: ::ring::rand::SystemRandom,
}
// TODO: graceful vs immediate stop
impl Drop for Fenrir {
fn drop(&mut self) {
self.stop_sync()
}
}
impl Fenrir {
/// Create a new Fenrir endpoint
pub fn new(config: &Config) -> Result<Self, Error> {
let listen_num = config.listen.len();
let (sender, _) = ::tokio::sync::broadcast::channel(1);
let (work_send, work_recv) = ::async_channel::unbounded::<Work>();
let endpoint = Fenrir {
cfg: config.clone(),
sockets: SocketList::new(),
dnssec: None,
stop_working: sender,
_inner: Arc::new(FenrirInner {
ciphers: ArcSwapAny::new(Arc::new(Vec::new())),
key_exchanges: ArcSwapAny::new(Arc::new(Vec::new())),
keys: ArcSwapAny::new(Arc::new(Vec::new())),
}),
token_check: Arc::new(ArcSwapOption::from(None)),
work_send: Arc::new(work_send),
work_recv: Arc::new(work_recv),
rand: ::ring::rand::SystemRandom::new(),
};
Ok(endpoint)
}
/// Start all workers, listeners
pub async fn start(&mut self) -> Result<(), Error> {
if let Err(e) = self.add_sockets().await {
self.stop().await;
return Err(e.into());
}
self.dnssec = Some(dnssec::Dnssec::new(&self.cfg.resolvers).await?);
Ok(())
}
/// Stop all workers, listeners
/// asyncronous version for Drop
fn stop_sync(&mut self) {
let _ = self.stop_working.send(true);
let mut toempty_sockets = self.sockets.rm_all();
let task = ::tokio::task::spawn(Self::stop_sockets(toempty_sockets));
//let mut toempty_socket = Vec::new();
//::std::mem::swap(&mut self.sockets, &mut toempty_socket);
//let task = ::tokio::task::spawn(Self::stop_sockets(toempty_socket));
let _ = ::futures::executor::block_on(task);
self.dnssec = None;
}
/// Stop all workers, listeners
pub async fn stop(&mut self) {
let _ = self.stop_working.send(true);
let mut toempty_sockets = self.sockets.rm_all();
Self::stop_sockets(toempty_sockets).await;
self.dnssec = None;
}
/// actually do the work of stopping resolvers and listeners
async fn stop_sockets(sockets: SocketList) {
sockets.stop_all().await;
}
/// Enable some common socket options. This is just the unsafe part
fn enable_sock_opt(
fd: ::std::os::fd::RawFd,
option: ::libc::c_int,
value: ::libc::c_int,
) -> ::std::io::Result<()> {
#[allow(unsafe_code)]
unsafe {
#[allow(trivial_casts)]
let val = &value as *const _ as *const ::libc::c_void;
let size = ::std::mem::size_of_val(&value) as ::libc::socklen_t;
// always clear the error bit before doing a new syscall
let _ = ::std::io::Error::last_os_error();
let ret =
::libc::setsockopt(fd, ::libc::SOL_SOCKET, option, val, size);
if ret != 0 {
return Err(::std::io::Error::last_os_error());
}
}
Ok(())
}
/// Add all UDP sockets found in config
/// and start listening for packets
async fn add_sockets(&mut self) -> ::std::io::Result<()> {
let sockets = self.cfg.listen.iter().map(|s_addr| async {
let socket =
::tokio::spawn(Self::bind_udp(s_addr.clone())).await??;
Ok(socket)
});
let sockets = ::futures::future::join_all(sockets).await;
for s_res in sockets.into_iter() {
match s_res {
Ok(s) => {
let stop_working = self.stop_working.subscribe();
let arc_s = Arc::new(s);
let join = ::tokio::spawn(Self::listen_udp(
stop_working,
self.work_send.clone(),
arc_s.clone(),
));
self.sockets.add_socket(arc_s, join);
}
Err(e) => {
return Err(e);
}
}
}
Ok(())
}
/// Add an async udp listener
async fn bind_udp(sock: SocketAddr) -> ::std::io::Result<UdpSocket> {
let socket = UdpSocket::bind(sock).await?;
use ::std::os::fd::AsRawFd;
let fd = socket.as_raw_fd();
// can be useful later on for reloads
Self::enable_sock_opt(fd, ::libc::SO_REUSEADDR, 1)?;
Self::enable_sock_opt(fd, ::libc::SO_REUSEPORT, 1)?;
// We will do path MTU discovery by ourselves,
// always set the "don't fragment" bit
if sock.is_ipv6() {
Self::enable_sock_opt(fd, ::libc::IPV6_DONTFRAG, 1)?;
} else {
// FIXME: linux only
Self::enable_sock_opt(
fd,
::libc::IP_MTU_DISCOVER,
::libc::IP_PMTUDISC_DO,
)?;
}
Ok(socket)
}
/// Run a dedicated loop to read packets on the listening socket
async fn listen_udp(
mut stop_working: ::tokio::sync::broadcast::Receiver<bool>,
work_queue: Arc<::async_channel::Sender<Work>>,
socket: Arc<UdpSocket>,
) -> ::std::io::Result<()> {
// jumbo frames are 9K max
let sock_receiver = socket.local_addr()?;
let mut buffer: [u8; 9000] = [0; 9000];
loop {
let (bytes, sock_sender) = ::tokio::select! {
_done = stop_working.recv() => {
break;
}
result = socket.recv_from(&mut buffer) => {
result?
}
};
let data: Vec<u8> = buffer[..bytes].to_vec();
work_queue.send(Work::Recv(RawUdp {
data,
src: sock_sender,
dst: sock_receiver.clone(),
}));
}
Ok(())
}
/// Get the raw TXT record of a Fenrir domain
pub async fn resolv_str(&self, domain: &str) -> Result<String, Error> {
match &self.dnssec {
Some(dnssec) => Ok(dnssec.resolv(domain).await?),
None => Err(Error::NotInitialized),
}
}
/// Get the raw TXT record of a Fenrir domain
pub async fn resolv(&self, domain: &str) -> Result<dnssec::Record, Error> {
let record_str = self.resolv_str(domain).await?;
Ok(dnssec::Dnssec::parse_txt_record(&record_str)?)
}
/// Loop continuously and parse packets and other work
pub async fn work_loop(&self) {
let mut stop_working = self.stop_working.subscribe();
loop {
let work = ::tokio::select! {
_done = stop_working.recv() => {
break;
}
maybe_work = self.work_recv.recv() => {
match maybe_work {
Ok(work) => work,
Err(_) => break,
}
}
};
match work {
Work::Recv(pkt) => {
self.recv(pkt).await;
}
}
}
}
const MIN_PACKET_BYTES: usize = 8;
/// Read and do stuff with the raw udp packet
async fn recv(&self, udp: RawUdp) {
if udp.data.len() < Self::MIN_PACKET_BYTES {
return;
}
use connection::ID;
let raw_id: [u8; 8] = (udp.data[..8]).try_into().expect("unreachable");
if ID::from(raw_id).is_handshake() {
use connection::handshake::Handshake;
let handshake = match Handshake::deserialize(&udp.data[8..]) {
Ok(handshake) => handshake,
Err(e) => {
::tracing::warn!("Handshake parsing: {}", e);
return;
}
};
let action = match self._inner.recv_handshake(handshake) {
Ok(action) => action,
Err(err) => {
::tracing::debug!("Handshake recv error {}", err);
return;
}
};
match action {
HandshakeAction::AuthNeeded(authinfo) => {
let tk_check = match self.token_check.load_full() {
Some(tk_check) => tk_check,
None => {
::tracing::error!(
"Handshake received, but no tocken_checker"
);
return;
}
};
use handshake::{
dirsync::{self, DirSync},
HandshakeData,
};
match authinfo.handshake.data {
HandshakeData::DirSync(ds) => match ds {
DirSync::Req(req) => {
use dirsync::ReqInner;
let req_data = match req.data {
ReqInner::Data(req_data) => req_data,
_ => {
::tracing::error!(
"token_check: expected Data"
);
return;
}
};
let is_authenticated = match tk_check(
req_data.auth.user,
req_data.auth.token,
req_data.auth.service_id,
req_data.auth.domain,
)
.await
{
Ok(is_authenticated) => is_authenticated,
Err(_) => {
::tracing::error!(
"error in token auth"
);
// TODO: retry?
return;
}
};
if !is_authenticated {
::tracing::warn!(
"Wrong authentication for user {:?}",
req_data.auth.user
);
// TODO: error response
return;
}
// Client has correctly authenticated
// TODO: contact the service, get the key and
// connection ID
let srv_conn_id =
connection::ID::new_rand(&self.rand);
let auth_conn_id =
connection::ID::new_rand(&self.rand);
let srv_secret =
enc::sym::Secret::new_rand(&self.rand);
let resp_data = dirsync::RespData {
client_nonce: req_data.nonce,
id: auth_conn_id,
service_id: srv_conn_id,
service_key: srv_secret,
};
// build response
let secret_send =
authinfo.hkdf.get_secret(b"to_client");
let mut cipher_send = CipherSend::new(
authinfo.cipher,
secret_send,
);
use crate::enc::sym::AAD;
let aad = AAD(&mut []); // no aad for now
let mut data = cipher_send
.make_data(dirsync::RespData::len());
if let Err(e) =
cipher_send.encrypt(aad, &mut data)
{
::tracing::error!("can't encrypt: {:?}", e);
return;
}
let resp = dirsync::Resp {
client_key_id: req_data.client_key_id,
enc: data.get_raw(),
};
todo!()
}
_ => {
todo!()
}
},
}
}
_ => {}
};
}
// copy packet, spawn
todo!();
}
}