Refactor, more pinned-thread work

Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
This commit is contained in:
Luca Fulchir 2023-05-24 15:45:37 +02:00
parent c0d6cf1824
commit 9b33ed8828
Signed by: luca.fulchir
GPG Key ID: 8F6440603D13A78E
5 changed files with 385 additions and 293 deletions

View File

@ -1,4 +1,4 @@
//! Authentication reslated struct definitions //! Authentication related struct definitions
use ::ring::rand::SecureRandom; use ::ring::rand::SecureRandom;
use ::zeroize::Zeroize; use ::zeroize::Zeroize;
@ -53,6 +53,16 @@ impl ::core::fmt::Debug for Token {
} }
} }
/// Type of the function used to check the validity of the tokens
/// Reimplement this to use whatever database you want
pub type TokenChecker =
fn(
user: UserID,
token: Token,
service_id: ServiceID,
domain: Domain,
) -> ::futures::future::BoxFuture<'static, Result<bool, ()>>;
/// domain representation /// domain representation
/// Security notice: internal representation is utf8, but we will /// Security notice: internal representation is utf8, but we will
/// further limit to a "safe" subset of utf8 /// further limit to a "safe" subset of utf8

View File

@ -8,17 +8,8 @@ use ::std::{
}; };
use ::tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle}; use ::tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle};
// PERF: move to arcswap: /// Pair to easily track the socket and its async listening handle
// We use multiple UDP sockets. pub type SocketTracker = (Arc<UdpSocket>, Arc<JoinHandle<::std::io::Result<()>>>);
// We need a list of them all and to scan this list.
// But we want to be able to update this list without locking everyone
// So we should wrap into ArcSwap, and use unsafe `as_raf_fd()` and
// `from_raw_fd` to update the list.
// This means that we will have multiple `UdpSocket` per actual socket
// so we have to handle `drop()` manually, and garbage-collect the ones we
// are no longer using in the background. sigh.
// Just go with a ArcSwapAny<Arc<Vec<Arc< ...
type SocketTracker = (Arc<UdpSocket>, Arc<JoinHandle<::std::io::Result<()>>>);
/// async free socket list /// async free socket list
pub(crate) struct SocketList { pub(crate) struct SocketList {

View File

@ -2,7 +2,10 @@
//! This is meant to be **async-free** so that others might use it //! This is meant to be **async-free** so that others might use it
//! without the tokio runtime //! without the tokio runtime
pub(crate) mod worker;
use crate::{ use crate::{
auth,
connection::{ connection::{
self, self,
handshake::{self, Handshake, HandshakeClient, HandshakeServer}, handshake::{self, Handshake, HandshakeClient, HandshakeServer},
@ -49,7 +52,7 @@ pub enum HandshakeAction {
} }
/// Async free but thread safe tracking of handhsakes and conenctions /// Async free but thread safe tracking of handhsakes and conenctions
pub struct Tracker { pub struct HandshakeTracker {
key_exchanges: ArcSwapAny<Arc<Vec<(asym::Key, asym::KeyExchange)>>>, key_exchanges: ArcSwapAny<Arc<Vec<(asym::Key, asym::KeyExchange)>>>,
ciphers: ArcSwapAny<Arc<Vec<CipherKind>>>, ciphers: ArcSwapAny<Arc<Vec<CipherKind>>>,
/// ephemeral keys used server side in key exchange /// ephemeral keys used server side in key exchange
@ -58,11 +61,11 @@ pub struct Tracker {
hshake_cli: ArcSwapAny<Arc<Vec<HandshakeClient>>>, hshake_cli: ArcSwapAny<Arc<Vec<HandshakeClient>>>,
} }
#[allow(unsafe_code)] #[allow(unsafe_code)]
unsafe impl Send for Tracker {} unsafe impl Send for HandshakeTracker {}
#[allow(unsafe_code)] #[allow(unsafe_code)]
unsafe impl Sync for Tracker {} unsafe impl Sync for HandshakeTracker {}
impl Tracker { impl HandshakeTracker {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
ciphers: ArcSwapAny::new(Arc::new(Vec::new())), ciphers: ArcSwapAny::new(Arc::new(Vec::new())),

307
src/inner/worker.rs Normal file
View File

@ -0,0 +1,307 @@
//! Worker thread implementation
use crate::{
auth::TokenChecker,
connection::{
self,
handshake::{
self,
dirsync::{self, DirSync},
Handshake, HandshakeClient, HandshakeData,
},
socket::{UdpClient, UdpServer},
ConnList, Connection, IDSend, Packet, ID,
},
enc::sym::Secret,
inner::{HandshakeAction, HandshakeTracker},
};
use ::std::{sync::Arc, vec::Vec};
/// This worker must be cpu-pinned
use ::tokio::{net::UdpSocket, sync::Mutex};
use std::net::SocketAddr;
/// Track a raw Udp packet
pub(crate) struct RawUdp {
pub src: UdpClient,
pub dst: UdpServer,
pub data: Vec<u8>,
pub packet: Packet,
}
pub(crate) enum Work {
Recv(RawUdp),
}
/// Actual worker implementation.
pub(crate) struct Worker {
// PERF: rand uses syscalls. how to do that async?
rand: ::ring::rand::SystemRandom,
stop_working: ::tokio::sync::broadcast::Receiver<bool>,
token_check: Option<Arc<Mutex<TokenChecker>>>,
sockets: Vec<UdpSocket>,
queue: ::async_channel::Receiver<Work>,
thread_channels: Vec<::async_channel::Sender<Work>>,
connections: ConnList,
handshakes: HandshakeTracker,
}
impl Worker {
pub(crate) async fn new(
stop_working: ::tokio::sync::broadcast::Receiver<bool>,
token_check: Option<Arc<Mutex<TokenChecker>>>,
socket_addrs: Vec<::std::net::SocketAddr>,
queue: ::async_channel::Receiver<Work>,
) -> ::std::io::Result<Self> {
// bind all sockets again so that we can easily
// send without sharing resources
// in the future we will want to have a thread-local listener too,
// but before that we need ebpf to pin a connection to a thread
// directly from the kernel
let socket_binding =
socket_addrs.into_iter().map(|s_addr| async move {
let socket = ::tokio::spawn(connection::socket::bind_udp(
s_addr.clone(),
))
.await??;
Ok(socket)
});
let sockets_bind_res =
::futures::future::join_all(socket_binding).await;
let sockets: Result<Vec<UdpSocket>, ::std::io::Error> =
sockets_bind_res
.into_iter()
.map(|s_res| match s_res {
Ok(s) => Ok(s),
Err(e) => {
::tracing::error!("Worker can't bind on socket: {}", e);
Err(e)
}
})
.collect();
let sockets = match sockets {
Ok(sockets) => sockets,
Err(e) => {
return Err(e);
}
};
Ok(Self {
rand: ::ring::rand::SystemRandom::new(),
stop_working,
token_check,
sockets,
queue,
thread_channels: Vec::new(),
connections: ConnList::new(),
handshakes: HandshakeTracker::new(),
})
}
pub(crate) async fn work_loop(&mut self) {
loop {
let work = ::tokio::select! {
_done = self.stop_working.recv() => {
break;
}
maybe_work = self.queue.recv() => {
match maybe_work {
Ok(work) => work,
Err(_) => break,
}
}
};
match work {
//TODO: reconf message to add channels
Work::Recv(pkt) => {
self.recv(pkt).await;
}
}
}
}
/// Read and do stuff with the raw udp packet
async fn recv(&mut self, mut udp: RawUdp) {
if udp.packet.id.is_handshake() {
let handshake = match Handshake::deserialize(&udp.data[8..]) {
Ok(handshake) => handshake,
Err(e) => {
::tracing::warn!("Handshake parsing: {}", e);
return;
}
};
let action = match self
.handshakes
.recv_handshake(handshake, &mut udp.data[8..])
{
Ok(action) => action,
Err(err) => {
::tracing::debug!("Handshake recv error {}", err);
return;
}
};
match action {
HandshakeAction::AuthNeeded(authinfo) => {
let token_check = match self.token_check.as_ref() {
Some(token_check) => token_check,
None => {
::tracing::error!(
"Authentication requested but \
we have no token checker"
);
return;
}
};
let req;
if let HandshakeData::DirSync(DirSync::Req(r)) =
authinfo.handshake.data
{
req = r;
} else {
::tracing::error!("AuthInfo on non DS::Req");
return;
}
use dirsync::ReqInner;
let req_data = match req.data {
ReqInner::ClearText(req_data) => req_data,
_ => {
::tracing::error!(
"token_check: expected ClearText"
);
return;
}
};
let is_authenticated = {
let tk_check = token_check.lock().await;
tk_check(
req_data.auth.user,
req_data.auth.token,
req_data.auth.service_id,
req_data.auth.domain,
)
.await
};
let is_authenticated = match is_authenticated {
Ok(is_authenticated) => is_authenticated,
Err(_) => {
::tracing::error!("error in token auth");
// TODO: retry?
return;
}
};
if !is_authenticated {
::tracing::warn!(
"Wrong authentication for user {:?}",
req_data.auth.user
);
// TODO: error response
return;
}
// Client has correctly authenticated
// TODO: contact the service, get the key and
// connection ID
let srv_conn_id = ID::new_rand(&self.rand);
let srv_secret = Secret::new_rand(&self.rand);
let head_len = req.cipher.nonce_len();
let tag_len = req.cipher.tag_len();
let mut raw_conn = Connection::new(
authinfo.hkdf,
req.cipher,
connection::Role::Server,
&self.rand,
);
raw_conn.id_send = IDSend(req_data.id);
// track connection
let auth_conn = self.connections.reserve_first(raw_conn);
let resp_data = dirsync::RespData {
client_nonce: req_data.nonce,
id: auth_conn.id_recv.0,
service_id: srv_conn_id,
service_key: srv_secret,
};
use crate::enc::sym::AAD;
// no aad for now
let aad = AAD(&mut []);
use dirsync::RespInner;
let resp = dirsync::Resp {
client_key_id: req_data.client_key_id,
data: RespInner::ClearText(resp_data),
};
let offset_to_encrypt = resp.encrypted_offset();
let encrypt_until =
offset_to_encrypt + resp.encrypted_length() + tag_len.0;
let resp_handshake = Handshake::new(
HandshakeData::DirSync(DirSync::Resp(resp)),
);
use connection::{PacketData, ID};
let packet = Packet {
id: ID::new_handshake(),
data: PacketData::Handshake(resp_handshake),
};
let mut raw_out = Vec::<u8>::with_capacity(packet.len());
packet.serialize(head_len, tag_len, &mut raw_out);
if let Err(e) = auth_conn.cipher_send.encrypt(
aad,
&mut raw_out[offset_to_encrypt..encrypt_until],
) {
::tracing::error!("can't encrypt: {:?}", e);
return;
}
self.send_packet(raw_out, udp.src, udp.dst).await;
return;
}
HandshakeAction::ClientConnect(mut cci) => {
let ds_resp;
if let HandshakeData::DirSync(DirSync::Resp(resp)) =
cci.handshake.data
{
ds_resp = resp;
} else {
::tracing::error!("ClientConnect on non DS::Resp");
return;
}
// track connection
use handshake::dirsync;
let resp_data;
if let dirsync::RespInner::ClearText(r_data) = ds_resp.data
{
resp_data = r_data;
} else {
::tracing::error!(
"ClientConnect on non DS::Resp::ClearText"
);
return;
}
// FIXME: conn tracking and arc counting
let conn = Arc::get_mut(&mut cci.connection).unwrap();
conn.id_send = IDSend(resp_data.id);
todo!();
}
_ => {}
};
}
// copy packet, spawn
todo!();
}
async fn send_packet(
&self,
data: Vec<u8>,
client: UdpClient,
server: UdpServer,
) {
let src_sock = match self
.sockets
.iter()
.find(|&s| s.local_addr().unwrap() == server.0)
{
Some(src_sock) => src_sock,
None => {
::tracing::error!(
"Can't send packet: Server changed listening ip!"
);
return;
}
};
src_sock.send_to(&data, client.0);
}
}

View File

@ -20,32 +20,21 @@ pub mod dnssec;
pub mod enc; pub mod enc;
mod inner; mod inner;
use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption};
use ::std::{ use ::std::{
net::SocketAddr, net::SocketAddr,
pin::Pin,
sync::{Arc, Weak}, sync::{Arc, Weak},
vec::{self, Vec}, vec::Vec,
};
use ::tokio::{
macros::support::Future, net::UdpSocket, sync::RwLock, task::JoinHandle,
}; };
use ::tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle};
use crate::{ use crate::{
auth::TokenChecker,
connection::{ connection::{
handshake::{ handshake,
self, dirsync::DirSync, Handshake, HandshakeClient, HandshakeData,
HandshakeServer,
},
socket::{SocketList, UdpClient, UdpServer}, socket::{SocketList, UdpClient, UdpServer},
ConnList, Connection, IDSend, Packet, Packet,
}, },
enc::{ inner::worker::{RawUdp, Work, Worker},
asym,
hkdf::HkdfSha3,
sym::{CipherKind, CipherRecv, CipherSend, HeadLen, TagLen},
},
inner::HandshakeAction,
}; };
pub use config::Config; pub use config::Config;
@ -55,6 +44,9 @@ pub enum Error {
/// The library was not initialized (run .start()) /// The library was not initialized (run .start())
#[error("not initialized")] #[error("not initialized")]
NotInitialized, NotInitialized,
/// Error in setting up worker threads
#[error("Setup err: {0}")]
Setup(String),
/// General I/O error /// General I/O error
#[error("IO: {0:?}")] #[error("IO: {0:?}")]
IO(#[from] ::std::io::Error), IO(#[from] ::std::io::Error),
@ -69,26 +61,6 @@ pub enum Error {
Key(#[from] crate::enc::Error), Key(#[from] crate::enc::Error),
} }
type TokenChecker =
fn(
user: auth::UserID,
token: auth::Token,
service_id: auth::ServiceID,
domain: auth::Domain,
) -> ::futures::future::BoxFuture<'static, Result<bool, ()>>;
/// Track a raw Udp packet
struct RawUdp {
src: UdpClient,
dst: UdpServer,
data: Vec<u8>,
packet: Packet,
}
enum Work {
Recv(RawUdp),
}
/// Instance of a fenrir endpoint /// Instance of a fenrir endpoint
#[allow(missing_copy_implementations, missing_debug_implementations)] #[allow(missing_copy_implementations, missing_debug_implementations)]
pub struct Fenrir { pub struct Fenrir {
@ -101,14 +73,11 @@ pub struct Fenrir {
/// Broadcast channel to tell workers to stop working /// Broadcast channel to tell workers to stop working
stop_working: ::tokio::sync::broadcast::Sender<bool>, stop_working: ::tokio::sync::broadcast::Sender<bool>,
/// Private keys used in the handshake /// Private keys used in the handshake
_inner: Arc<inner::Tracker>, _inner: Arc<inner::HandshakeTracker>,
/// where to ask for token check /// where to ask for token check
token_check: Arc<ArcSwapOption<TokenChecker>>, token_check: Option<Arc<::tokio::sync::Mutex<TokenChecker>>>,
// PERF: rand uses syscalls. should we do that async? // PERF: rand uses syscalls. should we do that async?
rand: ::ring::rand::SystemRandom, rand: ::ring::rand::SystemRandom,
/// list of Established connections
connections: Arc<RwLock<ConnList>>,
_myself: Weak<Self>,
// TODO: find a way to both increase and decrease these two in a thread-safe // TODO: find a way to both increase and decrease these two in a thread-safe
// manner // manner
_thread_pool: Vec<::std::thread::JoinHandle<()>>, _thread_pool: Vec<::std::thread::JoinHandle<()>>,
@ -125,28 +94,30 @@ impl Drop for Fenrir {
impl Fenrir { impl Fenrir {
/// Create a new Fenrir endpoint /// Create a new Fenrir endpoint
pub fn new(config: &Config) -> Result<Arc<Self>, Error> { pub fn new(config: &Config) -> Result<Self, Error> {
let listen_num = config.listen.len(); let listen_num = config.listen.len();
let (sender, _) = ::tokio::sync::broadcast::channel(1); let (sender, _) = ::tokio::sync::broadcast::channel(1);
let (work_send, work_recv) = ::async_channel::unbounded::<Work>(); let (work_send, work_recv) = ::async_channel::unbounded::<Work>();
let endpoint = Arc::new_cyclic(|myself| Fenrir { let endpoint = Fenrir {
cfg: config.clone(), cfg: config.clone(),
sockets: SocketList::new(), sockets: SocketList::new(),
dnssec: None, dnssec: None,
stop_working: sender, stop_working: sender,
_inner: Arc::new(inner::Tracker::new()), _inner: Arc::new(inner::HandshakeTracker::new()),
token_check: Arc::new(ArcSwapOption::from(None)), token_check: None,
rand: ::ring::rand::SystemRandom::new(), rand: ::ring::rand::SystemRandom::new(),
connections: Arc::new(RwLock::new(ConnList::new())),
_myself: myself.clone(),
_thread_pool: Vec::new(), _thread_pool: Vec::new(),
_thread_work: Arc::new(Vec::new()), _thread_work: Arc::new(Vec::new()),
}); };
Ok(endpoint) Ok(endpoint)
} }
/// Start all workers, listeners /// Start all workers, listeners
pub async fn start(&mut self) -> Result<(), Error> { pub async fn start(
&mut self,
tokio_rt: Arc<::tokio::runtime::Runtime>,
) -> Result<(), Error> {
self.start_work_threads_pinned(tokio_rt).await?;
if let Err(e) = self.add_sockets().await { if let Err(e) = self.add_sockets().await {
self.stop().await; self.stop().await;
return Err(e.into()); return Err(e.into());
@ -159,17 +130,25 @@ impl Fenrir {
/// asyncronous version for Drop /// asyncronous version for Drop
fn stop_sync(&mut self) { fn stop_sync(&mut self) {
let _ = self.stop_working.send(true); let _ = self.stop_working.send(true);
// FIXME: wait for thread pool to actually stop
let mut toempty_sockets = self.sockets.rm_all(); let mut toempty_sockets = self.sockets.rm_all();
let task = ::tokio::task::spawn(toempty_sockets.stop_all()); let task = ::tokio::task::spawn(toempty_sockets.stop_all());
let _ = ::futures::executor::block_on(task); let _ = ::futures::executor::block_on(task);
let mut old_thread_pool = Vec::new();
::std::mem::swap(&mut self._thread_pool, &mut old_thread_pool);
old_thread_pool.into_iter().map(|th| th.join());
self.dnssec = None; self.dnssec = None;
} }
/// Stop all workers, listeners /// Stop all workers, listeners
pub async fn stop(&mut self) { pub async fn stop(&mut self) {
let _ = self.stop_working.send(true); let _ = self.stop_working.send(true);
// FIXME: wait for thread pool to actually stop
let mut toempty_sockets = self.sockets.rm_all(); let mut toempty_sockets = self.sockets.rm_all();
toempty_sockets.stop_all().await; toempty_sockets.stop_all().await;
let mut old_thread_pool = Vec::new();
::std::mem::swap(&mut self._thread_pool, &mut old_thread_pool);
old_thread_pool.into_iter().map(|th| th.join());
self.dnssec = None; self.dnssec = None;
} }
/// Add all UDP sockets found in config /// Add all UDP sockets found in config
@ -271,14 +250,18 @@ impl Fenrir {
/// Start one working thread for each physical cpu /// Start one working thread for each physical cpu
/// threads are pinned to each cpu core. /// threads are pinned to each cpu core.
/// Work will be divided and rerouted so that there is no need to lock /// Work will be divided and rerouted so that there is no need to lock
pub async fn start_work_threads_pinned( async fn start_work_threads_pinned(
&mut self, &mut self,
tokio_rt: Arc<::tokio::runtime::Runtime>, tokio_rt: Arc<::tokio::runtime::Runtime>,
) -> ::std::result::Result<(), ()> { ) -> ::std::result::Result<(), Error> {
use ::std::sync::Mutex; use ::std::sync::Mutex;
let hw_topology = match ::hwloc2::Topology::new() { let hw_topology = match ::hwloc2::Topology::new() {
Some(hw_topology) => Arc::new(Mutex::new(hw_topology)), Some(hw_topology) => Arc::new(Mutex::new(hw_topology)),
None => return Err(()), None => {
return Err(Error::Setup(
"Can't get hardware topology".to_owned(),
))
}
}; };
let cores; let cores;
{ {
@ -287,20 +270,36 @@ impl Fenrir {
.objects_with_type(&::hwloc2::ObjectType::Core) .objects_with_type(&::hwloc2::ObjectType::Core)
{ {
Ok(all_cores) => all_cores, Ok(all_cores) => all_cores,
Err(_) => return Err(()), Err(_) => {
return Err(Error::Setup("can't list cores".to_owned()))
}
}; };
cores = all_cores.len(); cores = all_cores.len();
if cores <= 0 || !topology_lock.support().cpu().set_thread() { if cores <= 0 || !topology_lock.support().cpu().set_thread() {
::tracing::error!("No support for CPU pinning"); ::tracing::error!("No support for CPU pinning");
return Err(()); return Err(Error::Setup("No cpu pinning support".to_owned()));
} }
} }
for core in 0..cores { for core in 0..cores {
::tracing::debug!("Spawning thread {}", core); ::tracing::debug!("Spawning thread {}", core);
let th_topology = hw_topology.clone(); let th_topology = hw_topology.clone();
let th_tokio_rt = tokio_rt.clone(); let th_tokio_rt = tokio_rt.clone();
let th_myself = self._myself.upgrade().unwrap();
let (work_send, work_recv) = ::async_channel::unbounded::<Work>(); let (work_send, work_recv) = ::async_channel::unbounded::<Work>();
let mut worker = match Worker::new(
self.stop_working.subscribe(),
self.token_check.clone(),
self.cfg.listen.clone(),
work_recv,
)
.await
{
Ok(worker) => worker,
Err(e) => {
::tracing::error!("can't start worker");
return Err(Error::IO(e));
}
};
let join_handle = ::std::thread::spawn(move || { let join_handle = ::std::thread::spawn(move || {
// bind to a specific core // bind to a specific core
let th_pinning; let th_pinning;
@ -322,13 +321,10 @@ impl Fenrir {
return; return;
} }
} }
// finally run the main listener. make sure things stay on this // finally run the main worker.
// thread // make sure things stay on this thread
let tk_local = ::tokio::task::LocalSet::new(); let tk_local = ::tokio::task::LocalSet::new();
let _ = tk_local.block_on( let _ = tk_local.block_on(&th_tokio_rt, worker.work_loop());
&th_tokio_rt,
Self::work_loop_thread(th_myself, work_recv),
);
}); });
loop { loop {
let queues_lock = match Arc::get_mut(&mut self._thread_work) { let queues_lock = match Arc::get_mut(&mut self._thread_work) {
@ -348,219 +344,4 @@ impl Fenrir {
} }
Ok(()) Ok(())
} }
async fn work_loop_thread(
self: Arc<Self>,
work_recv: ::async_channel::Receiver<Work>,
) {
let mut stop_working = self.stop_working.subscribe();
loop {
let work = ::tokio::select! {
_done = stop_working.recv() => {
break;
}
maybe_work = work_recv.recv() => {
match maybe_work {
Ok(work) => work,
Err(_) => break,
}
}
};
match work {
Work::Recv(pkt) => {
self.recv(pkt).await;
}
}
}
}
/// Read and do stuff with the raw udp packet
async fn recv(&self, mut udp: RawUdp) {
if udp.packet.id.is_handshake() {
let handshake = match Handshake::deserialize(&udp.data[8..]) {
Ok(handshake) => handshake,
Err(e) => {
::tracing::warn!("Handshake parsing: {}", e);
return;
}
};
let action =
match self._inner.recv_handshake(handshake, &mut udp.data[8..])
{
Ok(action) => action,
Err(err) => {
::tracing::debug!("Handshake recv error {}", err);
return;
}
};
match action {
HandshakeAction::AuthNeeded(authinfo) => {
let tk_check = match self.token_check.load_full() {
Some(tk_check) => tk_check,
None => {
::tracing::error!(
"Handshake received, but no tocken_checker"
);
return;
}
};
use handshake::{
dirsync::{self, DirSync},
HandshakeData,
};
let req;
if let HandshakeData::DirSync(DirSync::Req(r)) =
authinfo.handshake.data
{
req = r;
} else {
::tracing::error!("AuthInfo on non DS::Req");
return;
}
use dirsync::ReqInner;
let req_data = match req.data {
ReqInner::ClearText(req_data) => req_data,
_ => {
::tracing::error!(
"token_check: expected ClearText"
);
return;
}
};
let is_authenticated = match tk_check(
req_data.auth.user,
req_data.auth.token,
req_data.auth.service_id,
req_data.auth.domain,
)
.await
{
Ok(is_authenticated) => is_authenticated,
Err(_) => {
::tracing::error!("error in token auth");
// TODO: retry?
return;
}
};
if !is_authenticated {
::tracing::warn!(
"Wrong authentication for user {:?}",
req_data.auth.user
);
// TODO: error response
return;
}
// Client has correctly authenticated
// TODO: contact the service, get the key and
// connection ID
let srv_conn_id = connection::ID::new_rand(&self.rand);
let srv_secret = enc::sym::Secret::new_rand(&self.rand);
let head_len = req.cipher.nonce_len();
let tag_len = req.cipher.tag_len();
let mut raw_conn = Connection::new(
authinfo.hkdf,
req.cipher,
connection::Role::Server,
&self.rand,
);
raw_conn.id_send = IDSend(req_data.id);
// track connection
let auth_conn = {
let mut lock = self.connections.write().await;
lock.reserve_first(raw_conn)
};
let resp_data = dirsync::RespData {
client_nonce: req_data.nonce,
id: auth_conn.id_recv.0,
service_id: srv_conn_id,
service_key: srv_secret,
};
use crate::enc::sym::AAD;
// no aad for now
let aad = AAD(&mut []);
use dirsync::RespInner;
let resp = dirsync::Resp {
client_key_id: req_data.client_key_id,
data: RespInner::ClearText(resp_data),
};
let offset_to_encrypt = resp.encrypted_offset();
let encrypt_until =
offset_to_encrypt + resp.encrypted_length() + tag_len.0;
let resp_handshake = Handshake::new(
HandshakeData::DirSync(DirSync::Resp(resp)),
);
use connection::{PacketData, ID};
let packet = Packet {
id: ID::new_handshake(),
data: PacketData::Handshake(resp_handshake),
};
let mut raw_out = Vec::<u8>::with_capacity(packet.len());
packet.serialize(head_len, tag_len, &mut raw_out);
if let Err(e) = auth_conn.cipher_send.encrypt(
aad,
&mut raw_out[offset_to_encrypt..encrypt_until],
) {
::tracing::error!("can't encrypt: {:?}", e);
return;
}
self.send_packet(raw_out, udp.src, udp.dst).await;
return;
}
HandshakeAction::ClientConnect(mut cci) => {
let ds_resp;
if let HandshakeData::DirSync(DirSync::Resp(resp)) =
cci.handshake.data
{
ds_resp = resp;
} else {
::tracing::error!("ClientConnect on non DS::Resp");
return;
}
// track connection
use handshake::dirsync;
let resp_data;
if let dirsync::RespInner::ClearText(r_data) = ds_resp.data
{
resp_data = r_data;
} else {
::tracing::error!(
"ClientConnect on non DS::Resp::ClearText"
);
return;
}
// FIXME: conn tracking and arc counting
let conn = Arc::get_mut(&mut cci.connection).unwrap();
conn.id_send = IDSend(resp_data.id);
todo!();
}
_ => {}
};
}
// copy packet, spawn
todo!();
}
async fn send_packet(
&self,
data: Vec<u8>,
client: UdpClient,
server: UdpServer,
) {
let src_sock;
{
let sockets = self.sockets.lock();
src_sock = match sockets.find(server) {
Some(src_sock) => src_sock,
None => {
::tracing::error!(
"Can't send packet: Server changed listening ip!"
);
return;
}
};
}
src_sock.send_to(&data, client.0);
}
} }