From 1259996201941f919faf0d3af32437fcbf5f3779 Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Sat, 27 May 2023 10:57:15 +0200 Subject: [PATCH] Connect boilerplate, cleanup Signed-off-by: Luca Fulchir --- flake.nix | 1 + src/connection/handshake/dirsync.rs | 4 +- src/connection/handshake/mod.rs | 6 +- src/connection/mod.rs | 11 +++- src/connection/socket.rs | 17 +++--- src/enc/asym.rs | 1 - src/enc/hkdf.rs | 11 ++-- src/enc/sym.rs | 57 ++---------------- src/inner/mod.rs | 10 +--- src/inner/worker.rs | 32 +++++++--- src/lib.rs | 90 ++++++++++++++++++++++------- 11 files changed, 123 insertions(+), 117 deletions(-) diff --git a/flake.nix b/flake.nix index 4c4a18a..fd0713c 100644 --- a/flake.nix +++ b/flake.nix @@ -37,6 +37,7 @@ #}) clippy cargo-watch + cargo-flamegraph cargo-license lld rust-bin.stable."1.69.0".default diff --git a/src/connection/handshake/dirsync.rs b/src/connection/handshake/dirsync.rs index 72d6da9..c5ef626 100644 --- a/src/connection/handshake/dirsync.rs +++ b/src/connection/handshake/dirsync.rs @@ -19,8 +19,6 @@ use crate::{ }; use ::arrayref::array_mut_ref; -use ::std::{collections::VecDeque, num::NonZeroU64, vec::Vec}; -use trust_dns_client::rr::rdata::key::Protocol; type Nonce = [u8; 16]; @@ -304,7 +302,7 @@ impl RespInner { pub fn len(&self) -> usize { match self { RespInner::CipherText(len) => *len, - RespInner::ClearText(d) => RespData::len(), + RespInner::ClearText(_) => RespData::len(), } } /* diff --git a/src/connection/handshake/mod.rs b/src/connection/handshake/mod.rs index 63de947..8897c1a 100644 --- a/src/connection/handshake/mod.rs +++ b/src/connection/handshake/mod.rs @@ -7,7 +7,7 @@ use crate::{ enc::sym::{HeadLen, TagLen}, }; use ::num_traits::FromPrimitive; -use ::std::{rc::Rc, sync::Arc}; +use ::std::rc::Rc; /// Handshake errors #[derive(::thiserror::Error, Debug, Copy, Clone)] @@ -145,10 +145,6 @@ impl Handshake { self.fenrir_version.serialize(&mut out[0]); self.data.serialize(head_len, tag_len, &mut out[1..]); } - - pub(crate) fn work(&self, keys: &[HandshakeServer]) -> Result<(), Error> { - todo!() - } } trait HandshakeParsing { diff --git a/src/connection/mod.rs b/src/connection/mod.rs index 2c85af7..9d8a4cb 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -4,7 +4,7 @@ pub mod handshake; pub mod packet; pub mod socket; -use ::std::{rc::Rc, sync::Arc, vec::Vec}; +use ::std::{rc::Rc, vec::Vec}; pub use crate::connection::{ handshake::Handshake, @@ -110,7 +110,7 @@ pub(crate) struct ConnList { impl ConnList { pub(crate) fn new(thread_id: ThreadTracker) -> Self { - let mut bitmap_id = ::bitmaps::Bitmap::<1024>::new(); + let bitmap_id = ::bitmaps::Bitmap::<1024>::new(); const INITIAL_CAP: usize = 128; let mut ret = Self { thread_id, @@ -120,6 +120,13 @@ impl ConnList { ret.connections.resize_with(INITIAL_CAP, || None); ret } + pub fn len(&self) -> usize { + let mut total: usize = 0; + for bitmap in self.ids_used.iter() { + total = total + bitmap.len() + } + total + } /// Only *Reserve* a connection, /// without actually tracking it in self.connections pub(crate) fn reserve_first( diff --git a/src/connection/socket.rs b/src/connection/socket.rs index 1cc570d..945dac6 100644 --- a/src/connection/socket.rs +++ b/src/connection/socket.rs @@ -1,15 +1,12 @@ //! Socket related types and functions -use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption}; -use ::std::{ - net::SocketAddr, - sync::Arc, - vec::{self, Vec}, -}; -use ::tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle}; +use ::arc_swap::ArcSwap; +use ::std::{net::SocketAddr, sync::Arc, vec::Vec}; +use ::tokio::{net::UdpSocket, task::JoinHandle}; /// Pair to easily track the socket and its async listening handle -pub type SocketTracker = (Arc, Arc>>); +pub type SocketTracker = + (Arc, Arc>>); /// async free socket list pub(crate) struct SocketList { @@ -48,7 +45,7 @@ impl SocketList { }); } /// This method assumes no other `add_sockets` are being run - pub(crate) async fn stop_all(mut self) { + pub(crate) async fn stop_all(self) { let mut arc_list = self.list.into_inner(); let list = loop { match Arc::try_unwrap(arc_list) { @@ -63,7 +60,7 @@ impl SocketList { } }; for (_socket, mut handle) in list.into_iter() { - Arc::get_mut(&mut handle).unwrap().await; + let _ = Arc::get_mut(&mut handle).unwrap().await; } } pub(crate) fn lock(&self) -> SocketListRef { diff --git a/src/enc/asym.rs b/src/enc/asym.rs index c705a5f..56027df 100644 --- a/src/enc/asym.rs +++ b/src/enc/asym.rs @@ -1,7 +1,6 @@ //! Asymmetric key handling and wrappers use ::num_traits::FromPrimitive; -use ::std::vec::Vec; use super::Error; use crate::enc::sym::Secret; diff --git a/src/enc/hkdf.rs b/src/enc/hkdf.rs index bb6ca59..15d7eca 100644 --- a/src/enc/hkdf.rs +++ b/src/enc/hkdf.rs @@ -51,13 +51,10 @@ impl HkdfSha3 { /// Instantiate a new HKDF with Sha3-256 pub fn new(salt: &[u8], key: Secret) -> Self { let hkdf = Hkdf::::new(Some(salt), key.as_ref()); - #[allow(unsafe_code)] - unsafe { - Self { - inner: HkdfInner { - hkdf: ::core::mem::ManuallyDrop::new(hkdf), - }, - } + Self { + inner: HkdfInner { + hkdf: ::core::mem::ManuallyDrop::new(hkdf), + }, } } /// Get a secret generated from the key and a given context diff --git a/src/enc/sym.rs b/src/enc/sym.rs index 3de267c..e6f9e11 100644 --- a/src/enc/sym.rs +++ b/src/enc/sym.rs @@ -1,7 +1,6 @@ //! Symmetric cypher stuff use super::Error; -use ::std::collections::VecDeque; use ::zeroize::Zeroize; /// Secret, used for keys. @@ -174,7 +173,7 @@ impl Cipher { } fn overhead(&self) -> usize { match self { - Cipher::XChaCha20Poly1305(cipher) => { + Cipher::XChaCha20Poly1305(_) => { let cipher = CipherKind::XChaCha20Poly1305; cipher.nonce_len().0 + cipher.tag_len().0 } @@ -189,9 +188,7 @@ impl Cipher { // FIXME: check minimum buffer size match self { Cipher::XChaCha20Poly1305(cipher) => { - use ::chacha20poly1305::{ - aead::generic_array::GenericArray, AeadInPlace, - }; + use ::chacha20poly1305::AeadInPlace; let tag_len: usize = ::ring::aead::CHACHA20_POLY1305.tag_len(); let data_len_notag = data.len() - tag_len; // write nonce @@ -211,10 +208,9 @@ impl Cipher { Ok(()) } Err(_) => Err(Error::Encrypt), - }; + } } } - todo!() } } @@ -253,35 +249,6 @@ impl CipherRecv { } } -/// Allocate some data, with additional indexes to track -/// where nonce and tags are -#[derive(Debug, Clone)] -pub struct Data { - data: Vec, - skip_start: usize, - skip_end: usize, -} - -impl Data { - /// Get the slice where you will write the actual data - /// this will skip the actual nonce and AEAD tag and give you - /// only the space for the data - pub fn get_slice(&mut self) -> &mut [u8] { - &mut self.data[self.skip_start..self.skip_end] - } - fn get_tag_slice(&mut self) -> &mut [u8] { - let start = self.data.len() - self.skip_end; - &mut self.data[start..] - } - fn get_slice_full(&mut self) -> &mut [u8] { - &mut self.data - } - /// Consume the data and return the whole raw vector - pub fn get_raw(self) -> Vec { - self.data - } -} - /// Send only cipher pub struct CipherSend { nonce: NonceSync, @@ -308,14 +275,6 @@ impl CipherSend { cipher: Cipher::new(kind, secret), } } - /// Allocate the memory for the data that will be encrypted - pub fn make_data(&self, length: usize) -> Data { - Data { - data: Vec::with_capacity(length + self.cipher.overhead()), - skip_start: self.cipher.nonce_len().0, - skip_end: self.cipher.tag_len().0, - } - } /// Encrypt the given data pub fn encrypt(&self, aad: AAD, data: &mut [u8]) -> Result<(), Error> { let old_nonce = self.nonce.advance(); @@ -380,10 +339,7 @@ impl Nonce { use ring::rand::SecureRandom; let mut raw = [0; 12]; rand.fill(&mut raw); - #[allow(unsafe_code)] - unsafe { - Self { raw } - } + Self { raw } } /// Length of this nonce in bytes pub const fn len() -> usize { @@ -398,10 +354,7 @@ impl Nonce { } /// Create Nonce from array pub fn from_slice(raw: [u8; 12]) -> Self { - #[allow(unsafe_code)] - unsafe { - Self { raw } - } + Self { raw } } /// Go to the next nonce pub fn advance(&mut self) { diff --git a/src/inner/mod.rs b/src/inner/mod.rs index 72447f2..ec09d1f 100644 --- a/src/inner/mod.rs +++ b/src/inner/mod.rs @@ -14,12 +14,11 @@ use crate::{ enc::{ self, asym, hkdf::HkdfSha3, - sym::{CipherKind, CipherRecv, CipherSend, HeadLen, TagLen}, + sym::{CipherKind, CipherRecv}, }, Error, }; -use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption}; -use ::std::{rc::Rc, sync::Arc, vec::Vec}; +use ::std::{rc::Rc, vec::Vec}; /// Information needed to reply after the key exchange #[derive(Debug, Clone)] @@ -98,10 +97,7 @@ impl HandshakeTracker { mut handshake: Handshake, handshake_raw: &mut [u8], ) -> Result { - use connection::handshake::{ - dirsync::{self, DirSync}, - HandshakeData, - }; + use connection::handshake::{dirsync::DirSync, HandshakeData}; match handshake.data { HandshakeData::DirSync(ref mut ds) => match ds { DirSync::Req(ref mut req) => { diff --git a/src/inner/worker.rs b/src/inner/worker.rs index d11efdb..c7adf10 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -1,23 +1,25 @@ //! Worker thread implementation use crate::{ - auth::TokenChecker, + auth::{ServiceID, TokenChecker}, connection::{ self, handshake::{ - self, dirsync::{self, DirSync}, - Handshake, HandshakeClient, HandshakeData, + Handshake, HandshakeData, }, socket::{UdpClient, UdpServer}, - ConnList, Connection, IDSend, Packet, ID, + ConnList, Connection, IDSend, Packet, }, + dnssec, enc::{hkdf::HkdfSha3, sym::Secret}, inner::{HandshakeAction, HandshakeTracker, ThreadTracker}, }; use ::std::{rc::Rc, sync::Arc, vec::Vec}; /// This worker must be cpu-pinned -use ::tokio::{net::UdpSocket, sync::Mutex}; -use std::net::SocketAddr; +use ::tokio::{ + net::UdpSocket, + sync::{oneshot, Mutex}, +}; /// Track a raw Udp packet pub(crate) struct RawUdp { @@ -28,8 +30,15 @@ pub(crate) struct RawUdp { } pub(crate) enum Work { + /// ask the thread to report to the main thread the total number of + /// connections present + CountConnections(oneshot::Sender), + Connect((oneshot::Sender, dnssec::Record, ServiceID)), Recv(RawUdp), } +pub(crate) enum WorkAnswer { + UNUSED, +} /// Actual worker implementation. pub(crate) struct Worker { @@ -131,6 +140,13 @@ impl Worker { } }; match work { + Work::CountConnections(sender) => { + let conn_num = self.connections.len(); + let _ = sender.send(conn_num); + } + Work::Connect((send_res, dnssec_record, service_id)) => { + todo!() + } //TODO: reconf message to add channels Work::Recv(pkt) => { self.recv(pkt).await; @@ -285,7 +301,6 @@ impl Worker { return; } // track connection - use handshake::dirsync; let resp_data; if let dirsync::RespInner::ClearText(r_data) = ds_resp.data { @@ -313,6 +328,7 @@ impl Worker { return; } // create and track the connection to the service + // SECURITY: //FIXME: the Secret should be XORed with the client stored // secret (if any) let hkdf = HkdfSha3::new( @@ -328,7 +344,7 @@ impl Worker { service_connection.id_recv = cci.service_connection_id; service_connection.id_send = IDSend(resp_data.service_connection_id); - self.connections.track(service_connection.into()); + let _ = self.connections.track(service_connection.into()); return; } _ => {} diff --git a/src/lib.rs b/src/lib.rs index 02df264..1aed2e2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,12 +20,9 @@ pub mod dnssec; pub mod enc; mod inner; -use ::std::{ - net::SocketAddr, - sync::{Arc, Weak}, - vec::Vec, -}; -use ::tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle}; +use ::std::{sync::Arc, vec::Vec}; +use ::tokio::net::UdpSocket; +use auth::ServiceID; use crate::{ auth::TokenChecker, @@ -94,9 +91,7 @@ impl Drop for Fenrir { impl Fenrir { /// Create a new Fenrir endpoint pub fn new(config: &Config) -> Result { - let listen_num = config.listen.len(); let (sender, _) = ::tokio::sync::broadcast::channel(1); - let (work_send, work_recv) = ::async_channel::unbounded::(); let endpoint = Fenrir { cfg: config.clone(), sockets: SocketList::new(), @@ -127,23 +122,23 @@ impl Fenrir { /// asyncronous version for Drop fn stop_sync(&mut self) { let _ = self.stop_working.send(true); - let mut toempty_sockets = self.sockets.rm_all(); + let toempty_sockets = self.sockets.rm_all(); let task = ::tokio::task::spawn(toempty_sockets.stop_all()); let _ = ::futures::executor::block_on(task); let mut old_thread_pool = Vec::new(); ::std::mem::swap(&mut self._thread_pool, &mut old_thread_pool); - old_thread_pool.into_iter().map(|th| th.join()); + let _ = old_thread_pool.into_iter().map(|th| th.join()); self.dnssec = None; } /// Stop all workers, listeners pub async fn stop(&mut self) { let _ = self.stop_working.send(true); - let mut toempty_sockets = self.sockets.rm_all(); + let toempty_sockets = self.sockets.rm_all(); toempty_sockets.stop_all().await; let mut old_thread_pool = Vec::new(); ::std::mem::swap(&mut self._thread_pool, &mut old_thread_pool); - old_thread_pool.into_iter().map(|th| th.join()); + let _ = old_thread_pool.into_iter().map(|th| th.join()); self.dnssec = None; } /// Add all UDP sockets found in config @@ -166,7 +161,7 @@ impl Fenrir { self._thread_work.clone(), arc_s.clone(), )); - self.sockets.add_socket(arc_s, join); + self.sockets.add_socket(arc_s, join).await; } Err(e) => { return Err(e); @@ -218,18 +213,19 @@ impl Fenrir { } } }; - work_queues[thread_idx].send(Work::Recv(RawUdp { - src: UdpClient(sock_sender), - dst: sock_receiver, - packet, - data, - })); + let _ = work_queues[thread_idx] + .send(Work::Recv(RawUdp { + src: UdpClient(sock_sender), + dst: sock_receiver, + packet, + data, + })) + .await; } Ok(()) } - /// Get the raw TXT record of a Fenrir domain - pub async fn resolv_str(&self, domain: &str) -> Result { + pub async fn resolv_txt(&self, domain: &str) -> Result { match &self.dnssec { Some(dnssec) => Ok(dnssec.resolv(domain).await?), None => Err(Error::NotInitialized), @@ -238,10 +234,60 @@ impl Fenrir { /// Get the raw TXT record of a Fenrir domain pub async fn resolv(&self, domain: &str) -> Result { - let record_str = self.resolv_str(domain).await?; + let record_str = self.resolv_txt(domain).await?; Ok(dnssec::Dnssec::parse_txt_record(&record_str)?) } + /// Connect to a service + pub async fn connect( + &self, + domain: &str, + service: ServiceID, + ) -> Result<(), Error> { + let resolved = self.resolv(domain).await?; + + // find the thread with less connections + + let th_num = self._thread_work.len(); + let mut conn_count = Vec::::with_capacity(th_num); + let mut wait_res = + Vec::<::tokio::sync::oneshot::Receiver>::with_capacity( + th_num, + ); + for th in self._thread_work.iter() { + let (send, recv) = ::tokio::sync::oneshot::channel(); + wait_res.push(recv); + let _ = th.send(Work::CountConnections(send)).await; + } + for ch in wait_res.into_iter() { + if let Ok(conn_num) = ch.await { + conn_count.push(conn_num); + } + } + if conn_count.len() != th_num { + return Err(Error::IO(::std::io::Error::new( + ::std::io::ErrorKind::NotConnected, + "can't connect to a thread", + ))); + } + let thread_idx = conn_count + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.cmp(b)) + .map(|(index, _)| index) + .unwrap(); + + // and tell that thread to connect somewhere + let (send, recv) = ::tokio::sync::oneshot::channel(); + let _ = self._thread_work[thread_idx] + .send(Work::Connect((send, resolved, service))) + .await; + + let _conn_res = recv.await; + + todo!() + } + /// Start one working thread for each physical cpu /// threads are pinned to each cpu core. /// Work will be divided and rerouted so that there is no need to lock