TONS of bugfixing. Add tests. Client now connects
Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
This commit is contained in:
parent
b682068dca
commit
866edc2d7d
|
@ -3,6 +3,8 @@
|
|||
use crate::enc::Random;
|
||||
use ::zeroize::Zeroize;
|
||||
|
||||
/// Anonymous user id
|
||||
pub const USERID_ANONYMOUS: UserID = UserID([0; UserID::len()]);
|
||||
/// User identifier. 16 bytes for easy uuid conversion
|
||||
#[derive(Debug, Copy, Clone, PartialEq)]
|
||||
pub struct UserID(pub [u8; 16]);
|
||||
|
@ -25,8 +27,8 @@ impl UserID {
|
|||
}
|
||||
}
|
||||
/// Anonymous user id
|
||||
pub fn new_anonymous() -> Self {
|
||||
UserID([0; 16])
|
||||
pub const fn new_anonymous() -> Self {
|
||||
USERID_ANONYMOUS
|
||||
}
|
||||
/// length of the User ID in bytes
|
||||
pub const fn len() -> usize {
|
||||
|
@ -98,6 +100,16 @@ impl TryFrom<&[u8]> for Domain {
|
|||
Ok(Domain(domain_string))
|
||||
}
|
||||
}
|
||||
impl From<String> for Domain {
|
||||
fn from(raw: String) -> Self {
|
||||
Self(raw)
|
||||
}
|
||||
}
|
||||
impl From<&str> for Domain {
|
||||
fn from(raw: &str) -> Self {
|
||||
Self(raw.to_owned())
|
||||
}
|
||||
}
|
||||
|
||||
impl Domain {
|
||||
/// length of the User ID in bytes
|
||||
|
|
|
@ -16,6 +16,23 @@ use ::std::{
|
|||
vec,
|
||||
};
|
||||
|
||||
/// Key used by a server during the handshake
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ServerKey {
|
||||
pub id: KeyID,
|
||||
pub priv_key: PrivKey,
|
||||
pub pub_key: PubKey,
|
||||
}
|
||||
|
||||
/// Authentication Server information and keys
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AuthServer {
|
||||
/// fqdn of the authentication server
|
||||
pub fqdn: crate::auth::Domain,
|
||||
/// list of key ids enabled for this domain
|
||||
pub keys: Vec<KeyID>,
|
||||
}
|
||||
|
||||
/// Main config for libFenrir
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Config {
|
||||
|
@ -34,8 +51,12 @@ pub struct Config {
|
|||
pub hkdfs: Vec<HkdfKind>,
|
||||
/// Supported Ciphers
|
||||
pub ciphers: Vec<CipherKind>,
|
||||
/// list of authentication servers
|
||||
/// clients will have this empty
|
||||
pub servers: Vec<AuthServer>,
|
||||
/// list of public/private keys
|
||||
pub keys: Vec<(KeyID, PrivKey, PubKey)>,
|
||||
/// clients should have this empty
|
||||
pub server_keys: Vec<ServerKey>,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
|
@ -56,7 +77,8 @@ impl Default for Config {
|
|||
key_exchanges: [KeyExchangeKind::X25519DiffieHellman].to_vec(),
|
||||
hkdfs: [HkdfKind::Sha3].to_vec(),
|
||||
ciphers: [CipherKind::XChaCha20Poly1305].to_vec(),
|
||||
keys: Vec::new(),
|
||||
servers: Vec::new(),
|
||||
server_keys: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -113,10 +113,14 @@ impl Req {
|
|||
+ self.exchange_key.kind().pub_len()
|
||||
}
|
||||
/// return the total length of the cleartext data
|
||||
pub fn encrypted_length(&self) -> usize {
|
||||
pub fn encrypted_length(
|
||||
&self,
|
||||
head_len: HeadLen,
|
||||
tag_len: TagLen,
|
||||
) -> usize {
|
||||
match &self.data {
|
||||
ReqInner::ClearText(data) => data.len(),
|
||||
_ => 0,
|
||||
ReqInner::ClearText(data) => data.len() + head_len.0 + tag_len.0,
|
||||
ReqInner::CipherText(length) => *length,
|
||||
}
|
||||
}
|
||||
/// actual length of the directory synchronized request
|
||||
|
@ -177,11 +181,16 @@ impl super::HandshakeParsing for Req {
|
|||
Some(cipher) => cipher,
|
||||
None => return Err(Error::Parsing),
|
||||
};
|
||||
let (exchange_key, len) = match ExchangePubKey::deserialize(&raw[5..]) {
|
||||
const CURR_SIZE: usize = KeyID::len()
|
||||
+ KeyExchangeKind::len()
|
||||
+ HkdfKind::len()
|
||||
+ CipherKind::len();
|
||||
let (exchange_key, len) =
|
||||
match ExchangePubKey::deserialize(&raw[CURR_SIZE..]) {
|
||||
Ok(exchange_key) => exchange_key,
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
let data = ReqInner::CipherText(raw.len() - (5 + len));
|
||||
let data = ReqInner::CipherText(raw.len() - (CURR_SIZE + len));
|
||||
Ok(HandshakeData::DirSync(DirSync::Req(Self {
|
||||
key_id,
|
||||
exchange,
|
||||
|
@ -436,7 +445,7 @@ impl super::HandshakeParsing for Resp {
|
|||
return Err(Error::NotEnoughData);
|
||||
}
|
||||
let client_key_id: KeyID =
|
||||
KeyID(u16::from_le_bytes(raw[0..2].try_into().unwrap()));
|
||||
KeyID(u16::from_le_bytes(raw[0..KeyID::len()].try_into().unwrap()));
|
||||
Ok(HandshakeData::DirSync(DirSync::Resp(Self {
|
||||
client_key_id,
|
||||
data: RespInner::CipherText(raw[KeyID::len()..].len()),
|
||||
|
@ -453,10 +462,16 @@ impl Resp {
|
|||
+ KeyID::len()
|
||||
}
|
||||
/// return the total length of the cleartext data
|
||||
pub fn encrypted_length(&self) -> usize {
|
||||
pub fn encrypted_length(
|
||||
&self,
|
||||
head_len: HeadLen,
|
||||
tag_len: TagLen,
|
||||
) -> usize {
|
||||
match &self.data {
|
||||
RespInner::ClearText(_data) => RespData::len(),
|
||||
_ => 0,
|
||||
RespInner::ClearText(_data) => {
|
||||
RespData::len() + head_len.0 + tag_len.0
|
||||
}
|
||||
RespInner::CipherText(len) => *len,
|
||||
}
|
||||
}
|
||||
/// Total length of the response handshake
|
||||
|
@ -471,8 +486,9 @@ impl Resp {
|
|||
_tag_len: TagLen,
|
||||
out: &mut [u8],
|
||||
) {
|
||||
out[0..2].copy_from_slice(&self.client_key_id.0.to_le_bytes());
|
||||
let start_data = 2 + head_len.0;
|
||||
out[0..KeyID::len()]
|
||||
.copy_from_slice(&self.client_key_id.0.to_le_bytes());
|
||||
let start_data = KeyID::len() + head_len.0;
|
||||
let end_data = start_data + self.data.len();
|
||||
self.data.serialize(&mut out[start_data..end_data]);
|
||||
}
|
||||
|
|
|
@ -37,6 +37,12 @@ pub enum Error {
|
|||
/// Too many client handshakes currently running
|
||||
#[error("Too many client handshakes")]
|
||||
TooManyClientHandshakes,
|
||||
/// generic internal error
|
||||
#[error("Internal tracking error")]
|
||||
InternalTracking,
|
||||
/// Handshake Timeout
|
||||
#[error("Handshake timeout")]
|
||||
Timeout,
|
||||
}
|
||||
|
||||
/// List of possible handshakes
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
//! Handhsake handling
|
||||
|
||||
use crate::{
|
||||
auth::ServiceID,
|
||||
auth::{Domain, ServiceID},
|
||||
connection::{
|
||||
self,
|
||||
handshake::{self, Error, Handshake},
|
||||
Connection, IDRecv,
|
||||
Connection, IDRecv, IDSend,
|
||||
},
|
||||
enc::{
|
||||
self,
|
||||
|
@ -16,16 +16,23 @@ use crate::{
|
|||
inner::ThreadTracker,
|
||||
};
|
||||
|
||||
use ::tokio::sync::oneshot;
|
||||
|
||||
pub(crate) struct HandshakeServer {
|
||||
pub id: KeyID,
|
||||
pub key: PrivKey,
|
||||
pub domains: Vec<Domain>,
|
||||
}
|
||||
|
||||
pub(crate) type ConnectAnswer = Result<(KeyID, IDSend), crate::Error>;
|
||||
|
||||
pub(crate) struct HandshakeClient {
|
||||
pub service_id: ServiceID,
|
||||
pub service_conn_id: IDRecv,
|
||||
pub connection: Connection,
|
||||
pub timeout: Option<::tokio::task::JoinHandle<()>>,
|
||||
pub answer: oneshot::Sender<ConnectAnswer>,
|
||||
pub srv_key_id: KeyID,
|
||||
}
|
||||
|
||||
/// Tracks the keys used by the client and the handshake
|
||||
|
@ -73,7 +80,10 @@ impl HandshakeClientList {
|
|||
service_id: ServiceID,
|
||||
service_conn_id: IDRecv,
|
||||
connection: Connection,
|
||||
) -> Result<(KeyID, &mut HandshakeClient), ()> {
|
||||
answer: oneshot::Sender<ConnectAnswer>,
|
||||
srv_key_id: KeyID,
|
||||
) -> Result<(KeyID, &mut HandshakeClient), oneshot::Sender<ConnectAnswer>>
|
||||
{
|
||||
let maybe_free_key_idx =
|
||||
self.used.iter().enumerate().find_map(|(idx, bmap)| {
|
||||
match bmap.first_false_index() {
|
||||
|
@ -85,7 +95,7 @@ impl HandshakeClientList {
|
|||
Some((idx, false_idx)) => {
|
||||
let free_key_idx = idx * 1024 + false_idx;
|
||||
if free_key_idx > KeyID::MAX as usize {
|
||||
return Err(());
|
||||
return Err(answer);
|
||||
}
|
||||
self.used[idx].set(false_idx, true);
|
||||
free_key_idx
|
||||
|
@ -107,6 +117,8 @@ impl HandshakeClientList {
|
|||
service_conn_id,
|
||||
connection,
|
||||
timeout: None,
|
||||
answer,
|
||||
srv_key_id,
|
||||
});
|
||||
Ok((
|
||||
KeyID(free_key_idx as u16),
|
||||
|
@ -136,6 +148,10 @@ pub(crate) struct ClientConnectInfo {
|
|||
pub handshake: Handshake,
|
||||
/// Connection
|
||||
pub connection: Connection,
|
||||
/// where to wake up the waiting client
|
||||
pub answer: oneshot::Sender<ConnectAnswer>,
|
||||
/// server public key id that we used on the handshake
|
||||
pub srv_key_id: KeyID,
|
||||
}
|
||||
/// Intermediate actions to be taken while parsing the handshake
|
||||
#[derive(Debug)]
|
||||
|
@ -177,10 +193,42 @@ impl HandshakeTracker {
|
|||
hshake_cli: HandshakeClientList::new(),
|
||||
}
|
||||
}
|
||||
pub(crate) fn add_server(&mut self, id: KeyID, key: PrivKey) {
|
||||
self.keys_srv.push(HandshakeServer { id, key });
|
||||
self.keys_srv.sort_by(|h_a, h_b| h_a.id.0.cmp(&h_b.id.0));
|
||||
pub(crate) fn add_server_key(
|
||||
&mut self,
|
||||
id: KeyID,
|
||||
key: PrivKey,
|
||||
) -> Result<(), ()> {
|
||||
if self.keys_srv.iter().find(|&k| k.id == id).is_some() {
|
||||
return Err(());
|
||||
}
|
||||
self.keys_srv.push(HandshakeServer {
|
||||
id,
|
||||
key,
|
||||
domains: Vec::new(),
|
||||
});
|
||||
self.keys_srv.sort_by(|h_a, h_b| h_a.id.0.cmp(&h_b.id.0));
|
||||
Ok(())
|
||||
}
|
||||
pub(crate) fn add_server_domain(
|
||||
&mut self,
|
||||
domain: &Domain,
|
||||
key_ids: &[KeyID],
|
||||
) -> Result<(), ()> {
|
||||
// check that all the key ids are present
|
||||
for id in key_ids.iter() {
|
||||
if self.keys_srv.iter().find(|k| k.id == *id).is_none() {
|
||||
return Err(());
|
||||
}
|
||||
}
|
||||
// add the domain to those keys
|
||||
for id in key_ids.iter() {
|
||||
if let Some(srv) = self.keys_srv.iter_mut().find(|k| k.id == *id) {
|
||||
srv.domains.push(domain.clone());
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn add_client(
|
||||
&mut self,
|
||||
priv_key: PrivKey,
|
||||
|
@ -188,20 +236,32 @@ impl HandshakeTracker {
|
|||
service_id: ServiceID,
|
||||
service_conn_id: IDRecv,
|
||||
connection: Connection,
|
||||
) -> Result<(KeyID, &mut HandshakeClient), ()> {
|
||||
answer: oneshot::Sender<ConnectAnswer>,
|
||||
srv_key_id: KeyID,
|
||||
) -> Result<(KeyID, &mut HandshakeClient), oneshot::Sender<ConnectAnswer>>
|
||||
{
|
||||
self.hshake_cli.add(
|
||||
priv_key,
|
||||
pub_key,
|
||||
service_id,
|
||||
service_conn_id,
|
||||
connection,
|
||||
answer,
|
||||
srv_key_id,
|
||||
)
|
||||
}
|
||||
pub(crate) fn remove_client(
|
||||
&mut self,
|
||||
key_id: KeyID,
|
||||
) -> Option<HandshakeClient> {
|
||||
self.hshake_cli.remove(key_id)
|
||||
}
|
||||
pub(crate) fn timeout_client(
|
||||
&mut self,
|
||||
key_id: KeyID,
|
||||
) -> Option<[IDRecv; 2]> {
|
||||
if let Some(hshake) = self.hshake_cli.remove(key_id) {
|
||||
let _ = hshake.answer.send(Err(Error::Timeout.into()));
|
||||
Some([hshake.connection.id_recv, hshake.service_conn_id])
|
||||
} else {
|
||||
None
|
||||
|
@ -257,9 +317,16 @@ impl HandshakeTracker {
|
|||
let cipher_recv = CipherRecv::new(req.cipher, secret_recv);
|
||||
use crate::enc::sym::AAD;
|
||||
let aad = AAD(&mut []); // no aad for now
|
||||
|
||||
let encrypt_from = req.encrypted_offset();
|
||||
let encrypt_to = encrypt_from
|
||||
+ req.encrypted_length(
|
||||
cipher_recv.nonce_len(),
|
||||
cipher_recv.tag_len(),
|
||||
);
|
||||
match cipher_recv.decrypt(
|
||||
aad,
|
||||
&mut handshake_raw[req.encrypted_offset()..],
|
||||
&mut handshake_raw[encrypt_from..encrypt_to],
|
||||
) {
|
||||
Ok(cleartext) => {
|
||||
req.data.deserialize_as_cleartext(cleartext)?;
|
||||
|
@ -292,9 +359,13 @@ impl HandshakeTracker {
|
|||
use crate::enc::sym::AAD;
|
||||
// no aad for now
|
||||
let aad = AAD(&mut []);
|
||||
let mut raw_data = &mut handshake_raw[resp
|
||||
.encrypted_offset()
|
||||
..(resp.encrypted_offset() + resp.encrypted_length())];
|
||||
let data_from = resp.encrypted_offset();
|
||||
let data_to = data_from
|
||||
+ resp.encrypted_length(
|
||||
cipher_recv.nonce_len(),
|
||||
cipher_recv.tag_len(),
|
||||
);
|
||||
let mut raw_data = &mut handshake_raw[data_from..data_to];
|
||||
match cipher_recv.decrypt(aad, &mut raw_data) {
|
||||
Ok(cleartext) => {
|
||||
resp.data.deserialize_as_cleartext(&cleartext)?;
|
||||
|
@ -314,6 +385,8 @@ impl HandshakeTracker {
|
|||
service_connection_id: hshake.service_conn_id,
|
||||
handshake,
|
||||
connection: hshake.connection,
|
||||
answer: hshake.answer,
|
||||
srv_key_id: hshake.srv_key_id,
|
||||
},
|
||||
));
|
||||
}
|
||||
|
|
|
@ -113,7 +113,11 @@ pub(crate) struct ConnList {
|
|||
|
||||
impl ConnList {
|
||||
pub(crate) fn new(thread_id: ThreadTracker) -> Self {
|
||||
let bitmap_id = ::bitmaps::Bitmap::<1024>::new();
|
||||
let mut bitmap_id = ::bitmaps::Bitmap::<1024>::new();
|
||||
if thread_id.id == 0 {
|
||||
// make sure we don't count the Handshake ID
|
||||
bitmap_id.set(0, true);
|
||||
}
|
||||
const INITIAL_CAP: usize = 128;
|
||||
let mut ret = Self {
|
||||
thread_id,
|
||||
|
@ -199,13 +203,6 @@ impl ConnList {
|
|||
}
|
||||
}
|
||||
|
||||
use ::std::collections::HashMap;
|
||||
|
||||
enum MapEntry {
|
||||
Present(IDSend),
|
||||
Reserved,
|
||||
}
|
||||
|
||||
/// return wether we already have a connection, we are waiting for one, or you
|
||||
/// can start one
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
|
@ -218,6 +215,12 @@ pub(crate) enum Reservation {
|
|||
Reserved,
|
||||
}
|
||||
|
||||
enum MapEntry {
|
||||
Present(IDSend),
|
||||
Reserved,
|
||||
}
|
||||
use ::std::collections::HashMap;
|
||||
|
||||
/// Link the public key of the authentication server to a connection id
|
||||
/// so that we can reuse that connection to ask for more authentications
|
||||
///
|
||||
|
@ -229,16 +232,16 @@ pub(crate) enum Reservation {
|
|||
/// * wait for the connection to finish
|
||||
/// * remove all those reservations, exept the one key that actually succeded
|
||||
/// While searching, we return a connection ID if just one key is a match
|
||||
// TODO: can we shard this per-core by hashing the pubkey? or domain? or...???
|
||||
// This needs a mutex and it will be our goeal to avoid any synchronization
|
||||
pub(crate) struct AuthServerConnections {
|
||||
conn_map: HashMap<PubKey, MapEntry>,
|
||||
next_reservation: u64,
|
||||
}
|
||||
|
||||
impl AuthServerConnections {
|
||||
pub(crate) fn new() -> Self {
|
||||
Self {
|
||||
conn_map: HashMap::with_capacity(32),
|
||||
next_reservation: 0,
|
||||
}
|
||||
}
|
||||
/// add an ID to the reserved spot,
|
||||
|
|
|
@ -1,40 +1,10 @@
|
|||
//! Socket related types and functions
|
||||
|
||||
use ::std::{net::SocketAddr, sync::Arc, vec::Vec};
|
||||
use ::std::net::SocketAddr;
|
||||
use ::tokio::{net::UdpSocket, task::JoinHandle};
|
||||
|
||||
/// Pair to easily track the socket and its async listening handle
|
||||
pub type SocketTracker =
|
||||
(Arc<UdpSocket>, Arc<JoinHandle<::std::io::Result<()>>>);
|
||||
|
||||
/// async free socket list
|
||||
pub(crate) struct SocketList {
|
||||
pub list: Vec<SocketTracker>,
|
||||
}
|
||||
impl SocketList {
|
||||
pub(crate) fn new() -> Self {
|
||||
Self { list: Vec::new() }
|
||||
}
|
||||
pub(crate) fn rm_all(&mut self) -> Self {
|
||||
let mut old_list = Vec::new();
|
||||
::core::mem::swap(&mut self.list, &mut old_list);
|
||||
Self { list: old_list }
|
||||
}
|
||||
pub(crate) async fn add_socket(
|
||||
&mut self,
|
||||
socket: Arc<UdpSocket>,
|
||||
handle: JoinHandle<::std::io::Result<()>>,
|
||||
) {
|
||||
let arc_handle = Arc::new(handle);
|
||||
self.list.push((socket, arc_handle));
|
||||
}
|
||||
/// This method assumes no other `add_sockets` are being run
|
||||
pub(crate) async fn stop_all(self) {
|
||||
for (_socket, mut handle) in self.list.into_iter() {
|
||||
let _ = Arc::get_mut(&mut handle).unwrap().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
pub type SocketTracker = (SocketAddr, JoinHandle<::std::io::Result<()>>);
|
||||
|
||||
/// Strong typedef for a client socket address
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
|
@ -53,7 +23,7 @@ fn enable_sock_opt(
|
|||
unsafe {
|
||||
#[allow(trivial_casts)]
|
||||
let val = &value as *const _ as *const ::libc::c_void;
|
||||
let size = ::std::mem::size_of_val(&value) as ::libc::socklen_t;
|
||||
let size = ::core::mem::size_of_val(&value) as ::libc::socklen_t;
|
||||
// always clear the error bit before doing a new syscall
|
||||
let _ = ::std::io::Error::last_os_error();
|
||||
let ret = ::libc::setsockopt(fd, ::libc::SOL_SOCKET, option, val, size);
|
||||
|
@ -64,23 +34,107 @@ fn enable_sock_opt(
|
|||
Ok(())
|
||||
}
|
||||
/// Add an async udp listener
|
||||
pub async fn bind_udp(sock: SocketAddr) -> ::std::io::Result<UdpSocket> {
|
||||
let socket = UdpSocket::bind(sock).await?;
|
||||
pub async fn bind_udp(addr: SocketAddr) -> ::std::io::Result<UdpSocket> {
|
||||
// I know, kind of a mess. but I really wanted SO_REUSE{ADDR,PORT} and
|
||||
// no-fragmenting stuff.
|
||||
// I also did not want to load another library for this.
|
||||
// feel free to simplify,
|
||||
// especially if we can avoid libc and other libraries
|
||||
// we currently use libc because it's a dependency of many other deps
|
||||
|
||||
use ::std::os::fd::AsRawFd;
|
||||
let fd = socket.as_raw_fd();
|
||||
// can be useful later on for reloads
|
||||
enable_sock_opt(fd, ::libc::SO_REUSEADDR, 1)?;
|
||||
enable_sock_opt(fd, ::libc::SO_REUSEPORT, 1)?;
|
||||
let fd: ::std::os::fd::RawFd = {
|
||||
let domain = if addr.is_ipv6() {
|
||||
::libc::AF_INET6
|
||||
} else {
|
||||
::libc::AF_INET
|
||||
};
|
||||
#[allow(unsafe_code)]
|
||||
let tmp = unsafe { ::libc::socket(domain, ::libc::SOCK_DGRAM, 0) };
|
||||
let lasterr = ::std::io::Error::last_os_error();
|
||||
if tmp == -1 {
|
||||
return Err(lasterr);
|
||||
}
|
||||
tmp.into()
|
||||
};
|
||||
|
||||
if let Err(e) = enable_sock_opt(fd, ::libc::SO_REUSEPORT, 1) {
|
||||
#[allow(unsafe_code)]
|
||||
unsafe {
|
||||
::libc::close(fd);
|
||||
}
|
||||
return Err(e);
|
||||
}
|
||||
if let Err(e) = enable_sock_opt(fd, ::libc::SO_REUSEADDR, 1) {
|
||||
#[allow(unsafe_code)]
|
||||
unsafe {
|
||||
::libc::close(fd);
|
||||
}
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
// We will do path MTU discovery by ourselves,
|
||||
// always set the "don't fragment" bit
|
||||
if sock.is_ipv6() {
|
||||
enable_sock_opt(fd, ::libc::IPV6_DONTFRAG, 1)?;
|
||||
let res = if addr.is_ipv6() {
|
||||
enable_sock_opt(fd, ::libc::IPV6_DONTFRAG, 1)
|
||||
} else {
|
||||
// FIXME: linux only
|
||||
enable_sock_opt(fd, ::libc::IP_MTU_DISCOVER, ::libc::IP_PMTUDISC_DO)?;
|
||||
enable_sock_opt(fd, ::libc::IP_MTU_DISCOVER, ::libc::IP_PMTUDISC_DO)
|
||||
};
|
||||
if let Err(e) = res {
|
||||
#[allow(unsafe_code)]
|
||||
unsafe {
|
||||
::libc::close(fd);
|
||||
}
|
||||
return Err(e);
|
||||
}
|
||||
// manually convert rust SockAddr to C sockaddr
|
||||
#[allow(unsafe_code, trivial_casts, trivial_numeric_casts)]
|
||||
{
|
||||
let bind_ret = match addr {
|
||||
SocketAddr::V4(s4) => {
|
||||
let ip4: u32 = (*s4.ip()).into();
|
||||
let bind_addr = ::libc::sockaddr_in {
|
||||
sin_family: ::libc::AF_INET as u16,
|
||||
sin_port: s4.port().to_be(),
|
||||
sin_addr: ::libc::in_addr { s_addr: ip4 },
|
||||
sin_zero: [0; 8],
|
||||
};
|
||||
unsafe {
|
||||
let c_addr =
|
||||
&bind_addr as *const _ as *const ::libc::sockaddr;
|
||||
::libc::bind(fd, c_addr, 16)
|
||||
}
|
||||
}
|
||||
SocketAddr::V6(s6) => {
|
||||
let ip6: [u8; 16] = (*s6.ip()).octets();
|
||||
let bind_addr = ::libc::sockaddr_in6 {
|
||||
sin6_family: ::libc::AF_INET6 as u16,
|
||||
sin6_port: s6.port().to_be(),
|
||||
sin6_flowinfo: 0,
|
||||
sin6_addr: ::libc::in6_addr { s6_addr: ip6 },
|
||||
sin6_scope_id: 0,
|
||||
};
|
||||
unsafe {
|
||||
let c_addr =
|
||||
&bind_addr as *const _ as *const ::libc::sockaddr;
|
||||
::libc::bind(fd, c_addr, 24)
|
||||
}
|
||||
}
|
||||
};
|
||||
let lasterr = ::std::io::Error::last_os_error();
|
||||
if bind_ret != 0 {
|
||||
unsafe {
|
||||
::libc::close(fd);
|
||||
}
|
||||
return Err(lasterr);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(socket)
|
||||
use ::std::os::fd::FromRawFd;
|
||||
#[allow(unsafe_code)]
|
||||
let std_sock = unsafe { ::std::net::UdpSocket::from_raw_fd(fd) };
|
||||
std_sock.set_nonblocking(true)?;
|
||||
::tracing::debug!("Listening udp sock: {}", std_sock.local_addr().unwrap());
|
||||
|
||||
Ok(UdpSocket::from_std(std_sock)?)
|
||||
}
|
||||
|
|
|
@ -152,7 +152,7 @@ pub enum KeyExchangeKind {
|
|||
}
|
||||
impl KeyExchangeKind {
|
||||
/// The serialize length of the field
|
||||
pub fn len() -> usize {
|
||||
pub const fn len() -> usize {
|
||||
1
|
||||
}
|
||||
/// Build a new keypair for key exchange
|
||||
|
|
|
@ -4,6 +4,8 @@ pub mod asym;
|
|||
mod errors;
|
||||
pub mod hkdf;
|
||||
pub mod sym;
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
pub use errors::Error;
|
||||
|
||||
|
|
|
@ -25,12 +25,11 @@ pub enum CipherKind {
|
|||
|
||||
impl CipherKind {
|
||||
/// length of the serialized id for the cipher kind field
|
||||
pub fn len() -> usize {
|
||||
pub const fn len() -> usize {
|
||||
1
|
||||
}
|
||||
/// required length of the nonce
|
||||
pub fn nonce_len(&self) -> HeadLen {
|
||||
// TODO: how the hell do I take this from ::chacha20poly1305?
|
||||
HeadLen(Nonce::len())
|
||||
}
|
||||
/// required length of the key
|
||||
|
@ -92,10 +91,7 @@ impl Cipher {
|
|||
}
|
||||
fn nonce_len(&self) -> HeadLen {
|
||||
match self {
|
||||
Cipher::XChaCha20Poly1305(_) => {
|
||||
// TODO: how the hell do I take this from ::chacha20poly1305?
|
||||
HeadLen(::ring::aead::CHACHA20_POLY1305.nonce_len())
|
||||
}
|
||||
Cipher::XChaCha20Poly1305(_) => HeadLen(Nonce::len()),
|
||||
}
|
||||
}
|
||||
fn tag_len(&self) -> TagLen {
|
||||
|
@ -117,10 +113,13 @@ impl Cipher {
|
|||
aead::generic_array::GenericArray, AeadInPlace,
|
||||
};
|
||||
let final_len: usize = {
|
||||
// FIXME: check min data length
|
||||
let (nonce_bytes, data_and_tag) = raw_data.split_at_mut(13);
|
||||
if raw_data.len() <= self.overhead() {
|
||||
return Err(Error::NotEnoughData(raw_data.len()));
|
||||
}
|
||||
let (nonce_bytes, data_and_tag) =
|
||||
raw_data.split_at_mut(Nonce::len());
|
||||
let (data_notag, tag_bytes) = data_and_tag.split_at_mut(
|
||||
data_and_tag.len() + 1
|
||||
data_and_tag.len()
|
||||
- ::ring::aead::CHACHA20_POLY1305.tag_len(),
|
||||
);
|
||||
let nonce = GenericArray::from_slice(nonce_bytes);
|
||||
|
@ -172,10 +171,7 @@ impl Cipher {
|
|||
&mut data[Nonce::len()..data_len_notag],
|
||||
) {
|
||||
Ok(tag) => {
|
||||
data[data_len_notag..]
|
||||
// add tag
|
||||
//data.get_tag_slice()
|
||||
.copy_from_slice(tag.as_slice());
|
||||
data[data_len_notag..].copy_from_slice(tag.as_slice());
|
||||
Ok(())
|
||||
}
|
||||
Err(_) => Err(Error::Encrypt),
|
||||
|
@ -205,6 +201,10 @@ impl CipherRecv {
|
|||
pub fn nonce_len(&self) -> HeadLen {
|
||||
self.0.nonce_len()
|
||||
}
|
||||
/// Get the length of the nonce for this cipher
|
||||
pub fn tag_len(&self) -> TagLen {
|
||||
self.0.tag_len()
|
||||
}
|
||||
/// Decrypt a paket. Nonce and Tag are taken from the packet,
|
||||
/// while you need to provide AAD (Additional Authenticated Data)
|
||||
pub fn decrypt<'a>(
|
||||
|
@ -285,7 +285,7 @@ struct NonceNum {
|
|||
#[repr(C)]
|
||||
pub union Nonce {
|
||||
num: NonceNum,
|
||||
raw: [u8; 12],
|
||||
raw: [u8; Self::len()],
|
||||
}
|
||||
|
||||
impl ::core::fmt::Debug for Nonce {
|
||||
|
@ -303,13 +303,17 @@ impl ::core::fmt::Debug for Nonce {
|
|||
impl Nonce {
|
||||
/// Generate a new random Nonce
|
||||
pub fn new(rand: &Random) -> Self {
|
||||
let mut raw = [0; 12];
|
||||
let mut raw = [0; Self::len()];
|
||||
rand.fill(&mut raw);
|
||||
Self { raw }
|
||||
}
|
||||
/// Length of this nonce in bytes
|
||||
pub const fn len() -> usize {
|
||||
return 12;
|
||||
// FIXME: was:12. xchacha20poly1305 requires 24.
|
||||
// but we should change keys much earlier than that, and our
|
||||
// nonces are not random, but sequential.
|
||||
// we should change keys every 2^30 bytes to be sure (stream max window)
|
||||
return 24;
|
||||
}
|
||||
/// Get reference to the nonce bytes
|
||||
pub fn as_bytes(&self) -> &[u8] {
|
||||
|
@ -319,7 +323,7 @@ impl Nonce {
|
|||
}
|
||||
}
|
||||
/// Create Nonce from array
|
||||
pub fn from_slice(raw: [u8; 12]) -> Self {
|
||||
pub fn from_slice(raw: [u8; Self::len()]) -> Self {
|
||||
Self { raw }
|
||||
}
|
||||
/// Go to the next nonce
|
||||
|
@ -336,6 +340,7 @@ impl Nonce {
|
|||
}
|
||||
|
||||
/// Synchronize the mutex acess with a nonce for multithread safety
|
||||
// TODO: remove mutex, not needed anymore
|
||||
#[derive(Debug)]
|
||||
pub struct NonceSync {
|
||||
nonce: ::std::sync::Mutex<Nonce>,
|
||||
|
|
|
@ -0,0 +1,135 @@
|
|||
use crate::{
|
||||
auth,
|
||||
connection::{handshake::*, ID},
|
||||
enc::{self, asym::KeyID},
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_simple_encrypt_decrypt() {
|
||||
let rand = enc::Random::new();
|
||||
let cipher = enc::sym::CipherKind::XChaCha20Poly1305;
|
||||
let secret = enc::Secret::new_rand(&rand);
|
||||
let secret2 = secret.clone();
|
||||
|
||||
let cipher_send = enc::sym::CipherSend::new(cipher, secret, &rand);
|
||||
let cipher_recv = enc::sym::CipherRecv::new(cipher, secret2);
|
||||
|
||||
let mut data = Vec::new();
|
||||
let tot_len = cipher_recv.nonce_len().0 + 1234 + cipher_recv.tag_len().0;
|
||||
data.resize(tot_len, 0);
|
||||
rand.fill(&mut data);
|
||||
data[..enc::sym::Nonce::len()].copy_from_slice(&[0; 24]);
|
||||
let last = data.len() - cipher_recv.tag_len().0;
|
||||
data[last..].copy_from_slice(&[0; 16]);
|
||||
let orig = data.clone();
|
||||
let raw_aad: [u8; 8] = [0, 1, 2, 3, 4, 5, 6, 7];
|
||||
let aad = enc::sym::AAD(&raw_aad[..]);
|
||||
let aad2 = enc::sym::AAD(&raw_aad[..]);
|
||||
if cipher_send.encrypt(aad, &mut data).is_err() {
|
||||
assert!(false, "Encrypt failed");
|
||||
}
|
||||
if cipher_recv.decrypt(aad2, &mut data).is_err() {
|
||||
assert!(false, "Decrypt failed");
|
||||
}
|
||||
data[..enc::sym::Nonce::len()].copy_from_slice(&[0; 24]);
|
||||
let last = data.len() - cipher_recv.tag_len().0;
|
||||
data[last..].copy_from_slice(&[0; 16]);
|
||||
assert!(orig == data, "DIFFERENT!\n{:?}\n{:?}\n", orig, data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encrypt_decrypt() {
|
||||
let rand = enc::Random::new();
|
||||
let cipher = enc::sym::CipherKind::XChaCha20Poly1305;
|
||||
let secret = enc::Secret::new_rand(&rand);
|
||||
let secret2 = secret.clone();
|
||||
|
||||
let cipher_send = enc::sym::CipherSend::new(cipher, secret, &rand);
|
||||
let cipher_recv = enc::sym::CipherRecv::new(cipher, secret2);
|
||||
let nonce_len = cipher_recv.nonce_len();
|
||||
let tag_len = cipher_recv.tag_len();
|
||||
|
||||
let service_key = enc::Secret::new_rand(&rand);
|
||||
|
||||
let data = dirsync::RespInner::ClearText(dirsync::RespData {
|
||||
client_nonce: dirsync::Nonce::new(&rand),
|
||||
id: ID::ID(::core::num::NonZeroU64::new(424242).unwrap()),
|
||||
service_connection_id: ID::ID(
|
||||
::core::num::NonZeroU64::new(434343).unwrap(),
|
||||
),
|
||||
service_key,
|
||||
});
|
||||
|
||||
let resp = dirsync::Resp {
|
||||
client_key_id: KeyID(4444),
|
||||
data,
|
||||
};
|
||||
let encrypt_from = resp.encrypted_offset();
|
||||
let encrypt_to = encrypt_from + resp.encrypted_length(nonce_len, tag_len);
|
||||
|
||||
let h_resp =
|
||||
Handshake::new(HandshakeData::DirSync(dirsync::DirSync::Resp(resp)));
|
||||
|
||||
let mut bytes = Vec::<u8>::with_capacity(
|
||||
h_resp.len(cipher.nonce_len(), cipher.tag_len()),
|
||||
);
|
||||
bytes.resize(h_resp.len(cipher.nonce_len(), cipher.tag_len()), 0);
|
||||
h_resp.serialize(cipher.nonce_len(), cipher.tag_len(), &mut bytes);
|
||||
|
||||
let raw_aad: [u8; 7] = [0, 1, 2, 3, 4, 5, 6];
|
||||
let aad = enc::sym::AAD(&raw_aad[..]);
|
||||
let aad2 = enc::sym::AAD(&raw_aad[..]);
|
||||
|
||||
let pre_encrypt = bytes.clone();
|
||||
// encrypt
|
||||
if cipher_send
|
||||
.encrypt(aad, &mut bytes[encrypt_from..encrypt_to])
|
||||
.is_err()
|
||||
{
|
||||
assert!(false, "Encrypt failed");
|
||||
}
|
||||
if cipher_recv
|
||||
.decrypt(aad2, &mut bytes[encrypt_from..encrypt_to])
|
||||
.is_err()
|
||||
{
|
||||
assert!(false, "Decrypt failed");
|
||||
}
|
||||
// make sure Nonce and Tag are 0
|
||||
bytes[encrypt_from..(encrypt_from + nonce_len.0)].copy_from_slice(&[0; 24]);
|
||||
let tag_from = encrypt_to - tag_len.0;
|
||||
bytes[tag_from..(tag_from + tag_len.0)].copy_from_slice(&[0; 16]);
|
||||
assert!(
|
||||
pre_encrypt == bytes,
|
||||
"{}|{}=\n{:?}\n{:?}",
|
||||
encrypt_from,
|
||||
encrypt_to,
|
||||
pre_encrypt,
|
||||
bytes
|
||||
);
|
||||
|
||||
// decrypt
|
||||
|
||||
let mut deserialized = match Handshake::deserialize(&bytes) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(e) => {
|
||||
assert!(false, "{}", e.to_string());
|
||||
return;
|
||||
}
|
||||
};
|
||||
// reparse
|
||||
if let HandshakeData::DirSync(dirsync::DirSync::Resp(r_a)) =
|
||||
&mut deserialized.data
|
||||
{
|
||||
let enc_start = r_a.encrypted_offset() + cipher.nonce_len().0;
|
||||
if let Err(e) = r_a.data.deserialize_as_cleartext(
|
||||
&bytes[enc_start..(bytes.len() - cipher.tag_len().0)],
|
||||
) {
|
||||
assert!(false, "DirSync Resp Inner serialize: {}", e.to_string());
|
||||
}
|
||||
};
|
||||
|
||||
assert!(
|
||||
deserialized == h_resp,
|
||||
"DirSync Resp (de)serialization not working",
|
||||
);
|
||||
}
|
|
@ -37,7 +37,7 @@ pub(crate) struct RawUdp {
|
|||
}
|
||||
|
||||
pub(crate) struct ConnectInfo {
|
||||
pub answer: oneshot::Sender<Result<(PubKey, IDSend), crate::Error>>,
|
||||
pub answer: oneshot::Sender<handshake::tracker::ConnectAnswer>,
|
||||
pub resolved: dnssec::Record,
|
||||
pub service_id: ServiceID,
|
||||
pub domain: Domain,
|
||||
|
@ -57,14 +57,15 @@ pub(crate) enum WorkAnswer {
|
|||
}
|
||||
|
||||
/// Actual worker implementation.
|
||||
pub(crate) struct Worker {
|
||||
#[allow(missing_debug_implementations)]
|
||||
pub struct Worker {
|
||||
cfg: Config,
|
||||
thread_id: ThreadTracker,
|
||||
// PERF: rand uses syscalls. how to do that async?
|
||||
rand: Random,
|
||||
stop_working: crate::StopWorkingRecvCh,
|
||||
token_check: Option<Arc<Mutex<TokenChecker>>>,
|
||||
sockets: Vec<UdpSocket>,
|
||||
sockets: Vec<Arc<UdpSocket>>,
|
||||
queue: ::async_channel::Receiver<Work>,
|
||||
queue_timeouts_recv: mpsc::UnboundedReceiver<Work>,
|
||||
queue_timeouts_send: mpsc::UnboundedSender<Work>,
|
||||
|
@ -73,64 +74,18 @@ pub(crate) struct Worker {
|
|||
handshakes: HandshakeTracker,
|
||||
}
|
||||
|
||||
#[allow(unsafe_code)]
|
||||
unsafe impl Send for Worker {}
|
||||
|
||||
impl Worker {
|
||||
pub(crate) async fn new_and_loop(
|
||||
cfg: Config,
|
||||
thread_id: ThreadTracker,
|
||||
stop_working: crate::StopWorkingRecvCh,
|
||||
token_check: Option<Arc<Mutex<TokenChecker>>>,
|
||||
socket_addrs: Vec<::std::net::SocketAddr>,
|
||||
queue: ::async_channel::Receiver<Work>,
|
||||
) -> ::std::io::Result<()> {
|
||||
// TODO: get a channel to send back information, and send the error
|
||||
let mut worker = Self::new(
|
||||
cfg,
|
||||
thread_id,
|
||||
stop_working,
|
||||
token_check,
|
||||
socket_addrs,
|
||||
queue,
|
||||
)
|
||||
.await?;
|
||||
worker.work_loop().await;
|
||||
Ok(())
|
||||
}
|
||||
pub(crate) async fn new(
|
||||
mut cfg: Config,
|
||||
thread_id: ThreadTracker,
|
||||
stop_working: crate::StopWorkingRecvCh,
|
||||
token_check: Option<Arc<Mutex<TokenChecker>>>,
|
||||
socket_addrs: Vec<::std::net::SocketAddr>,
|
||||
sockets: Vec<Arc<UdpSocket>>,
|
||||
queue: ::async_channel::Receiver<Work>,
|
||||
) -> ::std::io::Result<Self> {
|
||||
// bind all sockets again so that we can easily
|
||||
// send without sharing resources
|
||||
// in the future we will want to have a thread-local listener too,
|
||||
// but before that we need ebpf to pin a connection to a thread
|
||||
// directly from the kernel
|
||||
let mut sock_set = ::tokio::task::JoinSet::new();
|
||||
socket_addrs.into_iter().for_each(|s_addr| {
|
||||
sock_set.spawn(async move {
|
||||
let socket =
|
||||
connection::socket::bind_udp(s_addr.clone()).await?;
|
||||
Ok(socket)
|
||||
});
|
||||
});
|
||||
// make sure we either add all of them, or none
|
||||
let mut sockets = Vec::with_capacity(cfg.listen.len());
|
||||
while let Some(join_res) = sock_set.join_next().await {
|
||||
match join_res {
|
||||
Ok(s_res) => match s_res {
|
||||
Ok(sock) => sockets.push(sock),
|
||||
Err(e) => {
|
||||
::tracing::error!("Can't rebind socket");
|
||||
return Err(e);
|
||||
}
|
||||
},
|
||||
Err(e) => return Err(e.into()),
|
||||
}
|
||||
}
|
||||
|
||||
let (queue_timeouts_send, queue_timeouts_recv) =
|
||||
mpsc::unbounded_channel();
|
||||
let mut handshakes = HandshakeTracker::new(
|
||||
|
@ -138,11 +93,24 @@ impl Worker {
|
|||
cfg.ciphers.clone(),
|
||||
cfg.key_exchanges.clone(),
|
||||
);
|
||||
let mut keys = Vec::new();
|
||||
let mut server_keys = Vec::new();
|
||||
// make sure the keys are no longer in the config
|
||||
::core::mem::swap(&mut keys, &mut cfg.keys);
|
||||
for k in keys.into_iter() {
|
||||
handshakes.add_server(k.0, k.1);
|
||||
::core::mem::swap(&mut server_keys, &mut cfg.server_keys);
|
||||
for k in server_keys.into_iter() {
|
||||
if handshakes.add_server_key(k.id, k.priv_key).is_err() {
|
||||
return Err(::std::io::Error::new(
|
||||
::std::io::ErrorKind::AlreadyExists,
|
||||
"You can't use the same KeyID for multiple keys",
|
||||
));
|
||||
}
|
||||
}
|
||||
for srv in cfg.servers.iter() {
|
||||
if handshakes.add_server_domain(&srv.fqdn, &srv.keys).is_err() {
|
||||
return Err(::std::io::Error::new(
|
||||
::std::io::ErrorKind::NotFound,
|
||||
"Specified a KeyID that we don't have",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
|
@ -160,12 +128,15 @@ impl Worker {
|
|||
handshakes,
|
||||
})
|
||||
}
|
||||
pub(crate) async fn work_loop(&mut self) {
|
||||
/// Continuously loop and process work as needed
|
||||
pub async fn work_loop(&mut self) {
|
||||
'mainloop: loop {
|
||||
let work = ::tokio::select! {
|
||||
tell_stopped = self.stop_working.recv() => {
|
||||
let _ = tell_stopped.unwrap().send(
|
||||
if let Ok(stop_ch) = tell_stopped {
|
||||
let _ = stop_ch.send(
|
||||
crate::StopWorking::WorkerStopped).await;
|
||||
}
|
||||
break;
|
||||
}
|
||||
maybe_timeout = self.queue.recv() => {
|
||||
|
@ -302,6 +273,7 @@ impl Worker {
|
|||
}
|
||||
};
|
||||
let hkdf;
|
||||
|
||||
if let PubKey::Exchange(srv_pub) = key.1 {
|
||||
let secret =
|
||||
match priv_key.key_exchange(exchange, srv_pub) {
|
||||
|
@ -341,11 +313,13 @@ impl Worker {
|
|||
conn_info.service_id,
|
||||
service_conn_id,
|
||||
conn,
|
||||
conn_info.answer,
|
||||
key.0,
|
||||
) {
|
||||
Ok((client_key_id, hshake)) => (client_key_id, hshake),
|
||||
Err(_) => {
|
||||
Err(answer) => {
|
||||
::tracing::warn!("Too many client handshakes");
|
||||
let _ = conn_info.answer.send(Err(
|
||||
let _ = answer.send(Err(
|
||||
handshake::Error::TooManyClientHandshakes
|
||||
.into(),
|
||||
));
|
||||
|
@ -363,7 +337,7 @@ impl Worker {
|
|||
let req_data = dirsync::ReqData {
|
||||
nonce: dirsync::Nonce::new(&self.rand),
|
||||
client_key_id,
|
||||
id: auth_recv_id.0,
|
||||
id: auth_recv_id.0, //FIXME: is zero
|
||||
auth: auth_info,
|
||||
};
|
||||
let req = dirsync::Req {
|
||||
|
@ -374,28 +348,50 @@ impl Worker {
|
|||
exchange_key: pub_key,
|
||||
data: dirsync::ReqInner::ClearText(req_data),
|
||||
};
|
||||
let mut raw = Vec::<u8>::with_capacity(req.len());
|
||||
req.serialize(
|
||||
let encrypt_start = ID::len() + req.encrypted_offset();
|
||||
let encrypt_end = encrypt_start
|
||||
+ req.encrypted_length(
|
||||
cipher_selected.nonce_len(),
|
||||
cipher_selected.tag_len(),
|
||||
);
|
||||
let h_req = Handshake::new(HandshakeData::DirSync(
|
||||
DirSync::Req(req),
|
||||
));
|
||||
use connection::{PacketData, ID};
|
||||
let packet = Packet {
|
||||
id: ID::Handshake,
|
||||
data: PacketData::Handshake(h_req),
|
||||
};
|
||||
|
||||
let tot_len = packet.len(
|
||||
cipher_selected.nonce_len(),
|
||||
cipher_selected.tag_len(),
|
||||
);
|
||||
let mut raw = Vec::<u8>::with_capacity(tot_len);
|
||||
raw.resize(tot_len, 0);
|
||||
packet.serialize(
|
||||
cipher_selected.nonce_len(),
|
||||
cipher_selected.tag_len(),
|
||||
&mut raw[..],
|
||||
);
|
||||
// encrypt
|
||||
let encrypt_start = req.encrypted_offset();
|
||||
let encrypt_end = encrypt_start + req.encrypted_length();
|
||||
if let Err(e) = hshake.connection.cipher_send.encrypt(
|
||||
sym::AAD(&[]),
|
||||
&mut raw[encrypt_start..encrypt_end],
|
||||
) {
|
||||
::tracing::error!("Can't encrypt DirSync Request");
|
||||
let _ = conn_info.answer.send(Err(e.into()));
|
||||
if let Some(client) =
|
||||
self.handshakes.remove_client(client_key_id)
|
||||
{
|
||||
let _ = client.answer.send(Err(e.into()));
|
||||
};
|
||||
continue 'mainloop;
|
||||
}
|
||||
|
||||
// send always from the first socket
|
||||
// FIXME: select based on routing table
|
||||
let sender = self.sockets[0].local_addr().unwrap();
|
||||
let dest = UdpServer(addr.as_sockaddr().unwrap());
|
||||
let dest = UdpClient(addr.as_sockaddr().unwrap());
|
||||
|
||||
// start the timeout right before sending the packet
|
||||
hshake.timeout = Some(::tokio::task::spawn_local(
|
||||
|
@ -406,7 +402,7 @@ impl Worker {
|
|||
));
|
||||
|
||||
// send packet
|
||||
self.send_packet(raw, UdpClient(sender), dest).await;
|
||||
self.send_packet(raw, dest, UdpServer(sender)).await;
|
||||
|
||||
continue 'mainloop;
|
||||
}
|
||||
|
@ -435,17 +431,19 @@ impl Worker {
|
|||
/// Read and do stuff with the raw udp packet
|
||||
async fn recv(&mut self, mut udp: RawUdp) {
|
||||
if udp.packet.id.is_handshake() {
|
||||
let handshake = match Handshake::deserialize(&udp.data[8..]) {
|
||||
let handshake = match Handshake::deserialize(
|
||||
&udp.data[connection::ID::len()..],
|
||||
) {
|
||||
Ok(handshake) => handshake,
|
||||
Err(e) => {
|
||||
::tracing::warn!("Handshake parsing: {}", e);
|
||||
::tracing::debug!("Handshake parsing: {}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
let action = match self
|
||||
.handshakes
|
||||
.recv_handshake(handshake, &mut udp.data[8..])
|
||||
{
|
||||
let action = match self.handshakes.recv_handshake(
|
||||
handshake,
|
||||
&mut udp.data[connection::ID::len()..],
|
||||
) {
|
||||
Ok(action) => action,
|
||||
Err(err) => {
|
||||
::tracing::debug!("Handshake recv error {}", err);
|
||||
|
@ -454,16 +452,6 @@ impl Worker {
|
|||
};
|
||||
match action {
|
||||
HandshakeAction::AuthNeeded(authinfo) => {
|
||||
let token_check = match self.token_check.as_ref() {
|
||||
Some(token_check) => token_check,
|
||||
None => {
|
||||
::tracing::error!(
|
||||
"Authentication requested but \
|
||||
we have no token checker"
|
||||
);
|
||||
return;
|
||||
}
|
||||
};
|
||||
let req;
|
||||
if let HandshakeData::DirSync(DirSync::Req(r)) =
|
||||
authinfo.handshake.data
|
||||
|
@ -477,15 +465,24 @@ impl Worker {
|
|||
let req_data = match req.data {
|
||||
ReqInner::ClearText(req_data) => req_data,
|
||||
_ => {
|
||||
::tracing::error!(
|
||||
"token_check: expected ClearText"
|
||||
);
|
||||
::tracing::error!("AuthNeeded: expected ClearText");
|
||||
assert!(false, "AuthNeeded: unreachable");
|
||||
return;
|
||||
}
|
||||
};
|
||||
// FIXME: This part can take a while,
|
||||
// we should just spawn it probably
|
||||
let is_authenticated = {
|
||||
let maybe_auth_check = {
|
||||
match &self.token_check {
|
||||
None => {
|
||||
if req_data.auth.user == auth::USERID_ANONYMOUS
|
||||
{
|
||||
Ok(true)
|
||||
} else {
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
Some(token_check) => {
|
||||
let tk_check = token_check.lock().await;
|
||||
tk_check(
|
||||
req_data.auth.user,
|
||||
|
@ -494,8 +491,10 @@ impl Worker {
|
|||
req_data.auth.domain,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
};
|
||||
let is_authenticated = match is_authenticated {
|
||||
let is_authenticated = match maybe_auth_check {
|
||||
Ok(is_authenticated) => is_authenticated,
|
||||
Err(_) => {
|
||||
::tracing::error!("error in token auth");
|
||||
|
@ -545,9 +544,9 @@ impl Worker {
|
|||
client_key_id: req_data.client_key_id,
|
||||
data: RespInner::ClearText(resp_data),
|
||||
};
|
||||
let offset_to_encrypt = resp.encrypted_offset();
|
||||
let encrypt_from = ID::len() + resp.encrypted_offset();
|
||||
let encrypt_until =
|
||||
offset_to_encrypt + resp.encrypted_length() + tag_len.0;
|
||||
encrypt_from + resp.encrypted_length(head_len, tag_len);
|
||||
let resp_handshake = Handshake::new(
|
||||
HandshakeData::DirSync(DirSync::Resp(resp)),
|
||||
);
|
||||
|
@ -556,14 +555,15 @@ impl Worker {
|
|||
id: ID::new_handshake(),
|
||||
data: PacketData::Handshake(resp_handshake),
|
||||
};
|
||||
let mut raw_out =
|
||||
Vec::<u8>::with_capacity(packet.len(head_len, tag_len));
|
||||
let tot_len = packet.len(head_len, tag_len);
|
||||
let mut raw_out = Vec::<u8>::with_capacity(tot_len);
|
||||
raw_out.resize(tot_len, 0);
|
||||
packet.serialize(head_len, tag_len, &mut raw_out);
|
||||
|
||||
if let Err(e) = auth_conn.cipher_send.encrypt(
|
||||
aad,
|
||||
&mut raw_out[offset_to_encrypt..encrypt_until],
|
||||
) {
|
||||
if let Err(e) = auth_conn
|
||||
.cipher_send
|
||||
.encrypt(aad, &mut raw_out[encrypt_from..encrypt_until])
|
||||
{
|
||||
::tracing::error!("can't encrypt: {:?}", e);
|
||||
return;
|
||||
}
|
||||
|
@ -588,28 +588,27 @@ impl Worker {
|
|||
::tracing::error!(
|
||||
"ClientConnect on non DS::Resp::ClearText"
|
||||
);
|
||||
return;
|
||||
unreachable!();
|
||||
}
|
||||
let auth_srv_conn = IDSend(resp_data.id);
|
||||
let mut conn = cci.connection;
|
||||
conn.id_send = IDSend(resp_data.id);
|
||||
conn.id_send = auth_srv_conn;
|
||||
let id_recv = conn.id_recv;
|
||||
let cipher = conn.cipher_recv.kind();
|
||||
// track the connection to the authentication server
|
||||
if self.connections.track(conn.into()).is_err() {
|
||||
::tracing::error!("Could not track new connection");
|
||||
self.connections.remove(id_recv);
|
||||
let _ = cci.answer.send(Err(
|
||||
handshake::Error::InternalTracking.into(),
|
||||
));
|
||||
return;
|
||||
}
|
||||
if cci.service_id == auth::SERVICEID_AUTH {
|
||||
// the user asked a single connection
|
||||
// to the authentication server, without any additional
|
||||
// service. No more connections to setup
|
||||
return;
|
||||
}
|
||||
if cci.service_id != auth::SERVICEID_AUTH {
|
||||
// create and track the connection to the service
|
||||
// SECURITY: xor with secrets
|
||||
//FIXME: the Secret should be XORed with the client stored
|
||||
// secret (if any)
|
||||
//FIXME: the Secret should be XORed with the client
|
||||
// stored secret (if any)
|
||||
let hkdf = Hkdf::new(
|
||||
HkdfKind::Sha3,
|
||||
cci.service_id.as_bytes(),
|
||||
|
@ -624,7 +623,11 @@ impl Worker {
|
|||
service_connection.id_recv = cci.service_connection_id;
|
||||
service_connection.id_send =
|
||||
IDSend(resp_data.service_connection_id);
|
||||
let _ = self.connections.track(service_connection.into());
|
||||
let _ =
|
||||
self.connections.track(service_connection.into());
|
||||
}
|
||||
let _ =
|
||||
cci.answer.send(Ok((cci.srv_key_id, auth_srv_conn)));
|
||||
}
|
||||
HandshakeAction::Nothing => {}
|
||||
};
|
||||
|
@ -644,11 +647,12 @@ impl Worker {
|
|||
Some(src_sock) => src_sock,
|
||||
None => {
|
||||
::tracing::error!(
|
||||
"Can't send packet: Server changed listening ip!"
|
||||
"Can't send packet: Server changed listening ip{}!",
|
||||
server.0
|
||||
);
|
||||
return;
|
||||
}
|
||||
};
|
||||
let _ = src_sock.send_to(&data, client.0).await;
|
||||
let res = src_sock.send_to(&data, client.0).await;
|
||||
}
|
||||
}
|
||||
|
|
234
src/lib.rs
234
src/lib.rs
|
@ -30,7 +30,7 @@ use crate::{
|
|||
auth::{Domain, ServiceID, TokenChecker},
|
||||
connection::{
|
||||
handshake,
|
||||
socket::{SocketList, UdpClient, UdpServer},
|
||||
socket::{SocketTracker, UdpClient, UdpServer},
|
||||
AuthServerConnections, Packet,
|
||||
},
|
||||
inner::{
|
||||
|
@ -86,7 +86,7 @@ pub struct Fenrir {
|
|||
/// library Configuration
|
||||
cfg: Config,
|
||||
/// listening udp sockets
|
||||
sockets: SocketList,
|
||||
sockets: Vec<SocketTracker>,
|
||||
/// DNSSEC resolver, with failovers
|
||||
dnssec: dnssec::Dnssec,
|
||||
/// Broadcast channel to tell workers to stop working
|
||||
|
@ -100,9 +100,6 @@ pub struct Fenrir {
|
|||
// manner
|
||||
_thread_pool: Vec<::std::thread::JoinHandle<()>>,
|
||||
_thread_work: Arc<Vec<::async_channel::Sender<Work>>>,
|
||||
// This can be different from cfg.listen since using port 0 will result
|
||||
// in a random port assigned by the operative system
|
||||
_listen_addrs: Vec<::std::net::SocketAddr>,
|
||||
}
|
||||
// TODO: graceful vs immediate stop
|
||||
|
||||
|
@ -127,16 +124,23 @@ impl Fenrir {
|
|||
}
|
||||
fn stop_sync(
|
||||
&mut self,
|
||||
) -> Option<(::tokio::sync::mpsc::Receiver<StopWorking>, usize, usize)>
|
||||
{
|
||||
let listeners_num = self.sockets.list.len();
|
||||
) -> Option<(
|
||||
::tokio::sync::mpsc::Receiver<StopWorking>,
|
||||
Vec<::tokio::task::JoinHandle<::std::io::Result<()>>>,
|
||||
usize,
|
||||
)> {
|
||||
let workers_num = self._thread_work.len();
|
||||
if self.sockets.list.len() > 0 || self._thread_work.len() > 0 {
|
||||
if self.sockets.len() > 0 || self._thread_work.len() > 0 {
|
||||
let (ch_send, ch_recv) = ::tokio::sync::mpsc::channel(4);
|
||||
let _ = self.stop_working.send(ch_send);
|
||||
let _ = self.sockets.rm_all();
|
||||
let mut old_listeners = Vec::with_capacity(self.sockets.len());
|
||||
::core::mem::swap(&mut old_listeners, &mut self.sockets);
|
||||
self._thread_pool.clear();
|
||||
Some((ch_recv, listeners_num, workers_num))
|
||||
let listeners = old_listeners
|
||||
.into_iter()
|
||||
.map(|(_, joinable)| joinable)
|
||||
.collect();
|
||||
Some((ch_recv, listeners, workers_num))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
@ -144,9 +148,10 @@ impl Fenrir {
|
|||
async fn stop_wait(
|
||||
&mut self,
|
||||
mut ch: ::tokio::sync::mpsc::Receiver<StopWorking>,
|
||||
mut listeners_num: usize,
|
||||
listeners: Vec<::tokio::task::JoinHandle<::std::io::Result<()>>>,
|
||||
mut workers_num: usize,
|
||||
) {
|
||||
let mut listeners_num = listeners.len();
|
||||
while listeners_num > 0 && workers_num > 0 {
|
||||
match ch.recv().await {
|
||||
Some(stopped) => match stopped {
|
||||
|
@ -158,6 +163,11 @@ impl Fenrir {
|
|||
_ => break,
|
||||
}
|
||||
}
|
||||
for l in listeners.into_iter() {
|
||||
if let Err(e) = l.await {
|
||||
::tracing::error!("Unclean shutdown of listener: {:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
/// Create a new Fenrir endpoint
|
||||
/// spawn threads pinned to cpus in our own way with tokio's runtime
|
||||
|
@ -167,22 +177,32 @@ impl Fenrir {
|
|||
) -> Result<Self, Error> {
|
||||
let (sender, _) = ::tokio::sync::broadcast::channel(1);
|
||||
let dnssec = dnssec::Dnssec::new(&config.resolvers)?;
|
||||
// bind sockets early so we can change "port 0" (aka: random)
|
||||
// in the config
|
||||
let binded_sockets = Self::bind_sockets(&config).await?;
|
||||
let socket_addrs = binded_sockets
|
||||
.iter()
|
||||
.map(|s| s.local_addr().unwrap())
|
||||
.collect();
|
||||
let cfg = {
|
||||
let mut tmp = config.clone();
|
||||
tmp.listen = socket_addrs;
|
||||
tmp
|
||||
};
|
||||
let mut endpoint = Self {
|
||||
cfg: config.clone(),
|
||||
sockets: SocketList::new(),
|
||||
cfg,
|
||||
sockets: Vec::with_capacity(config.listen.len()),
|
||||
dnssec,
|
||||
stop_working: sender,
|
||||
token_check: None,
|
||||
conn_auth_srv: Mutex::new(AuthServerConnections::new()),
|
||||
_thread_pool: Vec::new(),
|
||||
_thread_work: Arc::new(Vec::new()),
|
||||
_listen_addrs: Vec::with_capacity(config.listen.len()),
|
||||
};
|
||||
endpoint.start_work_threads_pinned(tokio_rt).await?;
|
||||
match endpoint.add_sockets().await {
|
||||
Ok(addrs) => endpoint._listen_addrs = addrs,
|
||||
Err(e) => return Err(e.into()),
|
||||
}
|
||||
endpoint
|
||||
.start_work_threads_pinned(tokio_rt, binded_sockets.clone())
|
||||
.await?;
|
||||
endpoint.run_listeners(binded_sockets).await?;
|
||||
Ok(endpoint)
|
||||
}
|
||||
/// Create a new Fenrir endpoint
|
||||
|
@ -192,41 +212,39 @@ impl Fenrir {
|
|||
/// * make sure that the threads are pinned on the cpu
|
||||
pub async fn with_workers(
|
||||
config: &Config,
|
||||
) -> Result<
|
||||
(
|
||||
Self,
|
||||
Vec<impl futures::Future<Output = Result<(), std::io::Error>>>,
|
||||
),
|
||||
Error,
|
||||
> {
|
||||
) -> Result<(Self, Vec<Worker>), Error> {
|
||||
let (stop_working, _) = ::tokio::sync::broadcast::channel(1);
|
||||
let dnssec = dnssec::Dnssec::new(&config.resolvers)?;
|
||||
let cfg = config.clone();
|
||||
let sockets = SocketList::new();
|
||||
let conn_auth_srv = Mutex::new(AuthServerConnections::new());
|
||||
let thread_pool = Vec::new();
|
||||
let thread_work = Arc::new(Vec::new());
|
||||
let listen_addrs = Vec::with_capacity(config.listen.len());
|
||||
// bind sockets early so we can change "port 0" (aka: random)
|
||||
// in the config
|
||||
let binded_sockets = Self::bind_sockets(&config).await?;
|
||||
let socket_addrs = binded_sockets
|
||||
.iter()
|
||||
.map(|s| s.local_addr().unwrap())
|
||||
.collect();
|
||||
let cfg = {
|
||||
let mut tmp = config.clone();
|
||||
tmp.listen = socket_addrs;
|
||||
tmp
|
||||
};
|
||||
let mut endpoint = Self {
|
||||
cfg,
|
||||
sockets,
|
||||
sockets: Vec::with_capacity(config.listen.len()),
|
||||
dnssec,
|
||||
stop_working: stop_working.clone(),
|
||||
token_check: None,
|
||||
conn_auth_srv,
|
||||
_thread_pool: thread_pool,
|
||||
_thread_work: thread_work,
|
||||
_listen_addrs: listen_addrs,
|
||||
conn_auth_srv: Mutex::new(AuthServerConnections::new()),
|
||||
_thread_pool: Vec::new(),
|
||||
_thread_work: Arc::new(Vec::new()),
|
||||
};
|
||||
let worker_num = config.threads.unwrap().get();
|
||||
let mut workers = Vec::with_capacity(worker_num);
|
||||
for _ in 0..worker_num {
|
||||
workers.push(endpoint.start_single_worker().await?);
|
||||
}
|
||||
match endpoint.add_sockets().await {
|
||||
Ok(addrs) => endpoint._listen_addrs = addrs,
|
||||
Err(e) => return Err(e.into()),
|
||||
workers.push(
|
||||
endpoint.start_single_worker(binded_sockets.clone()).await?,
|
||||
);
|
||||
}
|
||||
endpoint.run_listeners(binded_sockets).await?;
|
||||
Ok((endpoint, workers))
|
||||
}
|
||||
/// Returns the list of the actual addresses we are listening on
|
||||
|
@ -234,57 +252,56 @@ impl Fenrir {
|
|||
/// if you specified UDP port 0 a random one has been assigned to you
|
||||
/// by the operating system.
|
||||
pub fn addresses(&self) -> Vec<::std::net::SocketAddr> {
|
||||
self._listen_addrs.clone()
|
||||
self.sockets.iter().map(|(s, _)| s.clone()).collect()
|
||||
}
|
||||
|
||||
// only call **after** starting all threads
|
||||
/// Add all UDP sockets found in config
|
||||
/// and start listening for packets
|
||||
async fn add_sockets(
|
||||
&mut self,
|
||||
) -> ::std::io::Result<Vec<::std::net::SocketAddr>> {
|
||||
// only call **before** starting all threads
|
||||
/// bind all UDP sockets found in config
|
||||
async fn bind_sockets(cfg: &Config) -> Result<Vec<Arc<UdpSocket>>, Error> {
|
||||
// try to bind multiple sockets in parallel
|
||||
let mut sock_set = ::tokio::task::JoinSet::new();
|
||||
self.cfg.listen.iter().for_each(|s_addr| {
|
||||
cfg.listen.iter().for_each(|s_addr| {
|
||||
let socket_address = s_addr.clone();
|
||||
let stop_working = self.stop_working.subscribe();
|
||||
let th_work = self._thread_work.clone();
|
||||
sock_set.spawn(async move {
|
||||
let s = connection::socket::bind_udp(socket_address).await?;
|
||||
let arc_s = Arc::new(s);
|
||||
let join = ::tokio::spawn(Self::listen_udp(
|
||||
stop_working,
|
||||
th_work,
|
||||
arc_s.clone(),
|
||||
));
|
||||
Ok((arc_s, join))
|
||||
connection::socket::bind_udp(socket_address).await
|
||||
});
|
||||
});
|
||||
|
||||
// make sure we either add all of them, or none
|
||||
let mut all_socks = Vec::with_capacity(self.cfg.listen.len());
|
||||
// make sure we either return all of them, or none
|
||||
let mut all_socks = Vec::with_capacity(cfg.listen.len());
|
||||
while let Some(join_res) = sock_set.join_next().await {
|
||||
match join_res {
|
||||
Ok(s_res) => match s_res {
|
||||
Ok(s) => {
|
||||
all_socks.push(s);
|
||||
all_socks.push(Arc::new(s));
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(e);
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
return Err(e.into());
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
return Err(Error::Setup(e.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
let mut ret = Vec::with_capacity(self.cfg.listen.len());
|
||||
for (arc_s, join) in all_socks.into_iter() {
|
||||
ret.push(arc_s.local_addr().unwrap());
|
||||
self.sockets.add_socket(arc_s, join).await;
|
||||
}
|
||||
Ok(ret)
|
||||
assert!(all_socks.len() == cfg.listen.len(), "missing socks");
|
||||
Ok(all_socks)
|
||||
}
|
||||
// only call **after** starting all threads
|
||||
/// spawn all listeners
|
||||
async fn run_listeners(
|
||||
&mut self,
|
||||
socks: Vec<Arc<UdpSocket>>,
|
||||
) -> Result<(), Error> {
|
||||
for sock in socks.into_iter() {
|
||||
let sockaddr = sock.local_addr().unwrap();
|
||||
let stop_working = self.stop_working.subscribe();
|
||||
let th_work = self._thread_work.clone();
|
||||
let joinable = ::tokio::spawn(async move {
|
||||
Self::listen_udp(stop_working, th_work, sock.clone()).await
|
||||
});
|
||||
self.sockets.push((sockaddr, joinable));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Run a dedicated loop to read packets on the listening socket
|
||||
|
@ -301,12 +318,15 @@ impl Fenrir {
|
|||
let (bytes, sock_sender) = ::tokio::select! {
|
||||
tell_stopped = stop_working.recv() => {
|
||||
drop(socket);
|
||||
let _ = tell_stopped.unwrap()
|
||||
if let Ok(stop_ch) = tell_stopped {
|
||||
let _ = stop_ch
|
||||
.send(StopWorking::ListenerStopped).await;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
result = socket.recv_from(&mut buffer) => {
|
||||
result?
|
||||
let (bytes, from) = result?;
|
||||
(bytes, UdpClient(from))
|
||||
}
|
||||
};
|
||||
let data: Vec<u8> = buffer[..bytes].to_vec();
|
||||
|
@ -324,17 +344,15 @@ impl Fenrir {
|
|||
use connection::packet::ConnectionID;
|
||||
match packet.id {
|
||||
ConnectionID::Handshake => {
|
||||
let send_port = sock_sender.port() as u64;
|
||||
((send_port % queues_num) - 1) as usize
|
||||
}
|
||||
ConnectionID::ID(id) => {
|
||||
((id.get() % queues_num) - 1) as usize
|
||||
let send_port = sock_sender.0.port() as u64;
|
||||
(send_port % queues_num) as usize
|
||||
}
|
||||
ConnectionID::ID(id) => (id.get() % queues_num) as usize,
|
||||
}
|
||||
};
|
||||
let _ = work_queues[thread_idx]
|
||||
.send(Work::Recv(RawUdp {
|
||||
src: UdpClient(sock_sender),
|
||||
src: sock_sender,
|
||||
dst: sock_receiver,
|
||||
packet,
|
||||
data,
|
||||
|
@ -431,7 +449,7 @@ impl Fenrir {
|
|||
.unwrap();
|
||||
|
||||
// and tell that thread to connect somewhere
|
||||
let (send, recv) = ::tokio::sync::oneshot::channel();
|
||||
let (send, mut recv) = ::tokio::sync::oneshot::channel();
|
||||
let _ = self._thread_work[thread_idx]
|
||||
.send(Work::Connect(ConnectInfo {
|
||||
answer: send,
|
||||
|
@ -450,10 +468,15 @@ impl Fenrir {
|
|||
conn_auth_lock.remove_reserved(&resolved);
|
||||
Err(e)
|
||||
}
|
||||
Ok((pubkey, id_send)) => {
|
||||
Ok((key_id, id_send)) => {
|
||||
let key = resolved
|
||||
.public_keys
|
||||
.iter()
|
||||
.find(|k| k.0 == key_id)
|
||||
.unwrap();
|
||||
let mut conn_auth_lock =
|
||||
self.conn_auth_srv.lock().await;
|
||||
conn_auth_lock.add(&pubkey, id_send, &resolved);
|
||||
conn_auth_lock.add(&key.1, id_send, &resolved);
|
||||
|
||||
//FIXME: user needs to somehow track the connection
|
||||
Ok(())
|
||||
|
@ -472,13 +495,11 @@ impl Fenrir {
|
|||
}
|
||||
}
|
||||
|
||||
// needs to be called before add_sockets
|
||||
// needs to be called before run_listeners
|
||||
async fn start_single_worker(
|
||||
&mut self,
|
||||
) -> ::std::result::Result<
|
||||
impl futures::Future<Output = Result<(), std::io::Error>>,
|
||||
Error,
|
||||
> {
|
||||
socks: Vec<Arc<UdpSocket>>,
|
||||
) -> ::std::result::Result<Worker, Error> {
|
||||
let thread_idx = self._thread_work.len() as u16;
|
||||
let max_threads = self.cfg.threads.unwrap().get() as u16;
|
||||
if thread_idx >= max_threads {
|
||||
|
@ -496,17 +517,18 @@ impl Fenrir {
|
|||
total: max_threads,
|
||||
};
|
||||
let (work_send, work_recv) = ::async_channel::unbounded::<Work>();
|
||||
let worker = Worker::new_and_loop(
|
||||
let worker = Worker::new(
|
||||
self.cfg.clone(),
|
||||
thread_id,
|
||||
self.stop_working.subscribe(),
|
||||
self.token_check.clone(),
|
||||
self.cfg.listen.clone(),
|
||||
socks,
|
||||
work_recv,
|
||||
);
|
||||
)
|
||||
.await?;
|
||||
// don't keep around private keys too much
|
||||
if (thread_idx + 1) == max_threads {
|
||||
self.cfg.keys.clear();
|
||||
self.cfg.server_keys.clear();
|
||||
}
|
||||
loop {
|
||||
let queues_lock = match Arc::get_mut(&mut self._thread_work) {
|
||||
|
@ -533,6 +555,7 @@ impl Fenrir {
|
|||
async fn start_work_threads_pinned(
|
||||
&mut self,
|
||||
tokio_rt: Arc<::tokio::runtime::Runtime>,
|
||||
sockets: Vec<Arc<UdpSocket>>,
|
||||
) -> ::std::result::Result<(), Error> {
|
||||
use ::std::sync::Mutex;
|
||||
let hw_topology = match ::hwloc2::Topology::new() {
|
||||
|
@ -568,7 +591,7 @@ impl Fenrir {
|
|||
let (work_send, work_recv) = ::async_channel::unbounded::<Work>();
|
||||
let th_stop_working = self.stop_working.subscribe();
|
||||
let th_token_check = self.token_check.clone();
|
||||
let th_socket_addrs = self.cfg.listen.clone();
|
||||
let th_sockets = sockets.clone();
|
||||
let thread_id = ThreadTracker {
|
||||
total: cores as u16,
|
||||
id: 1 + (core as u16),
|
||||
|
@ -598,17 +621,22 @@ impl Fenrir {
|
|||
// finally run the main worker.
|
||||
// make sure things stay on this thread
|
||||
let tk_local = ::tokio::task::LocalSet::new();
|
||||
let _ = tk_local.block_on(
|
||||
&th_tokio_rt,
|
||||
Worker::new_and_loop(
|
||||
let _ = tk_local.block_on(&th_tokio_rt, async move {
|
||||
let mut worker = match Worker::new(
|
||||
th_config,
|
||||
thread_id,
|
||||
th_stop_working,
|
||||
th_token_check,
|
||||
th_socket_addrs,
|
||||
th_sockets,
|
||||
work_recv,
|
||||
),
|
||||
);
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(worker) => worker,
|
||||
Err(_) => return,
|
||||
};
|
||||
worker.work_loop().await
|
||||
});
|
||||
});
|
||||
loop {
|
||||
let queues_lock = match Arc::get_mut(&mut self._thread_work) {
|
||||
|
@ -627,7 +655,7 @@ impl Fenrir {
|
|||
self._thread_pool.push(join_handle);
|
||||
}
|
||||
// don't keep around private keys too much
|
||||
self.cfg.keys.clear();
|
||||
self.cfg.server_keys.clear();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
59
src/tests.rs
59
src/tests.rs
|
@ -21,23 +21,46 @@ async fn test_connection_dirsync() {
|
|||
cfg.threads = Some(::core::num::NonZeroUsize::new(1).unwrap());
|
||||
cfg
|
||||
};
|
||||
let test_domain: Domain = "example.com".into();
|
||||
let cfg_server = {
|
||||
let mut cfg = cfg_client.clone();
|
||||
cfg.keys = [(KeyID(42), priv_exchange_key, pub_exchange_key)].to_vec();
|
||||
cfg.server_keys = [config::ServerKey {
|
||||
id: KeyID(42),
|
||||
priv_key: priv_exchange_key,
|
||||
pub_key: pub_exchange_key,
|
||||
}]
|
||||
.to_vec();
|
||||
cfg.servers = [config::AuthServer {
|
||||
fqdn: test_domain.clone(),
|
||||
keys: [KeyID(42)].to_vec(),
|
||||
}]
|
||||
.to_vec();
|
||||
cfg
|
||||
};
|
||||
|
||||
let (server, mut srv_workers) =
|
||||
Fenrir::with_workers(&cfg_server).await.unwrap();
|
||||
|
||||
let srv_worker = srv_workers.pop().unwrap();
|
||||
let local_thread = ::tokio::task::LocalSet::new();
|
||||
local_thread.spawn_local(async move { srv_worker.await });
|
||||
|
||||
let (client, mut cli_workers) =
|
||||
Fenrir::with_workers(&cfg_client).await.unwrap();
|
||||
let cli_worker = cli_workers.pop().unwrap();
|
||||
local_thread.spawn_local(async move { cli_worker.await });
|
||||
let mut srv_worker = srv_workers.pop().unwrap();
|
||||
let mut cli_worker = cli_workers.pop().unwrap();
|
||||
|
||||
::std::thread::spawn(move || {
|
||||
let rt = ::tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap();
|
||||
let local_thread = ::tokio::task::LocalSet::new();
|
||||
local_thread.spawn_local(async move {
|
||||
srv_worker.work_loop().await;
|
||||
});
|
||||
|
||||
local_thread.spawn_local(async move {
|
||||
::tokio::time::sleep(::std::time::Duration::from_millis(100)).await;
|
||||
cli_worker.work_loop().await;
|
||||
});
|
||||
rt.block_on(local_thread);
|
||||
});
|
||||
|
||||
use crate::{
|
||||
connection::handshake::HandshakeID,
|
||||
|
@ -63,17 +86,17 @@ async fn test_connection_dirsync() {
|
|||
ciphers: [enc::sym::CipherKind::XChaCha20Poly1305].to_vec(),
|
||||
};
|
||||
|
||||
server.graceful_stop().await;
|
||||
client.graceful_stop().await;
|
||||
return;
|
||||
::tokio::time::sleep(::std::time::Duration::from_millis(500)).await;
|
||||
match client
|
||||
.connect_resolved(dnssec_record, &test_domain, auth::SERVICEID_AUTH)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
assert!(false, "Err on client connection: {:?}", e);
|
||||
}
|
||||
}
|
||||
|
||||
let _ = client
|
||||
.connect_resolved(
|
||||
dnssec_record,
|
||||
&Domain("example.com".to_owned()),
|
||||
auth::SERVICEID_AUTH,
|
||||
)
|
||||
.await;
|
||||
server.graceful_stop().await;
|
||||
client.graceful_stop().await;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue