Connect boilerplate, cleanup

Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
This commit is contained in:
Luca Fulchir 2023-05-27 10:57:15 +02:00
parent e71167224c
commit 1259996201
Signed by: luca.fulchir
GPG Key ID: 8F6440603D13A78E
11 changed files with 123 additions and 117 deletions

View File

@ -37,6 +37,7 @@
#}) #})
clippy clippy
cargo-watch cargo-watch
cargo-flamegraph
cargo-license cargo-license
lld lld
rust-bin.stable."1.69.0".default rust-bin.stable."1.69.0".default

View File

@ -19,8 +19,6 @@ use crate::{
}; };
use ::arrayref::array_mut_ref; use ::arrayref::array_mut_ref;
use ::std::{collections::VecDeque, num::NonZeroU64, vec::Vec};
use trust_dns_client::rr::rdata::key::Protocol;
type Nonce = [u8; 16]; type Nonce = [u8; 16];
@ -304,7 +302,7 @@ impl RespInner {
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
match self { match self {
RespInner::CipherText(len) => *len, RespInner::CipherText(len) => *len,
RespInner::ClearText(d) => RespData::len(), RespInner::ClearText(_) => RespData::len(),
} }
} }
/* /*

View File

@ -7,7 +7,7 @@ use crate::{
enc::sym::{HeadLen, TagLen}, enc::sym::{HeadLen, TagLen},
}; };
use ::num_traits::FromPrimitive; use ::num_traits::FromPrimitive;
use ::std::{rc::Rc, sync::Arc}; use ::std::rc::Rc;
/// Handshake errors /// Handshake errors
#[derive(::thiserror::Error, Debug, Copy, Clone)] #[derive(::thiserror::Error, Debug, Copy, Clone)]
@ -145,10 +145,6 @@ impl Handshake {
self.fenrir_version.serialize(&mut out[0]); self.fenrir_version.serialize(&mut out[0]);
self.data.serialize(head_len, tag_len, &mut out[1..]); self.data.serialize(head_len, tag_len, &mut out[1..]);
} }
pub(crate) fn work(&self, keys: &[HandshakeServer]) -> Result<(), Error> {
todo!()
}
} }
trait HandshakeParsing { trait HandshakeParsing {

View File

@ -4,7 +4,7 @@ pub mod handshake;
pub mod packet; pub mod packet;
pub mod socket; pub mod socket;
use ::std::{rc::Rc, sync::Arc, vec::Vec}; use ::std::{rc::Rc, vec::Vec};
pub use crate::connection::{ pub use crate::connection::{
handshake::Handshake, handshake::Handshake,
@ -110,7 +110,7 @@ pub(crate) struct ConnList {
impl ConnList { impl ConnList {
pub(crate) fn new(thread_id: ThreadTracker) -> Self { pub(crate) fn new(thread_id: ThreadTracker) -> Self {
let mut bitmap_id = ::bitmaps::Bitmap::<1024>::new(); let bitmap_id = ::bitmaps::Bitmap::<1024>::new();
const INITIAL_CAP: usize = 128; const INITIAL_CAP: usize = 128;
let mut ret = Self { let mut ret = Self {
thread_id, thread_id,
@ -120,6 +120,13 @@ impl ConnList {
ret.connections.resize_with(INITIAL_CAP, || None); ret.connections.resize_with(INITIAL_CAP, || None);
ret ret
} }
pub fn len(&self) -> usize {
let mut total: usize = 0;
for bitmap in self.ids_used.iter() {
total = total + bitmap.len()
}
total
}
/// Only *Reserve* a connection, /// Only *Reserve* a connection,
/// without actually tracking it in self.connections /// without actually tracking it in self.connections
pub(crate) fn reserve_first( pub(crate) fn reserve_first(

View File

@ -1,15 +1,12 @@
//! Socket related types and functions //! Socket related types and functions
use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption}; use ::arc_swap::ArcSwap;
use ::std::{ use ::std::{net::SocketAddr, sync::Arc, vec::Vec};
net::SocketAddr, use ::tokio::{net::UdpSocket, task::JoinHandle};
sync::Arc,
vec::{self, Vec},
};
use ::tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle};
/// Pair to easily track the socket and its async listening handle /// Pair to easily track the socket and its async listening handle
pub type SocketTracker = (Arc<UdpSocket>, Arc<JoinHandle<::std::io::Result<()>>>); pub type SocketTracker =
(Arc<UdpSocket>, Arc<JoinHandle<::std::io::Result<()>>>);
/// async free socket list /// async free socket list
pub(crate) struct SocketList { pub(crate) struct SocketList {
@ -48,7 +45,7 @@ impl SocketList {
}); });
} }
/// This method assumes no other `add_sockets` are being run /// This method assumes no other `add_sockets` are being run
pub(crate) async fn stop_all(mut self) { pub(crate) async fn stop_all(self) {
let mut arc_list = self.list.into_inner(); let mut arc_list = self.list.into_inner();
let list = loop { let list = loop {
match Arc::try_unwrap(arc_list) { match Arc::try_unwrap(arc_list) {
@ -63,7 +60,7 @@ impl SocketList {
} }
}; };
for (_socket, mut handle) in list.into_iter() { for (_socket, mut handle) in list.into_iter() {
Arc::get_mut(&mut handle).unwrap().await; let _ = Arc::get_mut(&mut handle).unwrap().await;
} }
} }
pub(crate) fn lock(&self) -> SocketListRef { pub(crate) fn lock(&self) -> SocketListRef {

View File

@ -1,7 +1,6 @@
//! Asymmetric key handling and wrappers //! Asymmetric key handling and wrappers
use ::num_traits::FromPrimitive; use ::num_traits::FromPrimitive;
use ::std::vec::Vec;
use super::Error; use super::Error;
use crate::enc::sym::Secret; use crate::enc::sym::Secret;

View File

@ -51,15 +51,12 @@ impl HkdfSha3 {
/// Instantiate a new HKDF with Sha3-256 /// Instantiate a new HKDF with Sha3-256
pub fn new(salt: &[u8], key: Secret) -> Self { pub fn new(salt: &[u8], key: Secret) -> Self {
let hkdf = Hkdf::<Sha3_256>::new(Some(salt), key.as_ref()); let hkdf = Hkdf::<Sha3_256>::new(Some(salt), key.as_ref());
#[allow(unsafe_code)]
unsafe {
Self { Self {
inner: HkdfInner { inner: HkdfInner {
hkdf: ::core::mem::ManuallyDrop::new(hkdf), hkdf: ::core::mem::ManuallyDrop::new(hkdf),
}, },
} }
} }
}
/// Get a secret generated from the key and a given context /// Get a secret generated from the key and a given context
pub fn get_secret(&self, context: &[u8]) -> Secret { pub fn get_secret(&self, context: &[u8]) -> Secret {
let mut out: [u8; 32] = [0; 32]; let mut out: [u8; 32] = [0; 32];

View File

@ -1,7 +1,6 @@
//! Symmetric cypher stuff //! Symmetric cypher stuff
use super::Error; use super::Error;
use ::std::collections::VecDeque;
use ::zeroize::Zeroize; use ::zeroize::Zeroize;
/// Secret, used for keys. /// Secret, used for keys.
@ -174,7 +173,7 @@ impl Cipher {
} }
fn overhead(&self) -> usize { fn overhead(&self) -> usize {
match self { match self {
Cipher::XChaCha20Poly1305(cipher) => { Cipher::XChaCha20Poly1305(_) => {
let cipher = CipherKind::XChaCha20Poly1305; let cipher = CipherKind::XChaCha20Poly1305;
cipher.nonce_len().0 + cipher.tag_len().0 cipher.nonce_len().0 + cipher.tag_len().0
} }
@ -189,9 +188,7 @@ impl Cipher {
// FIXME: check minimum buffer size // FIXME: check minimum buffer size
match self { match self {
Cipher::XChaCha20Poly1305(cipher) => { Cipher::XChaCha20Poly1305(cipher) => {
use ::chacha20poly1305::{ use ::chacha20poly1305::AeadInPlace;
aead::generic_array::GenericArray, AeadInPlace,
};
let tag_len: usize = ::ring::aead::CHACHA20_POLY1305.tag_len(); let tag_len: usize = ::ring::aead::CHACHA20_POLY1305.tag_len();
let data_len_notag = data.len() - tag_len; let data_len_notag = data.len() - tag_len;
// write nonce // write nonce
@ -211,10 +208,9 @@ impl Cipher {
Ok(()) Ok(())
} }
Err(_) => Err(Error::Encrypt), Err(_) => Err(Error::Encrypt),
};
} }
} }
todo!() }
} }
} }
@ -253,35 +249,6 @@ impl CipherRecv {
} }
} }
/// Allocate some data, with additional indexes to track
/// where nonce and tags are
#[derive(Debug, Clone)]
pub struct Data {
data: Vec<u8>,
skip_start: usize,
skip_end: usize,
}
impl Data {
/// Get the slice where you will write the actual data
/// this will skip the actual nonce and AEAD tag and give you
/// only the space for the data
pub fn get_slice(&mut self) -> &mut [u8] {
&mut self.data[self.skip_start..self.skip_end]
}
fn get_tag_slice(&mut self) -> &mut [u8] {
let start = self.data.len() - self.skip_end;
&mut self.data[start..]
}
fn get_slice_full(&mut self) -> &mut [u8] {
&mut self.data
}
/// Consume the data and return the whole raw vector
pub fn get_raw(self) -> Vec<u8> {
self.data
}
}
/// Send only cipher /// Send only cipher
pub struct CipherSend { pub struct CipherSend {
nonce: NonceSync, nonce: NonceSync,
@ -308,14 +275,6 @@ impl CipherSend {
cipher: Cipher::new(kind, secret), cipher: Cipher::new(kind, secret),
} }
} }
/// Allocate the memory for the data that will be encrypted
pub fn make_data(&self, length: usize) -> Data {
Data {
data: Vec::with_capacity(length + self.cipher.overhead()),
skip_start: self.cipher.nonce_len().0,
skip_end: self.cipher.tag_len().0,
}
}
/// Encrypt the given data /// Encrypt the given data
pub fn encrypt(&self, aad: AAD, data: &mut [u8]) -> Result<(), Error> { pub fn encrypt(&self, aad: AAD, data: &mut [u8]) -> Result<(), Error> {
let old_nonce = self.nonce.advance(); let old_nonce = self.nonce.advance();
@ -380,11 +339,8 @@ impl Nonce {
use ring::rand::SecureRandom; use ring::rand::SecureRandom;
let mut raw = [0; 12]; let mut raw = [0; 12];
rand.fill(&mut raw); rand.fill(&mut raw);
#[allow(unsafe_code)]
unsafe {
Self { raw } Self { raw }
} }
}
/// Length of this nonce in bytes /// Length of this nonce in bytes
pub const fn len() -> usize { pub const fn len() -> usize {
return 12; return 12;
@ -398,11 +354,8 @@ impl Nonce {
} }
/// Create Nonce from array /// Create Nonce from array
pub fn from_slice(raw: [u8; 12]) -> Self { pub fn from_slice(raw: [u8; 12]) -> Self {
#[allow(unsafe_code)]
unsafe {
Self { raw } Self { raw }
} }
}
/// Go to the next nonce /// Go to the next nonce
pub fn advance(&mut self) { pub fn advance(&mut self) {
#[allow(unsafe_code)] #[allow(unsafe_code)]

View File

@ -14,12 +14,11 @@ use crate::{
enc::{ enc::{
self, asym, self, asym,
hkdf::HkdfSha3, hkdf::HkdfSha3,
sym::{CipherKind, CipherRecv, CipherSend, HeadLen, TagLen}, sym::{CipherKind, CipherRecv},
}, },
Error, Error,
}; };
use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption}; use ::std::{rc::Rc, vec::Vec};
use ::std::{rc::Rc, sync::Arc, vec::Vec};
/// Information needed to reply after the key exchange /// Information needed to reply after the key exchange
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -98,10 +97,7 @@ impl HandshakeTracker {
mut handshake: Handshake, mut handshake: Handshake,
handshake_raw: &mut [u8], handshake_raw: &mut [u8],
) -> Result<HandshakeAction, Error> { ) -> Result<HandshakeAction, Error> {
use connection::handshake::{ use connection::handshake::{dirsync::DirSync, HandshakeData};
dirsync::{self, DirSync},
HandshakeData,
};
match handshake.data { match handshake.data {
HandshakeData::DirSync(ref mut ds) => match ds { HandshakeData::DirSync(ref mut ds) => match ds {
DirSync::Req(ref mut req) => { DirSync::Req(ref mut req) => {

View File

@ -1,23 +1,25 @@
//! Worker thread implementation //! Worker thread implementation
use crate::{ use crate::{
auth::TokenChecker, auth::{ServiceID, TokenChecker},
connection::{ connection::{
self, self,
handshake::{ handshake::{
self,
dirsync::{self, DirSync}, dirsync::{self, DirSync},
Handshake, HandshakeClient, HandshakeData, Handshake, HandshakeData,
}, },
socket::{UdpClient, UdpServer}, socket::{UdpClient, UdpServer},
ConnList, Connection, IDSend, Packet, ID, ConnList, Connection, IDSend, Packet,
}, },
dnssec,
enc::{hkdf::HkdfSha3, sym::Secret}, enc::{hkdf::HkdfSha3, sym::Secret},
inner::{HandshakeAction, HandshakeTracker, ThreadTracker}, inner::{HandshakeAction, HandshakeTracker, ThreadTracker},
}; };
use ::std::{rc::Rc, sync::Arc, vec::Vec}; use ::std::{rc::Rc, sync::Arc, vec::Vec};
/// This worker must be cpu-pinned /// This worker must be cpu-pinned
use ::tokio::{net::UdpSocket, sync::Mutex}; use ::tokio::{
use std::net::SocketAddr; net::UdpSocket,
sync::{oneshot, Mutex},
};
/// Track a raw Udp packet /// Track a raw Udp packet
pub(crate) struct RawUdp { pub(crate) struct RawUdp {
@ -28,8 +30,15 @@ pub(crate) struct RawUdp {
} }
pub(crate) enum Work { pub(crate) enum Work {
/// ask the thread to report to the main thread the total number of
/// connections present
CountConnections(oneshot::Sender<usize>),
Connect((oneshot::Sender<u16>, dnssec::Record, ServiceID)),
Recv(RawUdp), Recv(RawUdp),
} }
pub(crate) enum WorkAnswer {
UNUSED,
}
/// Actual worker implementation. /// Actual worker implementation.
pub(crate) struct Worker { pub(crate) struct Worker {
@ -131,6 +140,13 @@ impl Worker {
} }
}; };
match work { match work {
Work::CountConnections(sender) => {
let conn_num = self.connections.len();
let _ = sender.send(conn_num);
}
Work::Connect((send_res, dnssec_record, service_id)) => {
todo!()
}
//TODO: reconf message to add channels //TODO: reconf message to add channels
Work::Recv(pkt) => { Work::Recv(pkt) => {
self.recv(pkt).await; self.recv(pkt).await;
@ -285,7 +301,6 @@ impl Worker {
return; return;
} }
// track connection // track connection
use handshake::dirsync;
let resp_data; let resp_data;
if let dirsync::RespInner::ClearText(r_data) = ds_resp.data if let dirsync::RespInner::ClearText(r_data) = ds_resp.data
{ {
@ -313,6 +328,7 @@ impl Worker {
return; return;
} }
// create and track the connection to the service // create and track the connection to the service
// SECURITY:
//FIXME: the Secret should be XORed with the client stored //FIXME: the Secret should be XORed with the client stored
// secret (if any) // secret (if any)
let hkdf = HkdfSha3::new( let hkdf = HkdfSha3::new(
@ -328,7 +344,7 @@ impl Worker {
service_connection.id_recv = cci.service_connection_id; service_connection.id_recv = cci.service_connection_id;
service_connection.id_send = service_connection.id_send =
IDSend(resp_data.service_connection_id); IDSend(resp_data.service_connection_id);
self.connections.track(service_connection.into()); let _ = self.connections.track(service_connection.into());
return; return;
} }
_ => {} _ => {}

View File

@ -20,12 +20,9 @@ pub mod dnssec;
pub mod enc; pub mod enc;
mod inner; mod inner;
use ::std::{ use ::std::{sync::Arc, vec::Vec};
net::SocketAddr, use ::tokio::net::UdpSocket;
sync::{Arc, Weak}, use auth::ServiceID;
vec::Vec,
};
use ::tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle};
use crate::{ use crate::{
auth::TokenChecker, auth::TokenChecker,
@ -94,9 +91,7 @@ impl Drop for Fenrir {
impl Fenrir { impl Fenrir {
/// Create a new Fenrir endpoint /// Create a new Fenrir endpoint
pub fn new(config: &Config) -> Result<Self, Error> { pub fn new(config: &Config) -> Result<Self, Error> {
let listen_num = config.listen.len();
let (sender, _) = ::tokio::sync::broadcast::channel(1); let (sender, _) = ::tokio::sync::broadcast::channel(1);
let (work_send, work_recv) = ::async_channel::unbounded::<Work>();
let endpoint = Fenrir { let endpoint = Fenrir {
cfg: config.clone(), cfg: config.clone(),
sockets: SocketList::new(), sockets: SocketList::new(),
@ -127,23 +122,23 @@ impl Fenrir {
/// asyncronous version for Drop /// asyncronous version for Drop
fn stop_sync(&mut self) { fn stop_sync(&mut self) {
let _ = self.stop_working.send(true); let _ = self.stop_working.send(true);
let mut toempty_sockets = self.sockets.rm_all(); let toempty_sockets = self.sockets.rm_all();
let task = ::tokio::task::spawn(toempty_sockets.stop_all()); let task = ::tokio::task::spawn(toempty_sockets.stop_all());
let _ = ::futures::executor::block_on(task); let _ = ::futures::executor::block_on(task);
let mut old_thread_pool = Vec::new(); let mut old_thread_pool = Vec::new();
::std::mem::swap(&mut self._thread_pool, &mut old_thread_pool); ::std::mem::swap(&mut self._thread_pool, &mut old_thread_pool);
old_thread_pool.into_iter().map(|th| th.join()); let _ = old_thread_pool.into_iter().map(|th| th.join());
self.dnssec = None; self.dnssec = None;
} }
/// Stop all workers, listeners /// Stop all workers, listeners
pub async fn stop(&mut self) { pub async fn stop(&mut self) {
let _ = self.stop_working.send(true); let _ = self.stop_working.send(true);
let mut toempty_sockets = self.sockets.rm_all(); let toempty_sockets = self.sockets.rm_all();
toempty_sockets.stop_all().await; toempty_sockets.stop_all().await;
let mut old_thread_pool = Vec::new(); let mut old_thread_pool = Vec::new();
::std::mem::swap(&mut self._thread_pool, &mut old_thread_pool); ::std::mem::swap(&mut self._thread_pool, &mut old_thread_pool);
old_thread_pool.into_iter().map(|th| th.join()); let _ = old_thread_pool.into_iter().map(|th| th.join());
self.dnssec = None; self.dnssec = None;
} }
/// Add all UDP sockets found in config /// Add all UDP sockets found in config
@ -166,7 +161,7 @@ impl Fenrir {
self._thread_work.clone(), self._thread_work.clone(),
arc_s.clone(), arc_s.clone(),
)); ));
self.sockets.add_socket(arc_s, join); self.sockets.add_socket(arc_s, join).await;
} }
Err(e) => { Err(e) => {
return Err(e); return Err(e);
@ -218,18 +213,19 @@ impl Fenrir {
} }
} }
}; };
work_queues[thread_idx].send(Work::Recv(RawUdp { let _ = work_queues[thread_idx]
.send(Work::Recv(RawUdp {
src: UdpClient(sock_sender), src: UdpClient(sock_sender),
dst: sock_receiver, dst: sock_receiver,
packet, packet,
data, data,
})); }))
.await;
} }
Ok(()) Ok(())
} }
/// Get the raw TXT record of a Fenrir domain /// Get the raw TXT record of a Fenrir domain
pub async fn resolv_str(&self, domain: &str) -> Result<String, Error> { pub async fn resolv_txt(&self, domain: &str) -> Result<String, Error> {
match &self.dnssec { match &self.dnssec {
Some(dnssec) => Ok(dnssec.resolv(domain).await?), Some(dnssec) => Ok(dnssec.resolv(domain).await?),
None => Err(Error::NotInitialized), None => Err(Error::NotInitialized),
@ -238,10 +234,60 @@ impl Fenrir {
/// Get the raw TXT record of a Fenrir domain /// Get the raw TXT record of a Fenrir domain
pub async fn resolv(&self, domain: &str) -> Result<dnssec::Record, Error> { pub async fn resolv(&self, domain: &str) -> Result<dnssec::Record, Error> {
let record_str = self.resolv_str(domain).await?; let record_str = self.resolv_txt(domain).await?;
Ok(dnssec::Dnssec::parse_txt_record(&record_str)?) Ok(dnssec::Dnssec::parse_txt_record(&record_str)?)
} }
/// Connect to a service
pub async fn connect(
&self,
domain: &str,
service: ServiceID,
) -> Result<(), Error> {
let resolved = self.resolv(domain).await?;
// find the thread with less connections
let th_num = self._thread_work.len();
let mut conn_count = Vec::<usize>::with_capacity(th_num);
let mut wait_res =
Vec::<::tokio::sync::oneshot::Receiver<usize>>::with_capacity(
th_num,
);
for th in self._thread_work.iter() {
let (send, recv) = ::tokio::sync::oneshot::channel();
wait_res.push(recv);
let _ = th.send(Work::CountConnections(send)).await;
}
for ch in wait_res.into_iter() {
if let Ok(conn_num) = ch.await {
conn_count.push(conn_num);
}
}
if conn_count.len() != th_num {
return Err(Error::IO(::std::io::Error::new(
::std::io::ErrorKind::NotConnected,
"can't connect to a thread",
)));
}
let thread_idx = conn_count
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.cmp(b))
.map(|(index, _)| index)
.unwrap();
// and tell that thread to connect somewhere
let (send, recv) = ::tokio::sync::oneshot::channel();
let _ = self._thread_work[thread_idx]
.send(Work::Connect((send, resolved, service)))
.await;
let _conn_res = recv.await;
todo!()
}
/// Start one working thread for each physical cpu /// Start one working thread for each physical cpu
/// threads are pinned to each cpu core. /// threads are pinned to each cpu core.
/// Work will be divided and rerouted so that there is no need to lock /// Work will be divided and rerouted so that there is no need to lock