383 lines
14 KiB
Rust
383 lines
14 KiB
Rust
//! Worker thread implementation
|
|
use crate::{
|
|
auth::{ServiceID, TokenChecker},
|
|
connection::{
|
|
self,
|
|
handshake::{
|
|
dirsync::{self, DirSync},
|
|
Handshake, HandshakeData,
|
|
},
|
|
socket::{UdpClient, UdpServer},
|
|
ConnList, Connection, IDSend, Packet,
|
|
},
|
|
dnssec,
|
|
enc::{asym::PubKey, hkdf::HkdfSha3, sym::Secret},
|
|
inner::{HandshakeAction, HandshakeTracker, ThreadTracker},
|
|
};
|
|
use ::std::{rc::Rc, sync::Arc, vec::Vec};
|
|
/// This worker must be cpu-pinned
|
|
use ::tokio::{
|
|
net::UdpSocket,
|
|
sync::{oneshot, Mutex},
|
|
};
|
|
|
|
/// Track a raw Udp packet
|
|
pub(crate) struct RawUdp {
|
|
pub src: UdpClient,
|
|
pub dst: UdpServer,
|
|
pub data: Vec<u8>,
|
|
pub packet: Packet,
|
|
}
|
|
|
|
pub(crate) enum ConnectionResult {
|
|
Failed(crate::Error),
|
|
Established((PubKey, IDSend)),
|
|
}
|
|
|
|
pub(crate) enum Work {
|
|
/// ask the thread to report to the main thread the total number of
|
|
/// connections present
|
|
CountConnections(oneshot::Sender<usize>),
|
|
Connect((oneshot::Sender<ConnectionResult>, dnssec::Record, ServiceID)),
|
|
Recv(RawUdp),
|
|
}
|
|
pub(crate) enum WorkAnswer {
|
|
UNUSED,
|
|
}
|
|
|
|
/// Actual worker implementation.
|
|
pub(crate) struct Worker {
|
|
thread_id: ThreadTracker,
|
|
// PERF: rand uses syscalls. how to do that async?
|
|
rand: ::ring::rand::SystemRandom,
|
|
stop_working: ::tokio::sync::broadcast::Receiver<bool>,
|
|
token_check: Option<Arc<Mutex<TokenChecker>>>,
|
|
sockets: Vec<UdpSocket>,
|
|
queue: ::async_channel::Receiver<Work>,
|
|
thread_channels: Vec<::async_channel::Sender<Work>>,
|
|
connections: ConnList,
|
|
handshakes: HandshakeTracker,
|
|
}
|
|
|
|
impl Worker {
|
|
pub(crate) async fn new_and_loop(
|
|
thread_id: ThreadTracker,
|
|
stop_working: ::tokio::sync::broadcast::Receiver<bool>,
|
|
token_check: Option<Arc<Mutex<TokenChecker>>>,
|
|
socket_addrs: Vec<::std::net::SocketAddr>,
|
|
queue: ::async_channel::Receiver<Work>,
|
|
) -> ::std::io::Result<()> {
|
|
// TODO: get a channel to send back information, and send the error
|
|
let mut worker = Self::new(
|
|
thread_id,
|
|
stop_working,
|
|
token_check,
|
|
socket_addrs,
|
|
queue,
|
|
)
|
|
.await?;
|
|
worker.work_loop().await;
|
|
Ok(())
|
|
}
|
|
pub(crate) async fn new(
|
|
thread_id: ThreadTracker,
|
|
stop_working: ::tokio::sync::broadcast::Receiver<bool>,
|
|
token_check: Option<Arc<Mutex<TokenChecker>>>,
|
|
socket_addrs: Vec<::std::net::SocketAddr>,
|
|
queue: ::async_channel::Receiver<Work>,
|
|
) -> ::std::io::Result<Self> {
|
|
// bind all sockets again so that we can easily
|
|
// send without sharing resources
|
|
// in the future we will want to have a thread-local listener too,
|
|
// but before that we need ebpf to pin a connection to a thread
|
|
// directly from the kernel
|
|
let socket_binding =
|
|
socket_addrs.into_iter().map(|s_addr| async move {
|
|
let socket = ::tokio::spawn(connection::socket::bind_udp(
|
|
s_addr.clone(),
|
|
))
|
|
.await??;
|
|
Ok(socket)
|
|
});
|
|
let sockets_bind_res =
|
|
::futures::future::join_all(socket_binding).await;
|
|
let sockets: Result<Vec<UdpSocket>, ::std::io::Error> =
|
|
sockets_bind_res
|
|
.into_iter()
|
|
.map(|s_res| match s_res {
|
|
Ok(s) => Ok(s),
|
|
Err(e) => {
|
|
::tracing::error!("Worker can't bind on socket: {}", e);
|
|
Err(e)
|
|
}
|
|
})
|
|
.collect();
|
|
let sockets = match sockets {
|
|
Ok(sockets) => sockets,
|
|
Err(e) => {
|
|
return Err(e);
|
|
}
|
|
};
|
|
|
|
Ok(Self {
|
|
thread_id,
|
|
rand: ::ring::rand::SystemRandom::new(),
|
|
stop_working,
|
|
token_check,
|
|
sockets,
|
|
queue,
|
|
thread_channels: Vec::new(),
|
|
connections: ConnList::new(thread_id),
|
|
handshakes: HandshakeTracker::new(thread_id),
|
|
})
|
|
}
|
|
pub(crate) async fn work_loop(&mut self) {
|
|
loop {
|
|
let work = ::tokio::select! {
|
|
_done = self.stop_working.recv() => {
|
|
break;
|
|
}
|
|
maybe_work = self.queue.recv() => {
|
|
match maybe_work {
|
|
Ok(work) => work,
|
|
Err(_) => break,
|
|
}
|
|
}
|
|
};
|
|
match work {
|
|
Work::CountConnections(sender) => {
|
|
let conn_num = self.connections.len();
|
|
let _ = sender.send(conn_num);
|
|
}
|
|
Work::Connect((send_res, dnssec_record, service_id)) => {
|
|
todo!()
|
|
}
|
|
//TODO: reconf message to add channels
|
|
Work::Recv(pkt) => {
|
|
self.recv(pkt).await;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
/// Read and do stuff with the raw udp packet
|
|
async fn recv(&mut self, mut udp: RawUdp) {
|
|
if udp.packet.id.is_handshake() {
|
|
let handshake = match Handshake::deserialize(&udp.data[8..]) {
|
|
Ok(handshake) => handshake,
|
|
Err(e) => {
|
|
::tracing::warn!("Handshake parsing: {}", e);
|
|
return;
|
|
}
|
|
};
|
|
let action = match self
|
|
.handshakes
|
|
.recv_handshake(handshake, &mut udp.data[8..])
|
|
{
|
|
Ok(action) => action,
|
|
Err(err) => {
|
|
::tracing::debug!("Handshake recv error {}", err);
|
|
return;
|
|
}
|
|
};
|
|
match action {
|
|
HandshakeAction::AuthNeeded(authinfo) => {
|
|
let token_check = match self.token_check.as_ref() {
|
|
Some(token_check) => token_check,
|
|
None => {
|
|
::tracing::error!(
|
|
"Authentication requested but \
|
|
we have no token checker"
|
|
);
|
|
return;
|
|
}
|
|
};
|
|
let req;
|
|
if let HandshakeData::DirSync(DirSync::Req(r)) =
|
|
authinfo.handshake.data
|
|
{
|
|
req = r;
|
|
} else {
|
|
::tracing::error!("AuthInfo on non DS::Req");
|
|
return;
|
|
}
|
|
use dirsync::ReqInner;
|
|
let req_data = match req.data {
|
|
ReqInner::ClearText(req_data) => req_data,
|
|
_ => {
|
|
::tracing::error!(
|
|
"token_check: expected ClearText"
|
|
);
|
|
return;
|
|
}
|
|
};
|
|
// FIXME: This part can take a while,
|
|
// we should just spawn it probably
|
|
let is_authenticated = {
|
|
let tk_check = token_check.lock().await;
|
|
tk_check(
|
|
req_data.auth.user,
|
|
req_data.auth.token,
|
|
req_data.auth.service_id,
|
|
req_data.auth.domain,
|
|
)
|
|
.await
|
|
};
|
|
let is_authenticated = match is_authenticated {
|
|
Ok(is_authenticated) => is_authenticated,
|
|
Err(_) => {
|
|
::tracing::error!("error in token auth");
|
|
// TODO: retry?
|
|
return;
|
|
}
|
|
};
|
|
if !is_authenticated {
|
|
::tracing::warn!(
|
|
"Wrong authentication for user {:?}",
|
|
req_data.auth.user
|
|
);
|
|
// TODO: error response
|
|
return;
|
|
}
|
|
// Client has correctly authenticated
|
|
// TODO: contact the service, get the key and
|
|
// connection ID
|
|
let srv_conn_id = ID::new_rand(&self.rand);
|
|
let srv_secret = Secret::new_rand(&self.rand);
|
|
let head_len = req.cipher.nonce_len();
|
|
let tag_len = req.cipher.tag_len();
|
|
|
|
let mut raw_conn = Connection::new(
|
|
authinfo.hkdf,
|
|
req.cipher,
|
|
connection::Role::Server,
|
|
&self.rand,
|
|
);
|
|
raw_conn.id_send = IDSend(req_data.id);
|
|
// track connection
|
|
let auth_conn = self.connections.reserve_first(raw_conn);
|
|
|
|
let resp_data = dirsync::RespData {
|
|
client_nonce: req_data.nonce,
|
|
id: auth_conn.id_recv.0,
|
|
service_connection_id: srv_conn_id,
|
|
service_key: srv_secret,
|
|
};
|
|
use crate::enc::sym::AAD;
|
|
// no aad for now
|
|
let aad = AAD(&mut []);
|
|
|
|
use dirsync::RespInner;
|
|
let resp = dirsync::Resp {
|
|
client_key_id: req_data.client_key_id,
|
|
data: RespInner::ClearText(resp_data),
|
|
};
|
|
let offset_to_encrypt = resp.encrypted_offset();
|
|
let encrypt_until =
|
|
offset_to_encrypt + resp.encrypted_length() + tag_len.0;
|
|
let resp_handshake = Handshake::new(
|
|
HandshakeData::DirSync(DirSync::Resp(resp)),
|
|
);
|
|
use connection::{PacketData, ID};
|
|
let packet = Packet {
|
|
id: ID::new_handshake(),
|
|
data: PacketData::Handshake(resp_handshake),
|
|
};
|
|
let mut raw_out = Vec::<u8>::with_capacity(packet.len());
|
|
packet.serialize(head_len, tag_len, &mut raw_out);
|
|
|
|
if let Err(e) = auth_conn.cipher_send.encrypt(
|
|
aad,
|
|
&mut raw_out[offset_to_encrypt..encrypt_until],
|
|
) {
|
|
::tracing::error!("can't encrypt: {:?}", e);
|
|
return;
|
|
}
|
|
self.send_packet(raw_out, udp.src, udp.dst).await;
|
|
return;
|
|
}
|
|
HandshakeAction::ClientConnect(mut cci) => {
|
|
let ds_resp;
|
|
if let HandshakeData::DirSync(DirSync::Resp(resp)) =
|
|
cci.handshake.data
|
|
{
|
|
ds_resp = resp;
|
|
} else {
|
|
::tracing::error!("ClientConnect on non DS::Resp");
|
|
return;
|
|
}
|
|
// track connection
|
|
let resp_data;
|
|
if let dirsync::RespInner::ClearText(r_data) = ds_resp.data
|
|
{
|
|
resp_data = r_data;
|
|
} else {
|
|
::tracing::error!(
|
|
"ClientConnect on non DS::Resp::ClearText"
|
|
);
|
|
return;
|
|
}
|
|
{
|
|
let conn = Rc::get_mut(&mut cci.connection).unwrap();
|
|
conn.id_send = IDSend(resp_data.id);
|
|
}
|
|
// track the connection to the authentication server
|
|
if self.connections.track(cci.connection.clone()).is_err() {
|
|
self.connections.delete(cci.connection.id_recv);
|
|
}
|
|
if cci.connection.id_recv.0
|
|
== resp_data.service_connection_id
|
|
{
|
|
// the user asked a single connection
|
|
// to the authentication server, without any additional
|
|
// service. No more connections to setup
|
|
return;
|
|
}
|
|
// create and track the connection to the service
|
|
// SECURITY:
|
|
//FIXME: the Secret should be XORed with the client stored
|
|
// secret (if any)
|
|
let hkdf = HkdfSha3::new(
|
|
cci.service_id.as_bytes(),
|
|
resp_data.service_key,
|
|
);
|
|
let mut service_connection = Connection::new(
|
|
hkdf,
|
|
cci.connection.cipher_recv.kind(),
|
|
connection::Role::Client,
|
|
&self.rand,
|
|
);
|
|
service_connection.id_recv = cci.service_connection_id;
|
|
service_connection.id_send =
|
|
IDSend(resp_data.service_connection_id);
|
|
let _ = self.connections.track(service_connection.into());
|
|
return;
|
|
}
|
|
_ => {}
|
|
};
|
|
}
|
|
// copy packet, spawn
|
|
todo!();
|
|
}
|
|
async fn send_packet(
|
|
&self,
|
|
data: Vec<u8>,
|
|
client: UdpClient,
|
|
server: UdpServer,
|
|
) {
|
|
let src_sock = match self
|
|
.sockets
|
|
.iter()
|
|
.find(|&s| s.local_addr().unwrap() == server.0)
|
|
{
|
|
Some(src_sock) => src_sock,
|
|
None => {
|
|
::tracing::error!(
|
|
"Can't send packet: Server changed listening ip!"
|
|
);
|
|
return;
|
|
}
|
|
};
|
|
let _ = src_sock.send_to(&data, client.0).await;
|
|
}
|
|
}
|