Connect boilerplate, cleanup

Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
This commit is contained in:
Luca Fulchir 2023-05-27 10:57:15 +02:00
parent e71167224c
commit 1259996201
Signed by: luca.fulchir
GPG Key ID: 8F6440603D13A78E
11 changed files with 123 additions and 117 deletions

View File

@ -37,6 +37,7 @@
#})
clippy
cargo-watch
cargo-flamegraph
cargo-license
lld
rust-bin.stable."1.69.0".default

View File

@ -19,8 +19,6 @@ use crate::{
};
use ::arrayref::array_mut_ref;
use ::std::{collections::VecDeque, num::NonZeroU64, vec::Vec};
use trust_dns_client::rr::rdata::key::Protocol;
type Nonce = [u8; 16];
@ -304,7 +302,7 @@ impl RespInner {
pub fn len(&self) -> usize {
match self {
RespInner::CipherText(len) => *len,
RespInner::ClearText(d) => RespData::len(),
RespInner::ClearText(_) => RespData::len(),
}
}
/*

View File

@ -7,7 +7,7 @@ use crate::{
enc::sym::{HeadLen, TagLen},
};
use ::num_traits::FromPrimitive;
use ::std::{rc::Rc, sync::Arc};
use ::std::rc::Rc;
/// Handshake errors
#[derive(::thiserror::Error, Debug, Copy, Clone)]
@ -145,10 +145,6 @@ impl Handshake {
self.fenrir_version.serialize(&mut out[0]);
self.data.serialize(head_len, tag_len, &mut out[1..]);
}
pub(crate) fn work(&self, keys: &[HandshakeServer]) -> Result<(), Error> {
todo!()
}
}
trait HandshakeParsing {

View File

@ -4,7 +4,7 @@ pub mod handshake;
pub mod packet;
pub mod socket;
use ::std::{rc::Rc, sync::Arc, vec::Vec};
use ::std::{rc::Rc, vec::Vec};
pub use crate::connection::{
handshake::Handshake,
@ -110,7 +110,7 @@ pub(crate) struct ConnList {
impl ConnList {
pub(crate) fn new(thread_id: ThreadTracker) -> Self {
let mut bitmap_id = ::bitmaps::Bitmap::<1024>::new();
let bitmap_id = ::bitmaps::Bitmap::<1024>::new();
const INITIAL_CAP: usize = 128;
let mut ret = Self {
thread_id,
@ -120,6 +120,13 @@ impl ConnList {
ret.connections.resize_with(INITIAL_CAP, || None);
ret
}
pub fn len(&self) -> usize {
let mut total: usize = 0;
for bitmap in self.ids_used.iter() {
total = total + bitmap.len()
}
total
}
/// Only *Reserve* a connection,
/// without actually tracking it in self.connections
pub(crate) fn reserve_first(

View File

@ -1,15 +1,12 @@
//! Socket related types and functions
use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption};
use ::std::{
net::SocketAddr,
sync::Arc,
vec::{self, Vec},
};
use ::tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle};
use ::arc_swap::ArcSwap;
use ::std::{net::SocketAddr, sync::Arc, vec::Vec};
use ::tokio::{net::UdpSocket, task::JoinHandle};
/// Pair to easily track the socket and its async listening handle
pub type SocketTracker = (Arc<UdpSocket>, Arc<JoinHandle<::std::io::Result<()>>>);
pub type SocketTracker =
(Arc<UdpSocket>, Arc<JoinHandle<::std::io::Result<()>>>);
/// async free socket list
pub(crate) struct SocketList {
@ -48,7 +45,7 @@ impl SocketList {
});
}
/// This method assumes no other `add_sockets` are being run
pub(crate) async fn stop_all(mut self) {
pub(crate) async fn stop_all(self) {
let mut arc_list = self.list.into_inner();
let list = loop {
match Arc::try_unwrap(arc_list) {
@ -63,7 +60,7 @@ impl SocketList {
}
};
for (_socket, mut handle) in list.into_iter() {
Arc::get_mut(&mut handle).unwrap().await;
let _ = Arc::get_mut(&mut handle).unwrap().await;
}
}
pub(crate) fn lock(&self) -> SocketListRef {

View File

@ -1,7 +1,6 @@
//! Asymmetric key handling and wrappers
use ::num_traits::FromPrimitive;
use ::std::vec::Vec;
use super::Error;
use crate::enc::sym::Secret;

View File

@ -51,13 +51,10 @@ impl HkdfSha3 {
/// Instantiate a new HKDF with Sha3-256
pub fn new(salt: &[u8], key: Secret) -> Self {
let hkdf = Hkdf::<Sha3_256>::new(Some(salt), key.as_ref());
#[allow(unsafe_code)]
unsafe {
Self {
inner: HkdfInner {
hkdf: ::core::mem::ManuallyDrop::new(hkdf),
},
}
Self {
inner: HkdfInner {
hkdf: ::core::mem::ManuallyDrop::new(hkdf),
},
}
}
/// Get a secret generated from the key and a given context

View File

@ -1,7 +1,6 @@
//! Symmetric cypher stuff
use super::Error;
use ::std::collections::VecDeque;
use ::zeroize::Zeroize;
/// Secret, used for keys.
@ -174,7 +173,7 @@ impl Cipher {
}
fn overhead(&self) -> usize {
match self {
Cipher::XChaCha20Poly1305(cipher) => {
Cipher::XChaCha20Poly1305(_) => {
let cipher = CipherKind::XChaCha20Poly1305;
cipher.nonce_len().0 + cipher.tag_len().0
}
@ -189,9 +188,7 @@ impl Cipher {
// FIXME: check minimum buffer size
match self {
Cipher::XChaCha20Poly1305(cipher) => {
use ::chacha20poly1305::{
aead::generic_array::GenericArray, AeadInPlace,
};
use ::chacha20poly1305::AeadInPlace;
let tag_len: usize = ::ring::aead::CHACHA20_POLY1305.tag_len();
let data_len_notag = data.len() - tag_len;
// write nonce
@ -211,10 +208,9 @@ impl Cipher {
Ok(())
}
Err(_) => Err(Error::Encrypt),
};
}
}
}
todo!()
}
}
@ -253,35 +249,6 @@ impl CipherRecv {
}
}
/// Allocate some data, with additional indexes to track
/// where nonce and tags are
#[derive(Debug, Clone)]
pub struct Data {
data: Vec<u8>,
skip_start: usize,
skip_end: usize,
}
impl Data {
/// Get the slice where you will write the actual data
/// this will skip the actual nonce and AEAD tag and give you
/// only the space for the data
pub fn get_slice(&mut self) -> &mut [u8] {
&mut self.data[self.skip_start..self.skip_end]
}
fn get_tag_slice(&mut self) -> &mut [u8] {
let start = self.data.len() - self.skip_end;
&mut self.data[start..]
}
fn get_slice_full(&mut self) -> &mut [u8] {
&mut self.data
}
/// Consume the data and return the whole raw vector
pub fn get_raw(self) -> Vec<u8> {
self.data
}
}
/// Send only cipher
pub struct CipherSend {
nonce: NonceSync,
@ -308,14 +275,6 @@ impl CipherSend {
cipher: Cipher::new(kind, secret),
}
}
/// Allocate the memory for the data that will be encrypted
pub fn make_data(&self, length: usize) -> Data {
Data {
data: Vec::with_capacity(length + self.cipher.overhead()),
skip_start: self.cipher.nonce_len().0,
skip_end: self.cipher.tag_len().0,
}
}
/// Encrypt the given data
pub fn encrypt(&self, aad: AAD, data: &mut [u8]) -> Result<(), Error> {
let old_nonce = self.nonce.advance();
@ -380,10 +339,7 @@ impl Nonce {
use ring::rand::SecureRandom;
let mut raw = [0; 12];
rand.fill(&mut raw);
#[allow(unsafe_code)]
unsafe {
Self { raw }
}
Self { raw }
}
/// Length of this nonce in bytes
pub const fn len() -> usize {
@ -398,10 +354,7 @@ impl Nonce {
}
/// Create Nonce from array
pub fn from_slice(raw: [u8; 12]) -> Self {
#[allow(unsafe_code)]
unsafe {
Self { raw }
}
Self { raw }
}
/// Go to the next nonce
pub fn advance(&mut self) {

View File

@ -14,12 +14,11 @@ use crate::{
enc::{
self, asym,
hkdf::HkdfSha3,
sym::{CipherKind, CipherRecv, CipherSend, HeadLen, TagLen},
sym::{CipherKind, CipherRecv},
},
Error,
};
use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption};
use ::std::{rc::Rc, sync::Arc, vec::Vec};
use ::std::{rc::Rc, vec::Vec};
/// Information needed to reply after the key exchange
#[derive(Debug, Clone)]
@ -98,10 +97,7 @@ impl HandshakeTracker {
mut handshake: Handshake,
handshake_raw: &mut [u8],
) -> Result<HandshakeAction, Error> {
use connection::handshake::{
dirsync::{self, DirSync},
HandshakeData,
};
use connection::handshake::{dirsync::DirSync, HandshakeData};
match handshake.data {
HandshakeData::DirSync(ref mut ds) => match ds {
DirSync::Req(ref mut req) => {

View File

@ -1,23 +1,25 @@
//! Worker thread implementation
use crate::{
auth::TokenChecker,
auth::{ServiceID, TokenChecker},
connection::{
self,
handshake::{
self,
dirsync::{self, DirSync},
Handshake, HandshakeClient, HandshakeData,
Handshake, HandshakeData,
},
socket::{UdpClient, UdpServer},
ConnList, Connection, IDSend, Packet, ID,
ConnList, Connection, IDSend, Packet,
},
dnssec,
enc::{hkdf::HkdfSha3, sym::Secret},
inner::{HandshakeAction, HandshakeTracker, ThreadTracker},
};
use ::std::{rc::Rc, sync::Arc, vec::Vec};
/// This worker must be cpu-pinned
use ::tokio::{net::UdpSocket, sync::Mutex};
use std::net::SocketAddr;
use ::tokio::{
net::UdpSocket,
sync::{oneshot, Mutex},
};
/// Track a raw Udp packet
pub(crate) struct RawUdp {
@ -28,8 +30,15 @@ pub(crate) struct RawUdp {
}
pub(crate) enum Work {
/// ask the thread to report to the main thread the total number of
/// connections present
CountConnections(oneshot::Sender<usize>),
Connect((oneshot::Sender<u16>, dnssec::Record, ServiceID)),
Recv(RawUdp),
}
pub(crate) enum WorkAnswer {
UNUSED,
}
/// Actual worker implementation.
pub(crate) struct Worker {
@ -131,6 +140,13 @@ impl Worker {
}
};
match work {
Work::CountConnections(sender) => {
let conn_num = self.connections.len();
let _ = sender.send(conn_num);
}
Work::Connect((send_res, dnssec_record, service_id)) => {
todo!()
}
//TODO: reconf message to add channels
Work::Recv(pkt) => {
self.recv(pkt).await;
@ -285,7 +301,6 @@ impl Worker {
return;
}
// track connection
use handshake::dirsync;
let resp_data;
if let dirsync::RespInner::ClearText(r_data) = ds_resp.data
{
@ -313,6 +328,7 @@ impl Worker {
return;
}
// create and track the connection to the service
// SECURITY:
//FIXME: the Secret should be XORed with the client stored
// secret (if any)
let hkdf = HkdfSha3::new(
@ -328,7 +344,7 @@ impl Worker {
service_connection.id_recv = cci.service_connection_id;
service_connection.id_send =
IDSend(resp_data.service_connection_id);
self.connections.track(service_connection.into());
let _ = self.connections.track(service_connection.into());
return;
}
_ => {}

View File

@ -20,12 +20,9 @@ pub mod dnssec;
pub mod enc;
mod inner;
use ::std::{
net::SocketAddr,
sync::{Arc, Weak},
vec::Vec,
};
use ::tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle};
use ::std::{sync::Arc, vec::Vec};
use ::tokio::net::UdpSocket;
use auth::ServiceID;
use crate::{
auth::TokenChecker,
@ -94,9 +91,7 @@ impl Drop for Fenrir {
impl Fenrir {
/// Create a new Fenrir endpoint
pub fn new(config: &Config) -> Result<Self, Error> {
let listen_num = config.listen.len();
let (sender, _) = ::tokio::sync::broadcast::channel(1);
let (work_send, work_recv) = ::async_channel::unbounded::<Work>();
let endpoint = Fenrir {
cfg: config.clone(),
sockets: SocketList::new(),
@ -127,23 +122,23 @@ impl Fenrir {
/// asyncronous version for Drop
fn stop_sync(&mut self) {
let _ = self.stop_working.send(true);
let mut toempty_sockets = self.sockets.rm_all();
let toempty_sockets = self.sockets.rm_all();
let task = ::tokio::task::spawn(toempty_sockets.stop_all());
let _ = ::futures::executor::block_on(task);
let mut old_thread_pool = Vec::new();
::std::mem::swap(&mut self._thread_pool, &mut old_thread_pool);
old_thread_pool.into_iter().map(|th| th.join());
let _ = old_thread_pool.into_iter().map(|th| th.join());
self.dnssec = None;
}
/// Stop all workers, listeners
pub async fn stop(&mut self) {
let _ = self.stop_working.send(true);
let mut toempty_sockets = self.sockets.rm_all();
let toempty_sockets = self.sockets.rm_all();
toempty_sockets.stop_all().await;
let mut old_thread_pool = Vec::new();
::std::mem::swap(&mut self._thread_pool, &mut old_thread_pool);
old_thread_pool.into_iter().map(|th| th.join());
let _ = old_thread_pool.into_iter().map(|th| th.join());
self.dnssec = None;
}
/// Add all UDP sockets found in config
@ -166,7 +161,7 @@ impl Fenrir {
self._thread_work.clone(),
arc_s.clone(),
));
self.sockets.add_socket(arc_s, join);
self.sockets.add_socket(arc_s, join).await;
}
Err(e) => {
return Err(e);
@ -218,18 +213,19 @@ impl Fenrir {
}
}
};
work_queues[thread_idx].send(Work::Recv(RawUdp {
src: UdpClient(sock_sender),
dst: sock_receiver,
packet,
data,
}));
let _ = work_queues[thread_idx]
.send(Work::Recv(RawUdp {
src: UdpClient(sock_sender),
dst: sock_receiver,
packet,
data,
}))
.await;
}
Ok(())
}
/// Get the raw TXT record of a Fenrir domain
pub async fn resolv_str(&self, domain: &str) -> Result<String, Error> {
pub async fn resolv_txt(&self, domain: &str) -> Result<String, Error> {
match &self.dnssec {
Some(dnssec) => Ok(dnssec.resolv(domain).await?),
None => Err(Error::NotInitialized),
@ -238,10 +234,60 @@ impl Fenrir {
/// Get the raw TXT record of a Fenrir domain
pub async fn resolv(&self, domain: &str) -> Result<dnssec::Record, Error> {
let record_str = self.resolv_str(domain).await?;
let record_str = self.resolv_txt(domain).await?;
Ok(dnssec::Dnssec::parse_txt_record(&record_str)?)
}
/// Connect to a service
pub async fn connect(
&self,
domain: &str,
service: ServiceID,
) -> Result<(), Error> {
let resolved = self.resolv(domain).await?;
// find the thread with less connections
let th_num = self._thread_work.len();
let mut conn_count = Vec::<usize>::with_capacity(th_num);
let mut wait_res =
Vec::<::tokio::sync::oneshot::Receiver<usize>>::with_capacity(
th_num,
);
for th in self._thread_work.iter() {
let (send, recv) = ::tokio::sync::oneshot::channel();
wait_res.push(recv);
let _ = th.send(Work::CountConnections(send)).await;
}
for ch in wait_res.into_iter() {
if let Ok(conn_num) = ch.await {
conn_count.push(conn_num);
}
}
if conn_count.len() != th_num {
return Err(Error::IO(::std::io::Error::new(
::std::io::ErrorKind::NotConnected,
"can't connect to a thread",
)));
}
let thread_idx = conn_count
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.cmp(b))
.map(|(index, _)| index)
.unwrap();
// and tell that thread to connect somewhere
let (send, recv) = ::tokio::sync::oneshot::channel();
let _ = self._thread_work[thread_idx]
.send(Work::Connect((send, resolved, service)))
.await;
let _conn_res = recv.await;
todo!()
}
/// Start one working thread for each physical cpu
/// threads are pinned to each cpu core.
/// Work will be divided and rerouted so that there is no need to lock