TONS of bugfixing. Add tests. Client now connects

Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
This commit is contained in:
Luca Fulchir 2023-06-17 11:33:47 +02:00
parent b682068dca
commit 866edc2d7d
Signed by: luca.fulchir
GPG Key ID: 8F6440603D13A78E
15 changed files with 739 additions and 355 deletions

1
TODO Normal file
View File

@ -0,0 +1 @@
* Wrapping for everything that wraps (sigh)

View File

@ -3,6 +3,8 @@
use crate::enc::Random;
use ::zeroize::Zeroize;
/// Anonymous user id
pub const USERID_ANONYMOUS: UserID = UserID([0; UserID::len()]);
/// User identifier. 16 bytes for easy uuid conversion
#[derive(Debug, Copy, Clone, PartialEq)]
pub struct UserID(pub [u8; 16]);
@ -25,8 +27,8 @@ impl UserID {
}
}
/// Anonymous user id
pub fn new_anonymous() -> Self {
UserID([0; 16])
pub const fn new_anonymous() -> Self {
USERID_ANONYMOUS
}
/// length of the User ID in bytes
pub const fn len() -> usize {
@ -98,6 +100,16 @@ impl TryFrom<&[u8]> for Domain {
Ok(Domain(domain_string))
}
}
impl From<String> for Domain {
fn from(raw: String) -> Self {
Self(raw)
}
}
impl From<&str> for Domain {
fn from(raw: &str) -> Self {
Self(raw.to_owned())
}
}
impl Domain {
/// length of the User ID in bytes

View File

@ -16,6 +16,23 @@ use ::std::{
vec,
};
/// Key used by a server during the handshake
#[derive(Clone, Debug)]
pub struct ServerKey {
pub id: KeyID,
pub priv_key: PrivKey,
pub pub_key: PubKey,
}
/// Authentication Server information and keys
#[derive(Clone, Debug)]
pub struct AuthServer {
/// fqdn of the authentication server
pub fqdn: crate::auth::Domain,
/// list of key ids enabled for this domain
pub keys: Vec<KeyID>,
}
/// Main config for libFenrir
#[derive(Clone, Debug)]
pub struct Config {
@ -34,8 +51,12 @@ pub struct Config {
pub hkdfs: Vec<HkdfKind>,
/// Supported Ciphers
pub ciphers: Vec<CipherKind>,
/// list of authentication servers
/// clients will have this empty
pub servers: Vec<AuthServer>,
/// list of public/private keys
pub keys: Vec<(KeyID, PrivKey, PubKey)>,
/// clients should have this empty
pub server_keys: Vec<ServerKey>,
}
impl Default for Config {
@ -56,7 +77,8 @@ impl Default for Config {
key_exchanges: [KeyExchangeKind::X25519DiffieHellman].to_vec(),
hkdfs: [HkdfKind::Sha3].to_vec(),
ciphers: [CipherKind::XChaCha20Poly1305].to_vec(),
keys: Vec::new(),
servers: Vec::new(),
server_keys: Vec::new(),
}
}
}

View File

@ -113,10 +113,14 @@ impl Req {
+ self.exchange_key.kind().pub_len()
}
/// return the total length of the cleartext data
pub fn encrypted_length(&self) -> usize {
pub fn encrypted_length(
&self,
head_len: HeadLen,
tag_len: TagLen,
) -> usize {
match &self.data {
ReqInner::ClearText(data) => data.len(),
_ => 0,
ReqInner::ClearText(data) => data.len() + head_len.0 + tag_len.0,
ReqInner::CipherText(length) => *length,
}
}
/// actual length of the directory synchronized request
@ -177,11 +181,16 @@ impl super::HandshakeParsing for Req {
Some(cipher) => cipher,
None => return Err(Error::Parsing),
};
let (exchange_key, len) = match ExchangePubKey::deserialize(&raw[5..]) {
Ok(exchange_key) => exchange_key,
Err(e) => return Err(e.into()),
};
let data = ReqInner::CipherText(raw.len() - (5 + len));
const CURR_SIZE: usize = KeyID::len()
+ KeyExchangeKind::len()
+ HkdfKind::len()
+ CipherKind::len();
let (exchange_key, len) =
match ExchangePubKey::deserialize(&raw[CURR_SIZE..]) {
Ok(exchange_key) => exchange_key,
Err(e) => return Err(e.into()),
};
let data = ReqInner::CipherText(raw.len() - (CURR_SIZE + len));
Ok(HandshakeData::DirSync(DirSync::Req(Self {
key_id,
exchange,
@ -436,7 +445,7 @@ impl super::HandshakeParsing for Resp {
return Err(Error::NotEnoughData);
}
let client_key_id: KeyID =
KeyID(u16::from_le_bytes(raw[0..2].try_into().unwrap()));
KeyID(u16::from_le_bytes(raw[0..KeyID::len()].try_into().unwrap()));
Ok(HandshakeData::DirSync(DirSync::Resp(Self {
client_key_id,
data: RespInner::CipherText(raw[KeyID::len()..].len()),
@ -453,10 +462,16 @@ impl Resp {
+ KeyID::len()
}
/// return the total length of the cleartext data
pub fn encrypted_length(&self) -> usize {
pub fn encrypted_length(
&self,
head_len: HeadLen,
tag_len: TagLen,
) -> usize {
match &self.data {
RespInner::ClearText(_data) => RespData::len(),
_ => 0,
RespInner::ClearText(_data) => {
RespData::len() + head_len.0 + tag_len.0
}
RespInner::CipherText(len) => *len,
}
}
/// Total length of the response handshake
@ -471,8 +486,9 @@ impl Resp {
_tag_len: TagLen,
out: &mut [u8],
) {
out[0..2].copy_from_slice(&self.client_key_id.0.to_le_bytes());
let start_data = 2 + head_len.0;
out[0..KeyID::len()]
.copy_from_slice(&self.client_key_id.0.to_le_bytes());
let start_data = KeyID::len() + head_len.0;
let end_data = start_data + self.data.len();
self.data.serialize(&mut out[start_data..end_data]);
}

View File

@ -37,6 +37,12 @@ pub enum Error {
/// Too many client handshakes currently running
#[error("Too many client handshakes")]
TooManyClientHandshakes,
/// generic internal error
#[error("Internal tracking error")]
InternalTracking,
/// Handshake Timeout
#[error("Handshake timeout")]
Timeout,
}
/// List of possible handshakes

View File

@ -1,11 +1,11 @@
//! Handhsake handling
use crate::{
auth::ServiceID,
auth::{Domain, ServiceID},
connection::{
self,
handshake::{self, Error, Handshake},
Connection, IDRecv,
Connection, IDRecv, IDSend,
},
enc::{
self,
@ -16,16 +16,23 @@ use crate::{
inner::ThreadTracker,
};
use ::tokio::sync::oneshot;
pub(crate) struct HandshakeServer {
pub id: KeyID,
pub key: PrivKey,
pub domains: Vec<Domain>,
}
pub(crate) type ConnectAnswer = Result<(KeyID, IDSend), crate::Error>;
pub(crate) struct HandshakeClient {
pub service_id: ServiceID,
pub service_conn_id: IDRecv,
pub connection: Connection,
pub timeout: Option<::tokio::task::JoinHandle<()>>,
pub answer: oneshot::Sender<ConnectAnswer>,
pub srv_key_id: KeyID,
}
/// Tracks the keys used by the client and the handshake
@ -73,7 +80,10 @@ impl HandshakeClientList {
service_id: ServiceID,
service_conn_id: IDRecv,
connection: Connection,
) -> Result<(KeyID, &mut HandshakeClient), ()> {
answer: oneshot::Sender<ConnectAnswer>,
srv_key_id: KeyID,
) -> Result<(KeyID, &mut HandshakeClient), oneshot::Sender<ConnectAnswer>>
{
let maybe_free_key_idx =
self.used.iter().enumerate().find_map(|(idx, bmap)| {
match bmap.first_false_index() {
@ -85,7 +95,7 @@ impl HandshakeClientList {
Some((idx, false_idx)) => {
let free_key_idx = idx * 1024 + false_idx;
if free_key_idx > KeyID::MAX as usize {
return Err(());
return Err(answer);
}
self.used[idx].set(false_idx, true);
free_key_idx
@ -107,6 +117,8 @@ impl HandshakeClientList {
service_conn_id,
connection,
timeout: None,
answer,
srv_key_id,
});
Ok((
KeyID(free_key_idx as u16),
@ -136,6 +148,10 @@ pub(crate) struct ClientConnectInfo {
pub handshake: Handshake,
/// Connection
pub connection: Connection,
/// where to wake up the waiting client
pub answer: oneshot::Sender<ConnectAnswer>,
/// server public key id that we used on the handshake
pub srv_key_id: KeyID,
}
/// Intermediate actions to be taken while parsing the handshake
#[derive(Debug)]
@ -177,10 +193,42 @@ impl HandshakeTracker {
hshake_cli: HandshakeClientList::new(),
}
}
pub(crate) fn add_server(&mut self, id: KeyID, key: PrivKey) {
self.keys_srv.push(HandshakeServer { id, key });
pub(crate) fn add_server_key(
&mut self,
id: KeyID,
key: PrivKey,
) -> Result<(), ()> {
if self.keys_srv.iter().find(|&k| k.id == id).is_some() {
return Err(());
}
self.keys_srv.push(HandshakeServer {
id,
key,
domains: Vec::new(),
});
self.keys_srv.sort_by(|h_a, h_b| h_a.id.0.cmp(&h_b.id.0));
Ok(())
}
pub(crate) fn add_server_domain(
&mut self,
domain: &Domain,
key_ids: &[KeyID],
) -> Result<(), ()> {
// check that all the key ids are present
for id in key_ids.iter() {
if self.keys_srv.iter().find(|k| k.id == *id).is_none() {
return Err(());
}
}
// add the domain to those keys
for id in key_ids.iter() {
if let Some(srv) = self.keys_srv.iter_mut().find(|k| k.id == *id) {
srv.domains.push(domain.clone());
}
}
Ok(())
}
pub(crate) fn add_client(
&mut self,
priv_key: PrivKey,
@ -188,20 +236,32 @@ impl HandshakeTracker {
service_id: ServiceID,
service_conn_id: IDRecv,
connection: Connection,
) -> Result<(KeyID, &mut HandshakeClient), ()> {
answer: oneshot::Sender<ConnectAnswer>,
srv_key_id: KeyID,
) -> Result<(KeyID, &mut HandshakeClient), oneshot::Sender<ConnectAnswer>>
{
self.hshake_cli.add(
priv_key,
pub_key,
service_id,
service_conn_id,
connection,
answer,
srv_key_id,
)
}
pub(crate) fn remove_client(
&mut self,
key_id: KeyID,
) -> Option<HandshakeClient> {
self.hshake_cli.remove(key_id)
}
pub(crate) fn timeout_client(
&mut self,
key_id: KeyID,
) -> Option<[IDRecv; 2]> {
if let Some(hshake) = self.hshake_cli.remove(key_id) {
let _ = hshake.answer.send(Err(Error::Timeout.into()));
Some([hshake.connection.id_recv, hshake.service_conn_id])
} else {
None
@ -257,9 +317,16 @@ impl HandshakeTracker {
let cipher_recv = CipherRecv::new(req.cipher, secret_recv);
use crate::enc::sym::AAD;
let aad = AAD(&mut []); // no aad for now
let encrypt_from = req.encrypted_offset();
let encrypt_to = encrypt_from
+ req.encrypted_length(
cipher_recv.nonce_len(),
cipher_recv.tag_len(),
);
match cipher_recv.decrypt(
aad,
&mut handshake_raw[req.encrypted_offset()..],
&mut handshake_raw[encrypt_from..encrypt_to],
) {
Ok(cleartext) => {
req.data.deserialize_as_cleartext(cleartext)?;
@ -292,9 +359,13 @@ impl HandshakeTracker {
use crate::enc::sym::AAD;
// no aad for now
let aad = AAD(&mut []);
let mut raw_data = &mut handshake_raw[resp
.encrypted_offset()
..(resp.encrypted_offset() + resp.encrypted_length())];
let data_from = resp.encrypted_offset();
let data_to = data_from
+ resp.encrypted_length(
cipher_recv.nonce_len(),
cipher_recv.tag_len(),
);
let mut raw_data = &mut handshake_raw[data_from..data_to];
match cipher_recv.decrypt(aad, &mut raw_data) {
Ok(cleartext) => {
resp.data.deserialize_as_cleartext(&cleartext)?;
@ -314,6 +385,8 @@ impl HandshakeTracker {
service_connection_id: hshake.service_conn_id,
handshake,
connection: hshake.connection,
answer: hshake.answer,
srv_key_id: hshake.srv_key_id,
},
));
}

View File

@ -113,7 +113,11 @@ pub(crate) struct ConnList {
impl ConnList {
pub(crate) fn new(thread_id: ThreadTracker) -> Self {
let bitmap_id = ::bitmaps::Bitmap::<1024>::new();
let mut bitmap_id = ::bitmaps::Bitmap::<1024>::new();
if thread_id.id == 0 {
// make sure we don't count the Handshake ID
bitmap_id.set(0, true);
}
const INITIAL_CAP: usize = 128;
let mut ret = Self {
thread_id,
@ -199,13 +203,6 @@ impl ConnList {
}
}
use ::std::collections::HashMap;
enum MapEntry {
Present(IDSend),
Reserved,
}
/// return wether we already have a connection, we are waiting for one, or you
/// can start one
#[derive(Debug, Clone, Copy)]
@ -218,6 +215,12 @@ pub(crate) enum Reservation {
Reserved,
}
enum MapEntry {
Present(IDSend),
Reserved,
}
use ::std::collections::HashMap;
/// Link the public key of the authentication server to a connection id
/// so that we can reuse that connection to ask for more authentications
///
@ -229,16 +232,16 @@ pub(crate) enum Reservation {
/// * wait for the connection to finish
/// * remove all those reservations, exept the one key that actually succeded
/// While searching, we return a connection ID if just one key is a match
// TODO: can we shard this per-core by hashing the pubkey? or domain? or...???
// This needs a mutex and it will be our goeal to avoid any synchronization
pub(crate) struct AuthServerConnections {
conn_map: HashMap<PubKey, MapEntry>,
next_reservation: u64,
}
impl AuthServerConnections {
pub(crate) fn new() -> Self {
Self {
conn_map: HashMap::with_capacity(32),
next_reservation: 0,
}
}
/// add an ID to the reserved spot,

View File

@ -1,40 +1,10 @@
//! Socket related types and functions
use ::std::{net::SocketAddr, sync::Arc, vec::Vec};
use ::std::net::SocketAddr;
use ::tokio::{net::UdpSocket, task::JoinHandle};
/// Pair to easily track the socket and its async listening handle
pub type SocketTracker =
(Arc<UdpSocket>, Arc<JoinHandle<::std::io::Result<()>>>);
/// async free socket list
pub(crate) struct SocketList {
pub list: Vec<SocketTracker>,
}
impl SocketList {
pub(crate) fn new() -> Self {
Self { list: Vec::new() }
}
pub(crate) fn rm_all(&mut self) -> Self {
let mut old_list = Vec::new();
::core::mem::swap(&mut self.list, &mut old_list);
Self { list: old_list }
}
pub(crate) async fn add_socket(
&mut self,
socket: Arc<UdpSocket>,
handle: JoinHandle<::std::io::Result<()>>,
) {
let arc_handle = Arc::new(handle);
self.list.push((socket, arc_handle));
}
/// This method assumes no other `add_sockets` are being run
pub(crate) async fn stop_all(self) {
for (_socket, mut handle) in self.list.into_iter() {
let _ = Arc::get_mut(&mut handle).unwrap().await;
}
}
}
pub type SocketTracker = (SocketAddr, JoinHandle<::std::io::Result<()>>);
/// Strong typedef for a client socket address
#[derive(Debug, Copy, Clone)]
@ -53,7 +23,7 @@ fn enable_sock_opt(
unsafe {
#[allow(trivial_casts)]
let val = &value as *const _ as *const ::libc::c_void;
let size = ::std::mem::size_of_val(&value) as ::libc::socklen_t;
let size = ::core::mem::size_of_val(&value) as ::libc::socklen_t;
// always clear the error bit before doing a new syscall
let _ = ::std::io::Error::last_os_error();
let ret = ::libc::setsockopt(fd, ::libc::SOL_SOCKET, option, val, size);
@ -64,23 +34,107 @@ fn enable_sock_opt(
Ok(())
}
/// Add an async udp listener
pub async fn bind_udp(sock: SocketAddr) -> ::std::io::Result<UdpSocket> {
let socket = UdpSocket::bind(sock).await?;
pub async fn bind_udp(addr: SocketAddr) -> ::std::io::Result<UdpSocket> {
// I know, kind of a mess. but I really wanted SO_REUSE{ADDR,PORT} and
// no-fragmenting stuff.
// I also did not want to load another library for this.
// feel free to simplify,
// especially if we can avoid libc and other libraries
// we currently use libc because it's a dependency of many other deps
use ::std::os::fd::AsRawFd;
let fd = socket.as_raw_fd();
// can be useful later on for reloads
enable_sock_opt(fd, ::libc::SO_REUSEADDR, 1)?;
enable_sock_opt(fd, ::libc::SO_REUSEPORT, 1)?;
let fd: ::std::os::fd::RawFd = {
let domain = if addr.is_ipv6() {
::libc::AF_INET6
} else {
::libc::AF_INET
};
#[allow(unsafe_code)]
let tmp = unsafe { ::libc::socket(domain, ::libc::SOCK_DGRAM, 0) };
let lasterr = ::std::io::Error::last_os_error();
if tmp == -1 {
return Err(lasterr);
}
tmp.into()
};
if let Err(e) = enable_sock_opt(fd, ::libc::SO_REUSEPORT, 1) {
#[allow(unsafe_code)]
unsafe {
::libc::close(fd);
}
return Err(e);
}
if let Err(e) = enable_sock_opt(fd, ::libc::SO_REUSEADDR, 1) {
#[allow(unsafe_code)]
unsafe {
::libc::close(fd);
}
return Err(e);
}
// We will do path MTU discovery by ourselves,
// always set the "don't fragment" bit
if sock.is_ipv6() {
enable_sock_opt(fd, ::libc::IPV6_DONTFRAG, 1)?;
let res = if addr.is_ipv6() {
enable_sock_opt(fd, ::libc::IPV6_DONTFRAG, 1)
} else {
// FIXME: linux only
enable_sock_opt(fd, ::libc::IP_MTU_DISCOVER, ::libc::IP_PMTUDISC_DO)?;
enable_sock_opt(fd, ::libc::IP_MTU_DISCOVER, ::libc::IP_PMTUDISC_DO)
};
if let Err(e) = res {
#[allow(unsafe_code)]
unsafe {
::libc::close(fd);
}
return Err(e);
}
// manually convert rust SockAddr to C sockaddr
#[allow(unsafe_code, trivial_casts, trivial_numeric_casts)]
{
let bind_ret = match addr {
SocketAddr::V4(s4) => {
let ip4: u32 = (*s4.ip()).into();
let bind_addr = ::libc::sockaddr_in {
sin_family: ::libc::AF_INET as u16,
sin_port: s4.port().to_be(),
sin_addr: ::libc::in_addr { s_addr: ip4 },
sin_zero: [0; 8],
};
unsafe {
let c_addr =
&bind_addr as *const _ as *const ::libc::sockaddr;
::libc::bind(fd, c_addr, 16)
}
}
SocketAddr::V6(s6) => {
let ip6: [u8; 16] = (*s6.ip()).octets();
let bind_addr = ::libc::sockaddr_in6 {
sin6_family: ::libc::AF_INET6 as u16,
sin6_port: s6.port().to_be(),
sin6_flowinfo: 0,
sin6_addr: ::libc::in6_addr { s6_addr: ip6 },
sin6_scope_id: 0,
};
unsafe {
let c_addr =
&bind_addr as *const _ as *const ::libc::sockaddr;
::libc::bind(fd, c_addr, 24)
}
}
};
let lasterr = ::std::io::Error::last_os_error();
if bind_ret != 0 {
unsafe {
::libc::close(fd);
}
return Err(lasterr);
}
}
Ok(socket)
use ::std::os::fd::FromRawFd;
#[allow(unsafe_code)]
let std_sock = unsafe { ::std::net::UdpSocket::from_raw_fd(fd) };
std_sock.set_nonblocking(true)?;
::tracing::debug!("Listening udp sock: {}", std_sock.local_addr().unwrap());
Ok(UdpSocket::from_std(std_sock)?)
}

View File

@ -152,7 +152,7 @@ pub enum KeyExchangeKind {
}
impl KeyExchangeKind {
/// The serialize length of the field
pub fn len() -> usize {
pub const fn len() -> usize {
1
}
/// Build a new keypair for key exchange

View File

@ -4,6 +4,8 @@ pub mod asym;
mod errors;
pub mod hkdf;
pub mod sym;
#[cfg(test)]
mod tests;
pub use errors::Error;

View File

@ -25,12 +25,11 @@ pub enum CipherKind {
impl CipherKind {
/// length of the serialized id for the cipher kind field
pub fn len() -> usize {
pub const fn len() -> usize {
1
}
/// required length of the nonce
pub fn nonce_len(&self) -> HeadLen {
// TODO: how the hell do I take this from ::chacha20poly1305?
HeadLen(Nonce::len())
}
/// required length of the key
@ -92,10 +91,7 @@ impl Cipher {
}
fn nonce_len(&self) -> HeadLen {
match self {
Cipher::XChaCha20Poly1305(_) => {
// TODO: how the hell do I take this from ::chacha20poly1305?
HeadLen(::ring::aead::CHACHA20_POLY1305.nonce_len())
}
Cipher::XChaCha20Poly1305(_) => HeadLen(Nonce::len()),
}
}
fn tag_len(&self) -> TagLen {
@ -117,10 +113,13 @@ impl Cipher {
aead::generic_array::GenericArray, AeadInPlace,
};
let final_len: usize = {
// FIXME: check min data length
let (nonce_bytes, data_and_tag) = raw_data.split_at_mut(13);
if raw_data.len() <= self.overhead() {
return Err(Error::NotEnoughData(raw_data.len()));
}
let (nonce_bytes, data_and_tag) =
raw_data.split_at_mut(Nonce::len());
let (data_notag, tag_bytes) = data_and_tag.split_at_mut(
data_and_tag.len() + 1
data_and_tag.len()
- ::ring::aead::CHACHA20_POLY1305.tag_len(),
);
let nonce = GenericArray::from_slice(nonce_bytes);
@ -172,10 +171,7 @@ impl Cipher {
&mut data[Nonce::len()..data_len_notag],
) {
Ok(tag) => {
data[data_len_notag..]
// add tag
//data.get_tag_slice()
.copy_from_slice(tag.as_slice());
data[data_len_notag..].copy_from_slice(tag.as_slice());
Ok(())
}
Err(_) => Err(Error::Encrypt),
@ -205,6 +201,10 @@ impl CipherRecv {
pub fn nonce_len(&self) -> HeadLen {
self.0.nonce_len()
}
/// Get the length of the nonce for this cipher
pub fn tag_len(&self) -> TagLen {
self.0.tag_len()
}
/// Decrypt a paket. Nonce and Tag are taken from the packet,
/// while you need to provide AAD (Additional Authenticated Data)
pub fn decrypt<'a>(
@ -285,7 +285,7 @@ struct NonceNum {
#[repr(C)]
pub union Nonce {
num: NonceNum,
raw: [u8; 12],
raw: [u8; Self::len()],
}
impl ::core::fmt::Debug for Nonce {
@ -303,13 +303,17 @@ impl ::core::fmt::Debug for Nonce {
impl Nonce {
/// Generate a new random Nonce
pub fn new(rand: &Random) -> Self {
let mut raw = [0; 12];
let mut raw = [0; Self::len()];
rand.fill(&mut raw);
Self { raw }
}
/// Length of this nonce in bytes
pub const fn len() -> usize {
return 12;
// FIXME: was:12. xchacha20poly1305 requires 24.
// but we should change keys much earlier than that, and our
// nonces are not random, but sequential.
// we should change keys every 2^30 bytes to be sure (stream max window)
return 24;
}
/// Get reference to the nonce bytes
pub fn as_bytes(&self) -> &[u8] {
@ -319,7 +323,7 @@ impl Nonce {
}
}
/// Create Nonce from array
pub fn from_slice(raw: [u8; 12]) -> Self {
pub fn from_slice(raw: [u8; Self::len()]) -> Self {
Self { raw }
}
/// Go to the next nonce
@ -336,6 +340,7 @@ impl Nonce {
}
/// Synchronize the mutex acess with a nonce for multithread safety
// TODO: remove mutex, not needed anymore
#[derive(Debug)]
pub struct NonceSync {
nonce: ::std::sync::Mutex<Nonce>,

135
src/enc/tests.rs Normal file
View File

@ -0,0 +1,135 @@
use crate::{
auth,
connection::{handshake::*, ID},
enc::{self, asym::KeyID},
};
#[test]
fn test_simple_encrypt_decrypt() {
let rand = enc::Random::new();
let cipher = enc::sym::CipherKind::XChaCha20Poly1305;
let secret = enc::Secret::new_rand(&rand);
let secret2 = secret.clone();
let cipher_send = enc::sym::CipherSend::new(cipher, secret, &rand);
let cipher_recv = enc::sym::CipherRecv::new(cipher, secret2);
let mut data = Vec::new();
let tot_len = cipher_recv.nonce_len().0 + 1234 + cipher_recv.tag_len().0;
data.resize(tot_len, 0);
rand.fill(&mut data);
data[..enc::sym::Nonce::len()].copy_from_slice(&[0; 24]);
let last = data.len() - cipher_recv.tag_len().0;
data[last..].copy_from_slice(&[0; 16]);
let orig = data.clone();
let raw_aad: [u8; 8] = [0, 1, 2, 3, 4, 5, 6, 7];
let aad = enc::sym::AAD(&raw_aad[..]);
let aad2 = enc::sym::AAD(&raw_aad[..]);
if cipher_send.encrypt(aad, &mut data).is_err() {
assert!(false, "Encrypt failed");
}
if cipher_recv.decrypt(aad2, &mut data).is_err() {
assert!(false, "Decrypt failed");
}
data[..enc::sym::Nonce::len()].copy_from_slice(&[0; 24]);
let last = data.len() - cipher_recv.tag_len().0;
data[last..].copy_from_slice(&[0; 16]);
assert!(orig == data, "DIFFERENT!\n{:?}\n{:?}\n", orig, data);
}
#[test]
fn test_encrypt_decrypt() {
let rand = enc::Random::new();
let cipher = enc::sym::CipherKind::XChaCha20Poly1305;
let secret = enc::Secret::new_rand(&rand);
let secret2 = secret.clone();
let cipher_send = enc::sym::CipherSend::new(cipher, secret, &rand);
let cipher_recv = enc::sym::CipherRecv::new(cipher, secret2);
let nonce_len = cipher_recv.nonce_len();
let tag_len = cipher_recv.tag_len();
let service_key = enc::Secret::new_rand(&rand);
let data = dirsync::RespInner::ClearText(dirsync::RespData {
client_nonce: dirsync::Nonce::new(&rand),
id: ID::ID(::core::num::NonZeroU64::new(424242).unwrap()),
service_connection_id: ID::ID(
::core::num::NonZeroU64::new(434343).unwrap(),
),
service_key,
});
let resp = dirsync::Resp {
client_key_id: KeyID(4444),
data,
};
let encrypt_from = resp.encrypted_offset();
let encrypt_to = encrypt_from + resp.encrypted_length(nonce_len, tag_len);
let h_resp =
Handshake::new(HandshakeData::DirSync(dirsync::DirSync::Resp(resp)));
let mut bytes = Vec::<u8>::with_capacity(
h_resp.len(cipher.nonce_len(), cipher.tag_len()),
);
bytes.resize(h_resp.len(cipher.nonce_len(), cipher.tag_len()), 0);
h_resp.serialize(cipher.nonce_len(), cipher.tag_len(), &mut bytes);
let raw_aad: [u8; 7] = [0, 1, 2, 3, 4, 5, 6];
let aad = enc::sym::AAD(&raw_aad[..]);
let aad2 = enc::sym::AAD(&raw_aad[..]);
let pre_encrypt = bytes.clone();
// encrypt
if cipher_send
.encrypt(aad, &mut bytes[encrypt_from..encrypt_to])
.is_err()
{
assert!(false, "Encrypt failed");
}
if cipher_recv
.decrypt(aad2, &mut bytes[encrypt_from..encrypt_to])
.is_err()
{
assert!(false, "Decrypt failed");
}
// make sure Nonce and Tag are 0
bytes[encrypt_from..(encrypt_from + nonce_len.0)].copy_from_slice(&[0; 24]);
let tag_from = encrypt_to - tag_len.0;
bytes[tag_from..(tag_from + tag_len.0)].copy_from_slice(&[0; 16]);
assert!(
pre_encrypt == bytes,
"{}|{}=\n{:?}\n{:?}",
encrypt_from,
encrypt_to,
pre_encrypt,
bytes
);
// decrypt
let mut deserialized = match Handshake::deserialize(&bytes) {
Ok(deserialized) => deserialized,
Err(e) => {
assert!(false, "{}", e.to_string());
return;
}
};
// reparse
if let HandshakeData::DirSync(dirsync::DirSync::Resp(r_a)) =
&mut deserialized.data
{
let enc_start = r_a.encrypted_offset() + cipher.nonce_len().0;
if let Err(e) = r_a.data.deserialize_as_cleartext(
&bytes[enc_start..(bytes.len() - cipher.tag_len().0)],
) {
assert!(false, "DirSync Resp Inner serialize: {}", e.to_string());
}
};
assert!(
deserialized == h_resp,
"DirSync Resp (de)serialization not working",
);
}

View File

@ -37,7 +37,7 @@ pub(crate) struct RawUdp {
}
pub(crate) struct ConnectInfo {
pub answer: oneshot::Sender<Result<(PubKey, IDSend), crate::Error>>,
pub answer: oneshot::Sender<handshake::tracker::ConnectAnswer>,
pub resolved: dnssec::Record,
pub service_id: ServiceID,
pub domain: Domain,
@ -57,14 +57,15 @@ pub(crate) enum WorkAnswer {
}
/// Actual worker implementation.
pub(crate) struct Worker {
#[allow(missing_debug_implementations)]
pub struct Worker {
cfg: Config,
thread_id: ThreadTracker,
// PERF: rand uses syscalls. how to do that async?
rand: Random,
stop_working: crate::StopWorkingRecvCh,
token_check: Option<Arc<Mutex<TokenChecker>>>,
sockets: Vec<UdpSocket>,
sockets: Vec<Arc<UdpSocket>>,
queue: ::async_channel::Receiver<Work>,
queue_timeouts_recv: mpsc::UnboundedReceiver<Work>,
queue_timeouts_send: mpsc::UnboundedSender<Work>,
@ -73,64 +74,18 @@ pub(crate) struct Worker {
handshakes: HandshakeTracker,
}
#[allow(unsafe_code)]
unsafe impl Send for Worker {}
impl Worker {
pub(crate) async fn new_and_loop(
cfg: Config,
thread_id: ThreadTracker,
stop_working: crate::StopWorkingRecvCh,
token_check: Option<Arc<Mutex<TokenChecker>>>,
socket_addrs: Vec<::std::net::SocketAddr>,
queue: ::async_channel::Receiver<Work>,
) -> ::std::io::Result<()> {
// TODO: get a channel to send back information, and send the error
let mut worker = Self::new(
cfg,
thread_id,
stop_working,
token_check,
socket_addrs,
queue,
)
.await?;
worker.work_loop().await;
Ok(())
}
pub(crate) async fn new(
mut cfg: Config,
thread_id: ThreadTracker,
stop_working: crate::StopWorkingRecvCh,
token_check: Option<Arc<Mutex<TokenChecker>>>,
socket_addrs: Vec<::std::net::SocketAddr>,
sockets: Vec<Arc<UdpSocket>>,
queue: ::async_channel::Receiver<Work>,
) -> ::std::io::Result<Self> {
// bind all sockets again so that we can easily
// send without sharing resources
// in the future we will want to have a thread-local listener too,
// but before that we need ebpf to pin a connection to a thread
// directly from the kernel
let mut sock_set = ::tokio::task::JoinSet::new();
socket_addrs.into_iter().for_each(|s_addr| {
sock_set.spawn(async move {
let socket =
connection::socket::bind_udp(s_addr.clone()).await?;
Ok(socket)
});
});
// make sure we either add all of them, or none
let mut sockets = Vec::with_capacity(cfg.listen.len());
while let Some(join_res) = sock_set.join_next().await {
match join_res {
Ok(s_res) => match s_res {
Ok(sock) => sockets.push(sock),
Err(e) => {
::tracing::error!("Can't rebind socket");
return Err(e);
}
},
Err(e) => return Err(e.into()),
}
}
let (queue_timeouts_send, queue_timeouts_recv) =
mpsc::unbounded_channel();
let mut handshakes = HandshakeTracker::new(
@ -138,11 +93,24 @@ impl Worker {
cfg.ciphers.clone(),
cfg.key_exchanges.clone(),
);
let mut keys = Vec::new();
let mut server_keys = Vec::new();
// make sure the keys are no longer in the config
::core::mem::swap(&mut keys, &mut cfg.keys);
for k in keys.into_iter() {
handshakes.add_server(k.0, k.1);
::core::mem::swap(&mut server_keys, &mut cfg.server_keys);
for k in server_keys.into_iter() {
if handshakes.add_server_key(k.id, k.priv_key).is_err() {
return Err(::std::io::Error::new(
::std::io::ErrorKind::AlreadyExists,
"You can't use the same KeyID for multiple keys",
));
}
}
for srv in cfg.servers.iter() {
if handshakes.add_server_domain(&srv.fqdn, &srv.keys).is_err() {
return Err(::std::io::Error::new(
::std::io::ErrorKind::NotFound,
"Specified a KeyID that we don't have",
));
}
}
Ok(Self {
@ -160,12 +128,15 @@ impl Worker {
handshakes,
})
}
pub(crate) async fn work_loop(&mut self) {
/// Continuously loop and process work as needed
pub async fn work_loop(&mut self) {
'mainloop: loop {
let work = ::tokio::select! {
tell_stopped = self.stop_working.recv() => {
let _ = tell_stopped.unwrap().send(
if let Ok(stop_ch) = tell_stopped {
let _ = stop_ch.send(
crate::StopWorking::WorkerStopped).await;
}
break;
}
maybe_timeout = self.queue.recv() => {
@ -302,6 +273,7 @@ impl Worker {
}
};
let hkdf;
if let PubKey::Exchange(srv_pub) = key.1 {
let secret =
match priv_key.key_exchange(exchange, srv_pub) {
@ -341,11 +313,13 @@ impl Worker {
conn_info.service_id,
service_conn_id,
conn,
conn_info.answer,
key.0,
) {
Ok((client_key_id, hshake)) => (client_key_id, hshake),
Err(_) => {
Err(answer) => {
::tracing::warn!("Too many client handshakes");
let _ = conn_info.answer.send(Err(
let _ = answer.send(Err(
handshake::Error::TooManyClientHandshakes
.into(),
));
@ -363,7 +337,7 @@ impl Worker {
let req_data = dirsync::ReqData {
nonce: dirsync::Nonce::new(&self.rand),
client_key_id,
id: auth_recv_id.0,
id: auth_recv_id.0, //FIXME: is zero
auth: auth_info,
};
let req = dirsync::Req {
@ -374,28 +348,50 @@ impl Worker {
exchange_key: pub_key,
data: dirsync::ReqInner::ClearText(req_data),
};
let mut raw = Vec::<u8>::with_capacity(req.len());
req.serialize(
let encrypt_start = ID::len() + req.encrypted_offset();
let encrypt_end = encrypt_start
+ req.encrypted_length(
cipher_selected.nonce_len(),
cipher_selected.tag_len(),
);
let h_req = Handshake::new(HandshakeData::DirSync(
DirSync::Req(req),
));
use connection::{PacketData, ID};
let packet = Packet {
id: ID::Handshake,
data: PacketData::Handshake(h_req),
};
let tot_len = packet.len(
cipher_selected.nonce_len(),
cipher_selected.tag_len(),
);
let mut raw = Vec::<u8>::with_capacity(tot_len);
raw.resize(tot_len, 0);
packet.serialize(
cipher_selected.nonce_len(),
cipher_selected.tag_len(),
&mut raw[..],
);
// encrypt
let encrypt_start = req.encrypted_offset();
let encrypt_end = encrypt_start + req.encrypted_length();
if let Err(e) = hshake.connection.cipher_send.encrypt(
sym::AAD(&[]),
&mut raw[encrypt_start..encrypt_end],
) {
::tracing::error!("Can't encrypt DirSync Request");
let _ = conn_info.answer.send(Err(e.into()));
if let Some(client) =
self.handshakes.remove_client(client_key_id)
{
let _ = client.answer.send(Err(e.into()));
};
continue 'mainloop;
}
// send always from the first socket
// FIXME: select based on routing table
let sender = self.sockets[0].local_addr().unwrap();
let dest = UdpServer(addr.as_sockaddr().unwrap());
let dest = UdpClient(addr.as_sockaddr().unwrap());
// start the timeout right before sending the packet
hshake.timeout = Some(::tokio::task::spawn_local(
@ -406,7 +402,7 @@ impl Worker {
));
// send packet
self.send_packet(raw, UdpClient(sender), dest).await;
self.send_packet(raw, dest, UdpServer(sender)).await;
continue 'mainloop;
}
@ -435,17 +431,19 @@ impl Worker {
/// Read and do stuff with the raw udp packet
async fn recv(&mut self, mut udp: RawUdp) {
if udp.packet.id.is_handshake() {
let handshake = match Handshake::deserialize(&udp.data[8..]) {
let handshake = match Handshake::deserialize(
&udp.data[connection::ID::len()..],
) {
Ok(handshake) => handshake,
Err(e) => {
::tracing::warn!("Handshake parsing: {}", e);
::tracing::debug!("Handshake parsing: {}", e);
return;
}
};
let action = match self
.handshakes
.recv_handshake(handshake, &mut udp.data[8..])
{
let action = match self.handshakes.recv_handshake(
handshake,
&mut udp.data[connection::ID::len()..],
) {
Ok(action) => action,
Err(err) => {
::tracing::debug!("Handshake recv error {}", err);
@ -454,16 +452,6 @@ impl Worker {
};
match action {
HandshakeAction::AuthNeeded(authinfo) => {
let token_check = match self.token_check.as_ref() {
Some(token_check) => token_check,
None => {
::tracing::error!(
"Authentication requested but \
we have no token checker"
);
return;
}
};
let req;
if let HandshakeData::DirSync(DirSync::Req(r)) =
authinfo.handshake.data
@ -477,25 +465,36 @@ impl Worker {
let req_data = match req.data {
ReqInner::ClearText(req_data) => req_data,
_ => {
::tracing::error!(
"token_check: expected ClearText"
);
::tracing::error!("AuthNeeded: expected ClearText");
assert!(false, "AuthNeeded: unreachable");
return;
}
};
// FIXME: This part can take a while,
// we should just spawn it probably
let is_authenticated = {
let tk_check = token_check.lock().await;
tk_check(
req_data.auth.user,
req_data.auth.token,
req_data.auth.service_id,
req_data.auth.domain,
)
.await
let maybe_auth_check = {
match &self.token_check {
None => {
if req_data.auth.user == auth::USERID_ANONYMOUS
{
Ok(true)
} else {
Ok(false)
}
}
Some(token_check) => {
let tk_check = token_check.lock().await;
tk_check(
req_data.auth.user,
req_data.auth.token,
req_data.auth.service_id,
req_data.auth.domain,
)
.await
}
}
};
let is_authenticated = match is_authenticated {
let is_authenticated = match maybe_auth_check {
Ok(is_authenticated) => is_authenticated,
Err(_) => {
::tracing::error!("error in token auth");
@ -545,9 +544,9 @@ impl Worker {
client_key_id: req_data.client_key_id,
data: RespInner::ClearText(resp_data),
};
let offset_to_encrypt = resp.encrypted_offset();
let encrypt_from = ID::len() + resp.encrypted_offset();
let encrypt_until =
offset_to_encrypt + resp.encrypted_length() + tag_len.0;
encrypt_from + resp.encrypted_length(head_len, tag_len);
let resp_handshake = Handshake::new(
HandshakeData::DirSync(DirSync::Resp(resp)),
);
@ -556,14 +555,15 @@ impl Worker {
id: ID::new_handshake(),
data: PacketData::Handshake(resp_handshake),
};
let mut raw_out =
Vec::<u8>::with_capacity(packet.len(head_len, tag_len));
let tot_len = packet.len(head_len, tag_len);
let mut raw_out = Vec::<u8>::with_capacity(tot_len);
raw_out.resize(tot_len, 0);
packet.serialize(head_len, tag_len, &mut raw_out);
if let Err(e) = auth_conn.cipher_send.encrypt(
aad,
&mut raw_out[offset_to_encrypt..encrypt_until],
) {
if let Err(e) = auth_conn
.cipher_send
.encrypt(aad, &mut raw_out[encrypt_from..encrypt_until])
{
::tracing::error!("can't encrypt: {:?}", e);
return;
}
@ -588,43 +588,46 @@ impl Worker {
::tracing::error!(
"ClientConnect on non DS::Resp::ClearText"
);
return;
unreachable!();
}
let auth_srv_conn = IDSend(resp_data.id);
let mut conn = cci.connection;
conn.id_send = IDSend(resp_data.id);
conn.id_send = auth_srv_conn;
let id_recv = conn.id_recv;
let cipher = conn.cipher_recv.kind();
// track the connection to the authentication server
if self.connections.track(conn.into()).is_err() {
::tracing::error!("Could not track new connection");
self.connections.remove(id_recv);
let _ = cci.answer.send(Err(
handshake::Error::InternalTracking.into(),
));
return;
}
if cci.service_id == auth::SERVICEID_AUTH {
// the user asked a single connection
// to the authentication server, without any additional
// service. No more connections to setup
return;
if cci.service_id != auth::SERVICEID_AUTH {
// create and track the connection to the service
// SECURITY: xor with secrets
//FIXME: the Secret should be XORed with the client
// stored secret (if any)
let hkdf = Hkdf::new(
HkdfKind::Sha3,
cci.service_id.as_bytes(),
resp_data.service_key,
);
let mut service_connection = Connection::new(
hkdf,
cipher,
connection::Role::Client,
&self.rand,
);
service_connection.id_recv = cci.service_connection_id;
service_connection.id_send =
IDSend(resp_data.service_connection_id);
let _ =
self.connections.track(service_connection.into());
}
// create and track the connection to the service
// SECURITY: xor with secrets
//FIXME: the Secret should be XORed with the client stored
// secret (if any)
let hkdf = Hkdf::new(
HkdfKind::Sha3,
cci.service_id.as_bytes(),
resp_data.service_key,
);
let mut service_connection = Connection::new(
hkdf,
cipher,
connection::Role::Client,
&self.rand,
);
service_connection.id_recv = cci.service_connection_id;
service_connection.id_send =
IDSend(resp_data.service_connection_id);
let _ = self.connections.track(service_connection.into());
let _ =
cci.answer.send(Ok((cci.srv_key_id, auth_srv_conn)));
}
HandshakeAction::Nothing => {}
};
@ -644,11 +647,12 @@ impl Worker {
Some(src_sock) => src_sock,
None => {
::tracing::error!(
"Can't send packet: Server changed listening ip!"
"Can't send packet: Server changed listening ip{}!",
server.0
);
return;
}
};
let _ = src_sock.send_to(&data, client.0).await;
let res = src_sock.send_to(&data, client.0).await;
}
}

View File

@ -30,7 +30,7 @@ use crate::{
auth::{Domain, ServiceID, TokenChecker},
connection::{
handshake,
socket::{SocketList, UdpClient, UdpServer},
socket::{SocketTracker, UdpClient, UdpServer},
AuthServerConnections, Packet,
},
inner::{
@ -86,7 +86,7 @@ pub struct Fenrir {
/// library Configuration
cfg: Config,
/// listening udp sockets
sockets: SocketList,
sockets: Vec<SocketTracker>,
/// DNSSEC resolver, with failovers
dnssec: dnssec::Dnssec,
/// Broadcast channel to tell workers to stop working
@ -100,9 +100,6 @@ pub struct Fenrir {
// manner
_thread_pool: Vec<::std::thread::JoinHandle<()>>,
_thread_work: Arc<Vec<::async_channel::Sender<Work>>>,
// This can be different from cfg.listen since using port 0 will result
// in a random port assigned by the operative system
_listen_addrs: Vec<::std::net::SocketAddr>,
}
// TODO: graceful vs immediate stop
@ -127,16 +124,23 @@ impl Fenrir {
}
fn stop_sync(
&mut self,
) -> Option<(::tokio::sync::mpsc::Receiver<StopWorking>, usize, usize)>
{
let listeners_num = self.sockets.list.len();
) -> Option<(
::tokio::sync::mpsc::Receiver<StopWorking>,
Vec<::tokio::task::JoinHandle<::std::io::Result<()>>>,
usize,
)> {
let workers_num = self._thread_work.len();
if self.sockets.list.len() > 0 || self._thread_work.len() > 0 {
if self.sockets.len() > 0 || self._thread_work.len() > 0 {
let (ch_send, ch_recv) = ::tokio::sync::mpsc::channel(4);
let _ = self.stop_working.send(ch_send);
let _ = self.sockets.rm_all();
let mut old_listeners = Vec::with_capacity(self.sockets.len());
::core::mem::swap(&mut old_listeners, &mut self.sockets);
self._thread_pool.clear();
Some((ch_recv, listeners_num, workers_num))
let listeners = old_listeners
.into_iter()
.map(|(_, joinable)| joinable)
.collect();
Some((ch_recv, listeners, workers_num))
} else {
None
}
@ -144,9 +148,10 @@ impl Fenrir {
async fn stop_wait(
&mut self,
mut ch: ::tokio::sync::mpsc::Receiver<StopWorking>,
mut listeners_num: usize,
listeners: Vec<::tokio::task::JoinHandle<::std::io::Result<()>>>,
mut workers_num: usize,
) {
let mut listeners_num = listeners.len();
while listeners_num > 0 && workers_num > 0 {
match ch.recv().await {
Some(stopped) => match stopped {
@ -158,6 +163,11 @@ impl Fenrir {
_ => break,
}
}
for l in listeners.into_iter() {
if let Err(e) = l.await {
::tracing::error!("Unclean shutdown of listener: {:?}", e);
}
}
}
/// Create a new Fenrir endpoint
/// spawn threads pinned to cpus in our own way with tokio's runtime
@ -167,22 +177,32 @@ impl Fenrir {
) -> Result<Self, Error> {
let (sender, _) = ::tokio::sync::broadcast::channel(1);
let dnssec = dnssec::Dnssec::new(&config.resolvers)?;
// bind sockets early so we can change "port 0" (aka: random)
// in the config
let binded_sockets = Self::bind_sockets(&config).await?;
let socket_addrs = binded_sockets
.iter()
.map(|s| s.local_addr().unwrap())
.collect();
let cfg = {
let mut tmp = config.clone();
tmp.listen = socket_addrs;
tmp
};
let mut endpoint = Self {
cfg: config.clone(),
sockets: SocketList::new(),
cfg,
sockets: Vec::with_capacity(config.listen.len()),
dnssec,
stop_working: sender,
token_check: None,
conn_auth_srv: Mutex::new(AuthServerConnections::new()),
_thread_pool: Vec::new(),
_thread_work: Arc::new(Vec::new()),
_listen_addrs: Vec::with_capacity(config.listen.len()),
};
endpoint.start_work_threads_pinned(tokio_rt).await?;
match endpoint.add_sockets().await {
Ok(addrs) => endpoint._listen_addrs = addrs,
Err(e) => return Err(e.into()),
}
endpoint
.start_work_threads_pinned(tokio_rt, binded_sockets.clone())
.await?;
endpoint.run_listeners(binded_sockets).await?;
Ok(endpoint)
}
/// Create a new Fenrir endpoint
@ -192,41 +212,39 @@ impl Fenrir {
/// * make sure that the threads are pinned on the cpu
pub async fn with_workers(
config: &Config,
) -> Result<
(
Self,
Vec<impl futures::Future<Output = Result<(), std::io::Error>>>,
),
Error,
> {
) -> Result<(Self, Vec<Worker>), Error> {
let (stop_working, _) = ::tokio::sync::broadcast::channel(1);
let dnssec = dnssec::Dnssec::new(&config.resolvers)?;
let cfg = config.clone();
let sockets = SocketList::new();
let conn_auth_srv = Mutex::new(AuthServerConnections::new());
let thread_pool = Vec::new();
let thread_work = Arc::new(Vec::new());
let listen_addrs = Vec::with_capacity(config.listen.len());
// bind sockets early so we can change "port 0" (aka: random)
// in the config
let binded_sockets = Self::bind_sockets(&config).await?;
let socket_addrs = binded_sockets
.iter()
.map(|s| s.local_addr().unwrap())
.collect();
let cfg = {
let mut tmp = config.clone();
tmp.listen = socket_addrs;
tmp
};
let mut endpoint = Self {
cfg,
sockets,
sockets: Vec::with_capacity(config.listen.len()),
dnssec,
stop_working: stop_working.clone(),
token_check: None,
conn_auth_srv,
_thread_pool: thread_pool,
_thread_work: thread_work,
_listen_addrs: listen_addrs,
conn_auth_srv: Mutex::new(AuthServerConnections::new()),
_thread_pool: Vec::new(),
_thread_work: Arc::new(Vec::new()),
};
let worker_num = config.threads.unwrap().get();
let mut workers = Vec::with_capacity(worker_num);
for _ in 0..worker_num {
workers.push(endpoint.start_single_worker().await?);
}
match endpoint.add_sockets().await {
Ok(addrs) => endpoint._listen_addrs = addrs,
Err(e) => return Err(e.into()),
workers.push(
endpoint.start_single_worker(binded_sockets.clone()).await?,
);
}
endpoint.run_listeners(binded_sockets).await?;
Ok((endpoint, workers))
}
/// Returns the list of the actual addresses we are listening on
@ -234,57 +252,56 @@ impl Fenrir {
/// if you specified UDP port 0 a random one has been assigned to you
/// by the operating system.
pub fn addresses(&self) -> Vec<::std::net::SocketAddr> {
self._listen_addrs.clone()
self.sockets.iter().map(|(s, _)| s.clone()).collect()
}
// only call **after** starting all threads
/// Add all UDP sockets found in config
/// and start listening for packets
async fn add_sockets(
&mut self,
) -> ::std::io::Result<Vec<::std::net::SocketAddr>> {
// only call **before** starting all threads
/// bind all UDP sockets found in config
async fn bind_sockets(cfg: &Config) -> Result<Vec<Arc<UdpSocket>>, Error> {
// try to bind multiple sockets in parallel
let mut sock_set = ::tokio::task::JoinSet::new();
self.cfg.listen.iter().for_each(|s_addr| {
cfg.listen.iter().for_each(|s_addr| {
let socket_address = s_addr.clone();
let stop_working = self.stop_working.subscribe();
let th_work = self._thread_work.clone();
sock_set.spawn(async move {
let s = connection::socket::bind_udp(socket_address).await?;
let arc_s = Arc::new(s);
let join = ::tokio::spawn(Self::listen_udp(
stop_working,
th_work,
arc_s.clone(),
));
Ok((arc_s, join))
connection::socket::bind_udp(socket_address).await
});
});
// make sure we either add all of them, or none
let mut all_socks = Vec::with_capacity(self.cfg.listen.len());
// make sure we either return all of them, or none
let mut all_socks = Vec::with_capacity(cfg.listen.len());
while let Some(join_res) = sock_set.join_next().await {
match join_res {
Ok(s_res) => match s_res {
Ok(s) => {
all_socks.push(s);
all_socks.push(Arc::new(s));
}
Err(e) => {
return Err(e);
return Err(e.into());
}
},
Err(e) => {
return Err(e.into());
return Err(Error::Setup(e.to_string()));
}
}
}
let mut ret = Vec::with_capacity(self.cfg.listen.len());
for (arc_s, join) in all_socks.into_iter() {
ret.push(arc_s.local_addr().unwrap());
self.sockets.add_socket(arc_s, join).await;
assert!(all_socks.len() == cfg.listen.len(), "missing socks");
Ok(all_socks)
}
// only call **after** starting all threads
/// spawn all listeners
async fn run_listeners(
&mut self,
socks: Vec<Arc<UdpSocket>>,
) -> Result<(), Error> {
for sock in socks.into_iter() {
let sockaddr = sock.local_addr().unwrap();
let stop_working = self.stop_working.subscribe();
let th_work = self._thread_work.clone();
let joinable = ::tokio::spawn(async move {
Self::listen_udp(stop_working, th_work, sock.clone()).await
});
self.sockets.push((sockaddr, joinable));
}
Ok(ret)
Ok(())
}
/// Run a dedicated loop to read packets on the listening socket
@ -301,12 +318,15 @@ impl Fenrir {
let (bytes, sock_sender) = ::tokio::select! {
tell_stopped = stop_working.recv() => {
drop(socket);
let _ = tell_stopped.unwrap()
.send(StopWorking::ListenerStopped).await;
if let Ok(stop_ch) = tell_stopped {
let _ = stop_ch
.send(StopWorking::ListenerStopped).await;
}
return Ok(());
}
result = socket.recv_from(&mut buffer) => {
result?
let (bytes, from) = result?;
(bytes, UdpClient(from))
}
};
let data: Vec<u8> = buffer[..bytes].to_vec();
@ -324,17 +344,15 @@ impl Fenrir {
use connection::packet::ConnectionID;
match packet.id {
ConnectionID::Handshake => {
let send_port = sock_sender.port() as u64;
((send_port % queues_num) - 1) as usize
}
ConnectionID::ID(id) => {
((id.get() % queues_num) - 1) as usize
let send_port = sock_sender.0.port() as u64;
(send_port % queues_num) as usize
}
ConnectionID::ID(id) => (id.get() % queues_num) as usize,
}
};
let _ = work_queues[thread_idx]
.send(Work::Recv(RawUdp {
src: UdpClient(sock_sender),
src: sock_sender,
dst: sock_receiver,
packet,
data,
@ -431,7 +449,7 @@ impl Fenrir {
.unwrap();
// and tell that thread to connect somewhere
let (send, recv) = ::tokio::sync::oneshot::channel();
let (send, mut recv) = ::tokio::sync::oneshot::channel();
let _ = self._thread_work[thread_idx]
.send(Work::Connect(ConnectInfo {
answer: send,
@ -450,10 +468,15 @@ impl Fenrir {
conn_auth_lock.remove_reserved(&resolved);
Err(e)
}
Ok((pubkey, id_send)) => {
Ok((key_id, id_send)) => {
let key = resolved
.public_keys
.iter()
.find(|k| k.0 == key_id)
.unwrap();
let mut conn_auth_lock =
self.conn_auth_srv.lock().await;
conn_auth_lock.add(&pubkey, id_send, &resolved);
conn_auth_lock.add(&key.1, id_send, &resolved);
//FIXME: user needs to somehow track the connection
Ok(())
@ -472,13 +495,11 @@ impl Fenrir {
}
}
// needs to be called before add_sockets
// needs to be called before run_listeners
async fn start_single_worker(
&mut self,
) -> ::std::result::Result<
impl futures::Future<Output = Result<(), std::io::Error>>,
Error,
> {
socks: Vec<Arc<UdpSocket>>,
) -> ::std::result::Result<Worker, Error> {
let thread_idx = self._thread_work.len() as u16;
let max_threads = self.cfg.threads.unwrap().get() as u16;
if thread_idx >= max_threads {
@ -496,17 +517,18 @@ impl Fenrir {
total: max_threads,
};
let (work_send, work_recv) = ::async_channel::unbounded::<Work>();
let worker = Worker::new_and_loop(
let worker = Worker::new(
self.cfg.clone(),
thread_id,
self.stop_working.subscribe(),
self.token_check.clone(),
self.cfg.listen.clone(),
socks,
work_recv,
);
)
.await?;
// don't keep around private keys too much
if (thread_idx + 1) == max_threads {
self.cfg.keys.clear();
self.cfg.server_keys.clear();
}
loop {
let queues_lock = match Arc::get_mut(&mut self._thread_work) {
@ -533,6 +555,7 @@ impl Fenrir {
async fn start_work_threads_pinned(
&mut self,
tokio_rt: Arc<::tokio::runtime::Runtime>,
sockets: Vec<Arc<UdpSocket>>,
) -> ::std::result::Result<(), Error> {
use ::std::sync::Mutex;
let hw_topology = match ::hwloc2::Topology::new() {
@ -568,7 +591,7 @@ impl Fenrir {
let (work_send, work_recv) = ::async_channel::unbounded::<Work>();
let th_stop_working = self.stop_working.subscribe();
let th_token_check = self.token_check.clone();
let th_socket_addrs = self.cfg.listen.clone();
let th_sockets = sockets.clone();
let thread_id = ThreadTracker {
total: cores as u16,
id: 1 + (core as u16),
@ -598,17 +621,22 @@ impl Fenrir {
// finally run the main worker.
// make sure things stay on this thread
let tk_local = ::tokio::task::LocalSet::new();
let _ = tk_local.block_on(
&th_tokio_rt,
Worker::new_and_loop(
let _ = tk_local.block_on(&th_tokio_rt, async move {
let mut worker = match Worker::new(
th_config,
thread_id,
th_stop_working,
th_token_check,
th_socket_addrs,
th_sockets,
work_recv,
),
);
)
.await
{
Ok(worker) => worker,
Err(_) => return,
};
worker.work_loop().await
});
});
loop {
let queues_lock = match Arc::get_mut(&mut self._thread_work) {
@ -627,7 +655,7 @@ impl Fenrir {
self._thread_pool.push(join_handle);
}
// don't keep around private keys too much
self.cfg.keys.clear();
self.cfg.server_keys.clear();
Ok(())
}
}

View File

@ -21,23 +21,46 @@ async fn test_connection_dirsync() {
cfg.threads = Some(::core::num::NonZeroUsize::new(1).unwrap());
cfg
};
let test_domain: Domain = "example.com".into();
let cfg_server = {
let mut cfg = cfg_client.clone();
cfg.keys = [(KeyID(42), priv_exchange_key, pub_exchange_key)].to_vec();
cfg.server_keys = [config::ServerKey {
id: KeyID(42),
priv_key: priv_exchange_key,
pub_key: pub_exchange_key,
}]
.to_vec();
cfg.servers = [config::AuthServer {
fqdn: test_domain.clone(),
keys: [KeyID(42)].to_vec(),
}]
.to_vec();
cfg
};
let (server, mut srv_workers) =
Fenrir::with_workers(&cfg_server).await.unwrap();
let srv_worker = srv_workers.pop().unwrap();
let local_thread = ::tokio::task::LocalSet::new();
local_thread.spawn_local(async move { srv_worker.await });
let (client, mut cli_workers) =
Fenrir::with_workers(&cfg_client).await.unwrap();
let cli_worker = cli_workers.pop().unwrap();
local_thread.spawn_local(async move { cli_worker.await });
let mut srv_worker = srv_workers.pop().unwrap();
let mut cli_worker = cli_workers.pop().unwrap();
::std::thread::spawn(move || {
let rt = ::tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let local_thread = ::tokio::task::LocalSet::new();
local_thread.spawn_local(async move {
srv_worker.work_loop().await;
});
local_thread.spawn_local(async move {
::tokio::time::sleep(::std::time::Duration::from_millis(100)).await;
cli_worker.work_loop().await;
});
rt.block_on(local_thread);
});
use crate::{
connection::handshake::HandshakeID,
@ -63,17 +86,17 @@ async fn test_connection_dirsync() {
ciphers: [enc::sym::CipherKind::XChaCha20Poly1305].to_vec(),
};
server.graceful_stop().await;
client.graceful_stop().await;
return;
::tokio::time::sleep(::std::time::Duration::from_millis(500)).await;
match client
.connect_resolved(dnssec_record, &test_domain, auth::SERVICEID_AUTH)
.await
{
Ok(()) => {}
Err(e) => {
assert!(false, "Err on client connection: {:?}", e);
}
}
let _ = client
.connect_resolved(
dnssec_record,
&Domain("example.com".to_owned()),
auth::SERVICEID_AUTH,
)
.await;
server.graceful_stop().await;
client.graceful_stop().await;
}