Refactor, more pinned-thread work

Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
This commit is contained in:
Luca Fulchir 2023-05-24 15:45:37 +02:00
parent c0d6cf1824
commit 9b33ed8828
Signed by: luca.fulchir
GPG Key ID: 8F6440603D13A78E
5 changed files with 385 additions and 293 deletions

View File

@ -1,4 +1,4 @@
//! Authentication reslated struct definitions
//! Authentication related struct definitions
use ::ring::rand::SecureRandom;
use ::zeroize::Zeroize;
@ -53,6 +53,16 @@ impl ::core::fmt::Debug for Token {
}
}
/// Type of the function used to check the validity of the tokens
/// Reimplement this to use whatever database you want
pub type TokenChecker =
fn(
user: UserID,
token: Token,
service_id: ServiceID,
domain: Domain,
) -> ::futures::future::BoxFuture<'static, Result<bool, ()>>;
/// domain representation
/// Security notice: internal representation is utf8, but we will
/// further limit to a "safe" subset of utf8

View File

@ -8,17 +8,8 @@ use ::std::{
};
use ::tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle};
// PERF: move to arcswap:
// We use multiple UDP sockets.
// We need a list of them all and to scan this list.
// But we want to be able to update this list without locking everyone
// So we should wrap into ArcSwap, and use unsafe `as_raf_fd()` and
// `from_raw_fd` to update the list.
// This means that we will have multiple `UdpSocket` per actual socket
// so we have to handle `drop()` manually, and garbage-collect the ones we
// are no longer using in the background. sigh.
// Just go with a ArcSwapAny<Arc<Vec<Arc< ...
type SocketTracker = (Arc<UdpSocket>, Arc<JoinHandle<::std::io::Result<()>>>);
/// Pair to easily track the socket and its async listening handle
pub type SocketTracker = (Arc<UdpSocket>, Arc<JoinHandle<::std::io::Result<()>>>);
/// async free socket list
pub(crate) struct SocketList {

View File

@ -2,7 +2,10 @@
//! This is meant to be **async-free** so that others might use it
//! without the tokio runtime
pub(crate) mod worker;
use crate::{
auth,
connection::{
self,
handshake::{self, Handshake, HandshakeClient, HandshakeServer},
@ -49,7 +52,7 @@ pub enum HandshakeAction {
}
/// Async free but thread safe tracking of handhsakes and conenctions
pub struct Tracker {
pub struct HandshakeTracker {
key_exchanges: ArcSwapAny<Arc<Vec<(asym::Key, asym::KeyExchange)>>>,
ciphers: ArcSwapAny<Arc<Vec<CipherKind>>>,
/// ephemeral keys used server side in key exchange
@ -58,11 +61,11 @@ pub struct Tracker {
hshake_cli: ArcSwapAny<Arc<Vec<HandshakeClient>>>,
}
#[allow(unsafe_code)]
unsafe impl Send for Tracker {}
unsafe impl Send for HandshakeTracker {}
#[allow(unsafe_code)]
unsafe impl Sync for Tracker {}
unsafe impl Sync for HandshakeTracker {}
impl Tracker {
impl HandshakeTracker {
pub fn new() -> Self {
Self {
ciphers: ArcSwapAny::new(Arc::new(Vec::new())),

307
src/inner/worker.rs Normal file
View File

@ -0,0 +1,307 @@
//! Worker thread implementation
use crate::{
auth::TokenChecker,
connection::{
self,
handshake::{
self,
dirsync::{self, DirSync},
Handshake, HandshakeClient, HandshakeData,
},
socket::{UdpClient, UdpServer},
ConnList, Connection, IDSend, Packet, ID,
},
enc::sym::Secret,
inner::{HandshakeAction, HandshakeTracker},
};
use ::std::{sync::Arc, vec::Vec};
/// This worker must be cpu-pinned
use ::tokio::{net::UdpSocket, sync::Mutex};
use std::net::SocketAddr;
/// Track a raw Udp packet
pub(crate) struct RawUdp {
pub src: UdpClient,
pub dst: UdpServer,
pub data: Vec<u8>,
pub packet: Packet,
}
pub(crate) enum Work {
Recv(RawUdp),
}
/// Actual worker implementation.
pub(crate) struct Worker {
// PERF: rand uses syscalls. how to do that async?
rand: ::ring::rand::SystemRandom,
stop_working: ::tokio::sync::broadcast::Receiver<bool>,
token_check: Option<Arc<Mutex<TokenChecker>>>,
sockets: Vec<UdpSocket>,
queue: ::async_channel::Receiver<Work>,
thread_channels: Vec<::async_channel::Sender<Work>>,
connections: ConnList,
handshakes: HandshakeTracker,
}
impl Worker {
pub(crate) async fn new(
stop_working: ::tokio::sync::broadcast::Receiver<bool>,
token_check: Option<Arc<Mutex<TokenChecker>>>,
socket_addrs: Vec<::std::net::SocketAddr>,
queue: ::async_channel::Receiver<Work>,
) -> ::std::io::Result<Self> {
// bind all sockets again so that we can easily
// send without sharing resources
// in the future we will want to have a thread-local listener too,
// but before that we need ebpf to pin a connection to a thread
// directly from the kernel
let socket_binding =
socket_addrs.into_iter().map(|s_addr| async move {
let socket = ::tokio::spawn(connection::socket::bind_udp(
s_addr.clone(),
))
.await??;
Ok(socket)
});
let sockets_bind_res =
::futures::future::join_all(socket_binding).await;
let sockets: Result<Vec<UdpSocket>, ::std::io::Error> =
sockets_bind_res
.into_iter()
.map(|s_res| match s_res {
Ok(s) => Ok(s),
Err(e) => {
::tracing::error!("Worker can't bind on socket: {}", e);
Err(e)
}
})
.collect();
let sockets = match sockets {
Ok(sockets) => sockets,
Err(e) => {
return Err(e);
}
};
Ok(Self {
rand: ::ring::rand::SystemRandom::new(),
stop_working,
token_check,
sockets,
queue,
thread_channels: Vec::new(),
connections: ConnList::new(),
handshakes: HandshakeTracker::new(),
})
}
pub(crate) async fn work_loop(&mut self) {
loop {
let work = ::tokio::select! {
_done = self.stop_working.recv() => {
break;
}
maybe_work = self.queue.recv() => {
match maybe_work {
Ok(work) => work,
Err(_) => break,
}
}
};
match work {
//TODO: reconf message to add channels
Work::Recv(pkt) => {
self.recv(pkt).await;
}
}
}
}
/// Read and do stuff with the raw udp packet
async fn recv(&mut self, mut udp: RawUdp) {
if udp.packet.id.is_handshake() {
let handshake = match Handshake::deserialize(&udp.data[8..]) {
Ok(handshake) => handshake,
Err(e) => {
::tracing::warn!("Handshake parsing: {}", e);
return;
}
};
let action = match self
.handshakes
.recv_handshake(handshake, &mut udp.data[8..])
{
Ok(action) => action,
Err(err) => {
::tracing::debug!("Handshake recv error {}", err);
return;
}
};
match action {
HandshakeAction::AuthNeeded(authinfo) => {
let token_check = match self.token_check.as_ref() {
Some(token_check) => token_check,
None => {
::tracing::error!(
"Authentication requested but \
we have no token checker"
);
return;
}
};
let req;
if let HandshakeData::DirSync(DirSync::Req(r)) =
authinfo.handshake.data
{
req = r;
} else {
::tracing::error!("AuthInfo on non DS::Req");
return;
}
use dirsync::ReqInner;
let req_data = match req.data {
ReqInner::ClearText(req_data) => req_data,
_ => {
::tracing::error!(
"token_check: expected ClearText"
);
return;
}
};
let is_authenticated = {
let tk_check = token_check.lock().await;
tk_check(
req_data.auth.user,
req_data.auth.token,
req_data.auth.service_id,
req_data.auth.domain,
)
.await
};
let is_authenticated = match is_authenticated {
Ok(is_authenticated) => is_authenticated,
Err(_) => {
::tracing::error!("error in token auth");
// TODO: retry?
return;
}
};
if !is_authenticated {
::tracing::warn!(
"Wrong authentication for user {:?}",
req_data.auth.user
);
// TODO: error response
return;
}
// Client has correctly authenticated
// TODO: contact the service, get the key and
// connection ID
let srv_conn_id = ID::new_rand(&self.rand);
let srv_secret = Secret::new_rand(&self.rand);
let head_len = req.cipher.nonce_len();
let tag_len = req.cipher.tag_len();
let mut raw_conn = Connection::new(
authinfo.hkdf,
req.cipher,
connection::Role::Server,
&self.rand,
);
raw_conn.id_send = IDSend(req_data.id);
// track connection
let auth_conn = self.connections.reserve_first(raw_conn);
let resp_data = dirsync::RespData {
client_nonce: req_data.nonce,
id: auth_conn.id_recv.0,
service_id: srv_conn_id,
service_key: srv_secret,
};
use crate::enc::sym::AAD;
// no aad for now
let aad = AAD(&mut []);
use dirsync::RespInner;
let resp = dirsync::Resp {
client_key_id: req_data.client_key_id,
data: RespInner::ClearText(resp_data),
};
let offset_to_encrypt = resp.encrypted_offset();
let encrypt_until =
offset_to_encrypt + resp.encrypted_length() + tag_len.0;
let resp_handshake = Handshake::new(
HandshakeData::DirSync(DirSync::Resp(resp)),
);
use connection::{PacketData, ID};
let packet = Packet {
id: ID::new_handshake(),
data: PacketData::Handshake(resp_handshake),
};
let mut raw_out = Vec::<u8>::with_capacity(packet.len());
packet.serialize(head_len, tag_len, &mut raw_out);
if let Err(e) = auth_conn.cipher_send.encrypt(
aad,
&mut raw_out[offset_to_encrypt..encrypt_until],
) {
::tracing::error!("can't encrypt: {:?}", e);
return;
}
self.send_packet(raw_out, udp.src, udp.dst).await;
return;
}
HandshakeAction::ClientConnect(mut cci) => {
let ds_resp;
if let HandshakeData::DirSync(DirSync::Resp(resp)) =
cci.handshake.data
{
ds_resp = resp;
} else {
::tracing::error!("ClientConnect on non DS::Resp");
return;
}
// track connection
use handshake::dirsync;
let resp_data;
if let dirsync::RespInner::ClearText(r_data) = ds_resp.data
{
resp_data = r_data;
} else {
::tracing::error!(
"ClientConnect on non DS::Resp::ClearText"
);
return;
}
// FIXME: conn tracking and arc counting
let conn = Arc::get_mut(&mut cci.connection).unwrap();
conn.id_send = IDSend(resp_data.id);
todo!();
}
_ => {}
};
}
// copy packet, spawn
todo!();
}
async fn send_packet(
&self,
data: Vec<u8>,
client: UdpClient,
server: UdpServer,
) {
let src_sock = match self
.sockets
.iter()
.find(|&s| s.local_addr().unwrap() == server.0)
{
Some(src_sock) => src_sock,
None => {
::tracing::error!(
"Can't send packet: Server changed listening ip!"
);
return;
}
};
src_sock.send_to(&data, client.0);
}
}

View File

@ -20,32 +20,21 @@ pub mod dnssec;
pub mod enc;
mod inner;
use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption};
use ::std::{
net::SocketAddr,
pin::Pin,
sync::{Arc, Weak},
vec::{self, Vec},
};
use ::tokio::{
macros::support::Future, net::UdpSocket, sync::RwLock, task::JoinHandle,
vec::Vec,
};
use ::tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle};
use crate::{
auth::TokenChecker,
connection::{
handshake::{
self, dirsync::DirSync, Handshake, HandshakeClient, HandshakeData,
HandshakeServer,
},
handshake,
socket::{SocketList, UdpClient, UdpServer},
ConnList, Connection, IDSend, Packet,
Packet,
},
enc::{
asym,
hkdf::HkdfSha3,
sym::{CipherKind, CipherRecv, CipherSend, HeadLen, TagLen},
},
inner::HandshakeAction,
inner::worker::{RawUdp, Work, Worker},
};
pub use config::Config;
@ -55,6 +44,9 @@ pub enum Error {
/// The library was not initialized (run .start())
#[error("not initialized")]
NotInitialized,
/// Error in setting up worker threads
#[error("Setup err: {0}")]
Setup(String),
/// General I/O error
#[error("IO: {0:?}")]
IO(#[from] ::std::io::Error),
@ -69,26 +61,6 @@ pub enum Error {
Key(#[from] crate::enc::Error),
}
type TokenChecker =
fn(
user: auth::UserID,
token: auth::Token,
service_id: auth::ServiceID,
domain: auth::Domain,
) -> ::futures::future::BoxFuture<'static, Result<bool, ()>>;
/// Track a raw Udp packet
struct RawUdp {
src: UdpClient,
dst: UdpServer,
data: Vec<u8>,
packet: Packet,
}
enum Work {
Recv(RawUdp),
}
/// Instance of a fenrir endpoint
#[allow(missing_copy_implementations, missing_debug_implementations)]
pub struct Fenrir {
@ -101,14 +73,11 @@ pub struct Fenrir {
/// Broadcast channel to tell workers to stop working
stop_working: ::tokio::sync::broadcast::Sender<bool>,
/// Private keys used in the handshake
_inner: Arc<inner::Tracker>,
_inner: Arc<inner::HandshakeTracker>,
/// where to ask for token check
token_check: Arc<ArcSwapOption<TokenChecker>>,
token_check: Option<Arc<::tokio::sync::Mutex<TokenChecker>>>,
// PERF: rand uses syscalls. should we do that async?
rand: ::ring::rand::SystemRandom,
/// list of Established connections
connections: Arc<RwLock<ConnList>>,
_myself: Weak<Self>,
// TODO: find a way to both increase and decrease these two in a thread-safe
// manner
_thread_pool: Vec<::std::thread::JoinHandle<()>>,
@ -125,28 +94,30 @@ impl Drop for Fenrir {
impl Fenrir {
/// Create a new Fenrir endpoint
pub fn new(config: &Config) -> Result<Arc<Self>, Error> {
pub fn new(config: &Config) -> Result<Self, Error> {
let listen_num = config.listen.len();
let (sender, _) = ::tokio::sync::broadcast::channel(1);
let (work_send, work_recv) = ::async_channel::unbounded::<Work>();
let endpoint = Arc::new_cyclic(|myself| Fenrir {
let endpoint = Fenrir {
cfg: config.clone(),
sockets: SocketList::new(),
dnssec: None,
stop_working: sender,
_inner: Arc::new(inner::Tracker::new()),
token_check: Arc::new(ArcSwapOption::from(None)),
_inner: Arc::new(inner::HandshakeTracker::new()),
token_check: None,
rand: ::ring::rand::SystemRandom::new(),
connections: Arc::new(RwLock::new(ConnList::new())),
_myself: myself.clone(),
_thread_pool: Vec::new(),
_thread_work: Arc::new(Vec::new()),
});
};
Ok(endpoint)
}
/// Start all workers, listeners
pub async fn start(&mut self) -> Result<(), Error> {
pub async fn start(
&mut self,
tokio_rt: Arc<::tokio::runtime::Runtime>,
) -> Result<(), Error> {
self.start_work_threads_pinned(tokio_rt).await?;
if let Err(e) = self.add_sockets().await {
self.stop().await;
return Err(e.into());
@ -159,17 +130,25 @@ impl Fenrir {
/// asyncronous version for Drop
fn stop_sync(&mut self) {
let _ = self.stop_working.send(true);
// FIXME: wait for thread pool to actually stop
let mut toempty_sockets = self.sockets.rm_all();
let task = ::tokio::task::spawn(toempty_sockets.stop_all());
let _ = ::futures::executor::block_on(task);
let mut old_thread_pool = Vec::new();
::std::mem::swap(&mut self._thread_pool, &mut old_thread_pool);
old_thread_pool.into_iter().map(|th| th.join());
self.dnssec = None;
}
/// Stop all workers, listeners
pub async fn stop(&mut self) {
let _ = self.stop_working.send(true);
// FIXME: wait for thread pool to actually stop
let mut toempty_sockets = self.sockets.rm_all();
toempty_sockets.stop_all().await;
let mut old_thread_pool = Vec::new();
::std::mem::swap(&mut self._thread_pool, &mut old_thread_pool);
old_thread_pool.into_iter().map(|th| th.join());
self.dnssec = None;
}
/// Add all UDP sockets found in config
@ -271,14 +250,18 @@ impl Fenrir {
/// Start one working thread for each physical cpu
/// threads are pinned to each cpu core.
/// Work will be divided and rerouted so that there is no need to lock
pub async fn start_work_threads_pinned(
async fn start_work_threads_pinned(
&mut self,
tokio_rt: Arc<::tokio::runtime::Runtime>,
) -> ::std::result::Result<(), ()> {
) -> ::std::result::Result<(), Error> {
use ::std::sync::Mutex;
let hw_topology = match ::hwloc2::Topology::new() {
Some(hw_topology) => Arc::new(Mutex::new(hw_topology)),
None => return Err(()),
None => {
return Err(Error::Setup(
"Can't get hardware topology".to_owned(),
))
}
};
let cores;
{
@ -287,20 +270,36 @@ impl Fenrir {
.objects_with_type(&::hwloc2::ObjectType::Core)
{
Ok(all_cores) => all_cores,
Err(_) => return Err(()),
Err(_) => {
return Err(Error::Setup("can't list cores".to_owned()))
}
};
cores = all_cores.len();
if cores <= 0 || !topology_lock.support().cpu().set_thread() {
::tracing::error!("No support for CPU pinning");
return Err(());
return Err(Error::Setup("No cpu pinning support".to_owned()));
}
}
for core in 0..cores {
::tracing::debug!("Spawning thread {}", core);
let th_topology = hw_topology.clone();
let th_tokio_rt = tokio_rt.clone();
let th_myself = self._myself.upgrade().unwrap();
let (work_send, work_recv) = ::async_channel::unbounded::<Work>();
let mut worker = match Worker::new(
self.stop_working.subscribe(),
self.token_check.clone(),
self.cfg.listen.clone(),
work_recv,
)
.await
{
Ok(worker) => worker,
Err(e) => {
::tracing::error!("can't start worker");
return Err(Error::IO(e));
}
};
let join_handle = ::std::thread::spawn(move || {
// bind to a specific core
let th_pinning;
@ -322,13 +321,10 @@ impl Fenrir {
return;
}
}
// finally run the main listener. make sure things stay on this
// thread
// finally run the main worker.
// make sure things stay on this thread
let tk_local = ::tokio::task::LocalSet::new();
let _ = tk_local.block_on(
&th_tokio_rt,
Self::work_loop_thread(th_myself, work_recv),
);
let _ = tk_local.block_on(&th_tokio_rt, worker.work_loop());
});
loop {
let queues_lock = match Arc::get_mut(&mut self._thread_work) {
@ -348,219 +344,4 @@ impl Fenrir {
}
Ok(())
}
async fn work_loop_thread(
self: Arc<Self>,
work_recv: ::async_channel::Receiver<Work>,
) {
let mut stop_working = self.stop_working.subscribe();
loop {
let work = ::tokio::select! {
_done = stop_working.recv() => {
break;
}
maybe_work = work_recv.recv() => {
match maybe_work {
Ok(work) => work,
Err(_) => break,
}
}
};
match work {
Work::Recv(pkt) => {
self.recv(pkt).await;
}
}
}
}
/// Read and do stuff with the raw udp packet
async fn recv(&self, mut udp: RawUdp) {
if udp.packet.id.is_handshake() {
let handshake = match Handshake::deserialize(&udp.data[8..]) {
Ok(handshake) => handshake,
Err(e) => {
::tracing::warn!("Handshake parsing: {}", e);
return;
}
};
let action =
match self._inner.recv_handshake(handshake, &mut udp.data[8..])
{
Ok(action) => action,
Err(err) => {
::tracing::debug!("Handshake recv error {}", err);
return;
}
};
match action {
HandshakeAction::AuthNeeded(authinfo) => {
let tk_check = match self.token_check.load_full() {
Some(tk_check) => tk_check,
None => {
::tracing::error!(
"Handshake received, but no tocken_checker"
);
return;
}
};
use handshake::{
dirsync::{self, DirSync},
HandshakeData,
};
let req;
if let HandshakeData::DirSync(DirSync::Req(r)) =
authinfo.handshake.data
{
req = r;
} else {
::tracing::error!("AuthInfo on non DS::Req");
return;
}
use dirsync::ReqInner;
let req_data = match req.data {
ReqInner::ClearText(req_data) => req_data,
_ => {
::tracing::error!(
"token_check: expected ClearText"
);
return;
}
};
let is_authenticated = match tk_check(
req_data.auth.user,
req_data.auth.token,
req_data.auth.service_id,
req_data.auth.domain,
)
.await
{
Ok(is_authenticated) => is_authenticated,
Err(_) => {
::tracing::error!("error in token auth");
// TODO: retry?
return;
}
};
if !is_authenticated {
::tracing::warn!(
"Wrong authentication for user {:?}",
req_data.auth.user
);
// TODO: error response
return;
}
// Client has correctly authenticated
// TODO: contact the service, get the key and
// connection ID
let srv_conn_id = connection::ID::new_rand(&self.rand);
let srv_secret = enc::sym::Secret::new_rand(&self.rand);
let head_len = req.cipher.nonce_len();
let tag_len = req.cipher.tag_len();
let mut raw_conn = Connection::new(
authinfo.hkdf,
req.cipher,
connection::Role::Server,
&self.rand,
);
raw_conn.id_send = IDSend(req_data.id);
// track connection
let auth_conn = {
let mut lock = self.connections.write().await;
lock.reserve_first(raw_conn)
};
let resp_data = dirsync::RespData {
client_nonce: req_data.nonce,
id: auth_conn.id_recv.0,
service_id: srv_conn_id,
service_key: srv_secret,
};
use crate::enc::sym::AAD;
// no aad for now
let aad = AAD(&mut []);
use dirsync::RespInner;
let resp = dirsync::Resp {
client_key_id: req_data.client_key_id,
data: RespInner::ClearText(resp_data),
};
let offset_to_encrypt = resp.encrypted_offset();
let encrypt_until =
offset_to_encrypt + resp.encrypted_length() + tag_len.0;
let resp_handshake = Handshake::new(
HandshakeData::DirSync(DirSync::Resp(resp)),
);
use connection::{PacketData, ID};
let packet = Packet {
id: ID::new_handshake(),
data: PacketData::Handshake(resp_handshake),
};
let mut raw_out = Vec::<u8>::with_capacity(packet.len());
packet.serialize(head_len, tag_len, &mut raw_out);
if let Err(e) = auth_conn.cipher_send.encrypt(
aad,
&mut raw_out[offset_to_encrypt..encrypt_until],
) {
::tracing::error!("can't encrypt: {:?}", e);
return;
}
self.send_packet(raw_out, udp.src, udp.dst).await;
return;
}
HandshakeAction::ClientConnect(mut cci) => {
let ds_resp;
if let HandshakeData::DirSync(DirSync::Resp(resp)) =
cci.handshake.data
{
ds_resp = resp;
} else {
::tracing::error!("ClientConnect on non DS::Resp");
return;
}
// track connection
use handshake::dirsync;
let resp_data;
if let dirsync::RespInner::ClearText(r_data) = ds_resp.data
{
resp_data = r_data;
} else {
::tracing::error!(
"ClientConnect on non DS::Resp::ClearText"
);
return;
}
// FIXME: conn tracking and arc counting
let conn = Arc::get_mut(&mut cci.connection).unwrap();
conn.id_send = IDSend(resp_data.id);
todo!();
}
_ => {}
};
}
// copy packet, spawn
todo!();
}
async fn send_packet(
&self,
data: Vec<u8>,
client: UdpClient,
server: UdpServer,
) {
let src_sock;
{
let sockets = self.sockets.lock();
src_sock = match sockets.find(server) {
Some(src_sock) => src_sock,
None => {
::tracing::error!(
"Can't send packet: Server changed listening ip!"
);
return;
}
};
}
src_sock.send_to(&data, client.0);
}
}