libFenrir/src/inner/worker.rs
Luca Fulchir a3430f1813
Initial connections: share auth.server connection
Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
2023-05-28 18:23:14 +02:00

383 lines
14 KiB
Rust

//! Worker thread implementation
use crate::{
auth::{ServiceID, TokenChecker},
connection::{
self,
handshake::{
dirsync::{self, DirSync},
Handshake, HandshakeData,
},
socket::{UdpClient, UdpServer},
ConnList, Connection, IDSend, Packet,
},
dnssec,
enc::{asym::PubKey, hkdf::HkdfSha3, sym::Secret},
inner::{HandshakeAction, HandshakeTracker, ThreadTracker},
};
use ::std::{rc::Rc, sync::Arc, vec::Vec};
/// This worker must be cpu-pinned
use ::tokio::{
net::UdpSocket,
sync::{oneshot, Mutex},
};
/// Track a raw Udp packet
pub(crate) struct RawUdp {
pub src: UdpClient,
pub dst: UdpServer,
pub data: Vec<u8>,
pub packet: Packet,
}
pub(crate) enum ConnectionResult {
Failed(crate::Error),
Established((PubKey, IDSend)),
}
pub(crate) enum Work {
/// ask the thread to report to the main thread the total number of
/// connections present
CountConnections(oneshot::Sender<usize>),
Connect((oneshot::Sender<ConnectionResult>, dnssec::Record, ServiceID)),
Recv(RawUdp),
}
pub(crate) enum WorkAnswer {
UNUSED,
}
/// Actual worker implementation.
pub(crate) struct Worker {
thread_id: ThreadTracker,
// PERF: rand uses syscalls. how to do that async?
rand: ::ring::rand::SystemRandom,
stop_working: ::tokio::sync::broadcast::Receiver<bool>,
token_check: Option<Arc<Mutex<TokenChecker>>>,
sockets: Vec<UdpSocket>,
queue: ::async_channel::Receiver<Work>,
thread_channels: Vec<::async_channel::Sender<Work>>,
connections: ConnList,
handshakes: HandshakeTracker,
}
impl Worker {
pub(crate) async fn new_and_loop(
thread_id: ThreadTracker,
stop_working: ::tokio::sync::broadcast::Receiver<bool>,
token_check: Option<Arc<Mutex<TokenChecker>>>,
socket_addrs: Vec<::std::net::SocketAddr>,
queue: ::async_channel::Receiver<Work>,
) -> ::std::io::Result<()> {
// TODO: get a channel to send back information, and send the error
let mut worker = Self::new(
thread_id,
stop_working,
token_check,
socket_addrs,
queue,
)
.await?;
worker.work_loop().await;
Ok(())
}
pub(crate) async fn new(
thread_id: ThreadTracker,
stop_working: ::tokio::sync::broadcast::Receiver<bool>,
token_check: Option<Arc<Mutex<TokenChecker>>>,
socket_addrs: Vec<::std::net::SocketAddr>,
queue: ::async_channel::Receiver<Work>,
) -> ::std::io::Result<Self> {
// bind all sockets again so that we can easily
// send without sharing resources
// in the future we will want to have a thread-local listener too,
// but before that we need ebpf to pin a connection to a thread
// directly from the kernel
let socket_binding =
socket_addrs.into_iter().map(|s_addr| async move {
let socket = ::tokio::spawn(connection::socket::bind_udp(
s_addr.clone(),
))
.await??;
Ok(socket)
});
let sockets_bind_res =
::futures::future::join_all(socket_binding).await;
let sockets: Result<Vec<UdpSocket>, ::std::io::Error> =
sockets_bind_res
.into_iter()
.map(|s_res| match s_res {
Ok(s) => Ok(s),
Err(e) => {
::tracing::error!("Worker can't bind on socket: {}", e);
Err(e)
}
})
.collect();
let sockets = match sockets {
Ok(sockets) => sockets,
Err(e) => {
return Err(e);
}
};
Ok(Self {
thread_id,
rand: ::ring::rand::SystemRandom::new(),
stop_working,
token_check,
sockets,
queue,
thread_channels: Vec::new(),
connections: ConnList::new(thread_id),
handshakes: HandshakeTracker::new(thread_id),
})
}
pub(crate) async fn work_loop(&mut self) {
loop {
let work = ::tokio::select! {
_done = self.stop_working.recv() => {
break;
}
maybe_work = self.queue.recv() => {
match maybe_work {
Ok(work) => work,
Err(_) => break,
}
}
};
match work {
Work::CountConnections(sender) => {
let conn_num = self.connections.len();
let _ = sender.send(conn_num);
}
Work::Connect((send_res, dnssec_record, service_id)) => {
todo!()
}
//TODO: reconf message to add channels
Work::Recv(pkt) => {
self.recv(pkt).await;
}
}
}
}
/// Read and do stuff with the raw udp packet
async fn recv(&mut self, mut udp: RawUdp) {
if udp.packet.id.is_handshake() {
let handshake = match Handshake::deserialize(&udp.data[8..]) {
Ok(handshake) => handshake,
Err(e) => {
::tracing::warn!("Handshake parsing: {}", e);
return;
}
};
let action = match self
.handshakes
.recv_handshake(handshake, &mut udp.data[8..])
{
Ok(action) => action,
Err(err) => {
::tracing::debug!("Handshake recv error {}", err);
return;
}
};
match action {
HandshakeAction::AuthNeeded(authinfo) => {
let token_check = match self.token_check.as_ref() {
Some(token_check) => token_check,
None => {
::tracing::error!(
"Authentication requested but \
we have no token checker"
);
return;
}
};
let req;
if let HandshakeData::DirSync(DirSync::Req(r)) =
authinfo.handshake.data
{
req = r;
} else {
::tracing::error!("AuthInfo on non DS::Req");
return;
}
use dirsync::ReqInner;
let req_data = match req.data {
ReqInner::ClearText(req_data) => req_data,
_ => {
::tracing::error!(
"token_check: expected ClearText"
);
return;
}
};
// FIXME: This part can take a while,
// we should just spawn it probably
let is_authenticated = {
let tk_check = token_check.lock().await;
tk_check(
req_data.auth.user,
req_data.auth.token,
req_data.auth.service_id,
req_data.auth.domain,
)
.await
};
let is_authenticated = match is_authenticated {
Ok(is_authenticated) => is_authenticated,
Err(_) => {
::tracing::error!("error in token auth");
// TODO: retry?
return;
}
};
if !is_authenticated {
::tracing::warn!(
"Wrong authentication for user {:?}",
req_data.auth.user
);
// TODO: error response
return;
}
// Client has correctly authenticated
// TODO: contact the service, get the key and
// connection ID
let srv_conn_id = ID::new_rand(&self.rand);
let srv_secret = Secret::new_rand(&self.rand);
let head_len = req.cipher.nonce_len();
let tag_len = req.cipher.tag_len();
let mut raw_conn = Connection::new(
authinfo.hkdf,
req.cipher,
connection::Role::Server,
&self.rand,
);
raw_conn.id_send = IDSend(req_data.id);
// track connection
let auth_conn = self.connections.reserve_first(raw_conn);
let resp_data = dirsync::RespData {
client_nonce: req_data.nonce,
id: auth_conn.id_recv.0,
service_connection_id: srv_conn_id,
service_key: srv_secret,
};
use crate::enc::sym::AAD;
// no aad for now
let aad = AAD(&mut []);
use dirsync::RespInner;
let resp = dirsync::Resp {
client_key_id: req_data.client_key_id,
data: RespInner::ClearText(resp_data),
};
let offset_to_encrypt = resp.encrypted_offset();
let encrypt_until =
offset_to_encrypt + resp.encrypted_length() + tag_len.0;
let resp_handshake = Handshake::new(
HandshakeData::DirSync(DirSync::Resp(resp)),
);
use connection::{PacketData, ID};
let packet = Packet {
id: ID::new_handshake(),
data: PacketData::Handshake(resp_handshake),
};
let mut raw_out = Vec::<u8>::with_capacity(packet.len());
packet.serialize(head_len, tag_len, &mut raw_out);
if let Err(e) = auth_conn.cipher_send.encrypt(
aad,
&mut raw_out[offset_to_encrypt..encrypt_until],
) {
::tracing::error!("can't encrypt: {:?}", e);
return;
}
self.send_packet(raw_out, udp.src, udp.dst).await;
return;
}
HandshakeAction::ClientConnect(mut cci) => {
let ds_resp;
if let HandshakeData::DirSync(DirSync::Resp(resp)) =
cci.handshake.data
{
ds_resp = resp;
} else {
::tracing::error!("ClientConnect on non DS::Resp");
return;
}
// track connection
let resp_data;
if let dirsync::RespInner::ClearText(r_data) = ds_resp.data
{
resp_data = r_data;
} else {
::tracing::error!(
"ClientConnect on non DS::Resp::ClearText"
);
return;
}
{
let conn = Rc::get_mut(&mut cci.connection).unwrap();
conn.id_send = IDSend(resp_data.id);
}
// track the connection to the authentication server
if self.connections.track(cci.connection.clone()).is_err() {
self.connections.delete(cci.connection.id_recv);
}
if cci.connection.id_recv.0
== resp_data.service_connection_id
{
// the user asked a single connection
// to the authentication server, without any additional
// service. No more connections to setup
return;
}
// create and track the connection to the service
// SECURITY:
//FIXME: the Secret should be XORed with the client stored
// secret (if any)
let hkdf = HkdfSha3::new(
cci.service_id.as_bytes(),
resp_data.service_key,
);
let mut service_connection = Connection::new(
hkdf,
cci.connection.cipher_recv.kind(),
connection::Role::Client,
&self.rand,
);
service_connection.id_recv = cci.service_connection_id;
service_connection.id_send =
IDSend(resp_data.service_connection_id);
let _ = self.connections.track(service_connection.into());
return;
}
_ => {}
};
}
// copy packet, spawn
todo!();
}
async fn send_packet(
&self,
data: Vec<u8>,
client: UdpClient,
server: UdpServer,
) {
let src_sock = match self
.sockets
.iter()
.find(|&s| s.local_addr().unwrap() == server.0)
{
Some(src_sock) => src_sock,
None => {
::tracing::error!(
"Can't send packet: Server changed listening ip!"
);
return;
}
};
let _ = src_sock.send_to(&data, client.0).await;
}
}