Connect boilerplate, cleanup
Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
This commit is contained in:
parent
e71167224c
commit
1259996201
|
@ -37,6 +37,7 @@
|
|||
#})
|
||||
clippy
|
||||
cargo-watch
|
||||
cargo-flamegraph
|
||||
cargo-license
|
||||
lld
|
||||
rust-bin.stable."1.69.0".default
|
||||
|
|
|
@ -19,8 +19,6 @@ use crate::{
|
|||
};
|
||||
|
||||
use ::arrayref::array_mut_ref;
|
||||
use ::std::{collections::VecDeque, num::NonZeroU64, vec::Vec};
|
||||
use trust_dns_client::rr::rdata::key::Protocol;
|
||||
|
||||
type Nonce = [u8; 16];
|
||||
|
||||
|
@ -304,7 +302,7 @@ impl RespInner {
|
|||
pub fn len(&self) -> usize {
|
||||
match self {
|
||||
RespInner::CipherText(len) => *len,
|
||||
RespInner::ClearText(d) => RespData::len(),
|
||||
RespInner::ClearText(_) => RespData::len(),
|
||||
}
|
||||
}
|
||||
/*
|
||||
|
|
|
@ -7,7 +7,7 @@ use crate::{
|
|||
enc::sym::{HeadLen, TagLen},
|
||||
};
|
||||
use ::num_traits::FromPrimitive;
|
||||
use ::std::{rc::Rc, sync::Arc};
|
||||
use ::std::rc::Rc;
|
||||
|
||||
/// Handshake errors
|
||||
#[derive(::thiserror::Error, Debug, Copy, Clone)]
|
||||
|
@ -145,10 +145,6 @@ impl Handshake {
|
|||
self.fenrir_version.serialize(&mut out[0]);
|
||||
self.data.serialize(head_len, tag_len, &mut out[1..]);
|
||||
}
|
||||
|
||||
pub(crate) fn work(&self, keys: &[HandshakeServer]) -> Result<(), Error> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
trait HandshakeParsing {
|
||||
|
|
|
@ -4,7 +4,7 @@ pub mod handshake;
|
|||
pub mod packet;
|
||||
pub mod socket;
|
||||
|
||||
use ::std::{rc::Rc, sync::Arc, vec::Vec};
|
||||
use ::std::{rc::Rc, vec::Vec};
|
||||
|
||||
pub use crate::connection::{
|
||||
handshake::Handshake,
|
||||
|
@ -110,7 +110,7 @@ pub(crate) struct ConnList {
|
|||
|
||||
impl ConnList {
|
||||
pub(crate) fn new(thread_id: ThreadTracker) -> Self {
|
||||
let mut bitmap_id = ::bitmaps::Bitmap::<1024>::new();
|
||||
let bitmap_id = ::bitmaps::Bitmap::<1024>::new();
|
||||
const INITIAL_CAP: usize = 128;
|
||||
let mut ret = Self {
|
||||
thread_id,
|
||||
|
@ -120,6 +120,13 @@ impl ConnList {
|
|||
ret.connections.resize_with(INITIAL_CAP, || None);
|
||||
ret
|
||||
}
|
||||
pub fn len(&self) -> usize {
|
||||
let mut total: usize = 0;
|
||||
for bitmap in self.ids_used.iter() {
|
||||
total = total + bitmap.len()
|
||||
}
|
||||
total
|
||||
}
|
||||
/// Only *Reserve* a connection,
|
||||
/// without actually tracking it in self.connections
|
||||
pub(crate) fn reserve_first(
|
||||
|
|
|
@ -1,15 +1,12 @@
|
|||
//! Socket related types and functions
|
||||
|
||||
use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption};
|
||||
use ::std::{
|
||||
net::SocketAddr,
|
||||
sync::Arc,
|
||||
vec::{self, Vec},
|
||||
};
|
||||
use ::tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle};
|
||||
use ::arc_swap::ArcSwap;
|
||||
use ::std::{net::SocketAddr, sync::Arc, vec::Vec};
|
||||
use ::tokio::{net::UdpSocket, task::JoinHandle};
|
||||
|
||||
/// Pair to easily track the socket and its async listening handle
|
||||
pub type SocketTracker = (Arc<UdpSocket>, Arc<JoinHandle<::std::io::Result<()>>>);
|
||||
pub type SocketTracker =
|
||||
(Arc<UdpSocket>, Arc<JoinHandle<::std::io::Result<()>>>);
|
||||
|
||||
/// async free socket list
|
||||
pub(crate) struct SocketList {
|
||||
|
@ -48,7 +45,7 @@ impl SocketList {
|
|||
});
|
||||
}
|
||||
/// This method assumes no other `add_sockets` are being run
|
||||
pub(crate) async fn stop_all(mut self) {
|
||||
pub(crate) async fn stop_all(self) {
|
||||
let mut arc_list = self.list.into_inner();
|
||||
let list = loop {
|
||||
match Arc::try_unwrap(arc_list) {
|
||||
|
@ -63,7 +60,7 @@ impl SocketList {
|
|||
}
|
||||
};
|
||||
for (_socket, mut handle) in list.into_iter() {
|
||||
Arc::get_mut(&mut handle).unwrap().await;
|
||||
let _ = Arc::get_mut(&mut handle).unwrap().await;
|
||||
}
|
||||
}
|
||||
pub(crate) fn lock(&self) -> SocketListRef {
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
//! Asymmetric key handling and wrappers
|
||||
|
||||
use ::num_traits::FromPrimitive;
|
||||
use ::std::vec::Vec;
|
||||
|
||||
use super::Error;
|
||||
use crate::enc::sym::Secret;
|
||||
|
|
|
@ -51,13 +51,10 @@ impl HkdfSha3 {
|
|||
/// Instantiate a new HKDF with Sha3-256
|
||||
pub fn new(salt: &[u8], key: Secret) -> Self {
|
||||
let hkdf = Hkdf::<Sha3_256>::new(Some(salt), key.as_ref());
|
||||
#[allow(unsafe_code)]
|
||||
unsafe {
|
||||
Self {
|
||||
inner: HkdfInner {
|
||||
hkdf: ::core::mem::ManuallyDrop::new(hkdf),
|
||||
},
|
||||
}
|
||||
Self {
|
||||
inner: HkdfInner {
|
||||
hkdf: ::core::mem::ManuallyDrop::new(hkdf),
|
||||
},
|
||||
}
|
||||
}
|
||||
/// Get a secret generated from the key and a given context
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
//! Symmetric cypher stuff
|
||||
|
||||
use super::Error;
|
||||
use ::std::collections::VecDeque;
|
||||
use ::zeroize::Zeroize;
|
||||
|
||||
/// Secret, used for keys.
|
||||
|
@ -174,7 +173,7 @@ impl Cipher {
|
|||
}
|
||||
fn overhead(&self) -> usize {
|
||||
match self {
|
||||
Cipher::XChaCha20Poly1305(cipher) => {
|
||||
Cipher::XChaCha20Poly1305(_) => {
|
||||
let cipher = CipherKind::XChaCha20Poly1305;
|
||||
cipher.nonce_len().0 + cipher.tag_len().0
|
||||
}
|
||||
|
@ -189,9 +188,7 @@ impl Cipher {
|
|||
// FIXME: check minimum buffer size
|
||||
match self {
|
||||
Cipher::XChaCha20Poly1305(cipher) => {
|
||||
use ::chacha20poly1305::{
|
||||
aead::generic_array::GenericArray, AeadInPlace,
|
||||
};
|
||||
use ::chacha20poly1305::AeadInPlace;
|
||||
let tag_len: usize = ::ring::aead::CHACHA20_POLY1305.tag_len();
|
||||
let data_len_notag = data.len() - tag_len;
|
||||
// write nonce
|
||||
|
@ -211,10 +208,9 @@ impl Cipher {
|
|||
Ok(())
|
||||
}
|
||||
Err(_) => Err(Error::Encrypt),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -253,35 +249,6 @@ impl CipherRecv {
|
|||
}
|
||||
}
|
||||
|
||||
/// Allocate some data, with additional indexes to track
|
||||
/// where nonce and tags are
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Data {
|
||||
data: Vec<u8>,
|
||||
skip_start: usize,
|
||||
skip_end: usize,
|
||||
}
|
||||
|
||||
impl Data {
|
||||
/// Get the slice where you will write the actual data
|
||||
/// this will skip the actual nonce and AEAD tag and give you
|
||||
/// only the space for the data
|
||||
pub fn get_slice(&mut self) -> &mut [u8] {
|
||||
&mut self.data[self.skip_start..self.skip_end]
|
||||
}
|
||||
fn get_tag_slice(&mut self) -> &mut [u8] {
|
||||
let start = self.data.len() - self.skip_end;
|
||||
&mut self.data[start..]
|
||||
}
|
||||
fn get_slice_full(&mut self) -> &mut [u8] {
|
||||
&mut self.data
|
||||
}
|
||||
/// Consume the data and return the whole raw vector
|
||||
pub fn get_raw(self) -> Vec<u8> {
|
||||
self.data
|
||||
}
|
||||
}
|
||||
|
||||
/// Send only cipher
|
||||
pub struct CipherSend {
|
||||
nonce: NonceSync,
|
||||
|
@ -308,14 +275,6 @@ impl CipherSend {
|
|||
cipher: Cipher::new(kind, secret),
|
||||
}
|
||||
}
|
||||
/// Allocate the memory for the data that will be encrypted
|
||||
pub fn make_data(&self, length: usize) -> Data {
|
||||
Data {
|
||||
data: Vec::with_capacity(length + self.cipher.overhead()),
|
||||
skip_start: self.cipher.nonce_len().0,
|
||||
skip_end: self.cipher.tag_len().0,
|
||||
}
|
||||
}
|
||||
/// Encrypt the given data
|
||||
pub fn encrypt(&self, aad: AAD, data: &mut [u8]) -> Result<(), Error> {
|
||||
let old_nonce = self.nonce.advance();
|
||||
|
@ -380,10 +339,7 @@ impl Nonce {
|
|||
use ring::rand::SecureRandom;
|
||||
let mut raw = [0; 12];
|
||||
rand.fill(&mut raw);
|
||||
#[allow(unsafe_code)]
|
||||
unsafe {
|
||||
Self { raw }
|
||||
}
|
||||
Self { raw }
|
||||
}
|
||||
/// Length of this nonce in bytes
|
||||
pub const fn len() -> usize {
|
||||
|
@ -398,10 +354,7 @@ impl Nonce {
|
|||
}
|
||||
/// Create Nonce from array
|
||||
pub fn from_slice(raw: [u8; 12]) -> Self {
|
||||
#[allow(unsafe_code)]
|
||||
unsafe {
|
||||
Self { raw }
|
||||
}
|
||||
Self { raw }
|
||||
}
|
||||
/// Go to the next nonce
|
||||
pub fn advance(&mut self) {
|
||||
|
|
|
@ -14,12 +14,11 @@ use crate::{
|
|||
enc::{
|
||||
self, asym,
|
||||
hkdf::HkdfSha3,
|
||||
sym::{CipherKind, CipherRecv, CipherSend, HeadLen, TagLen},
|
||||
sym::{CipherKind, CipherRecv},
|
||||
},
|
||||
Error,
|
||||
};
|
||||
use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption};
|
||||
use ::std::{rc::Rc, sync::Arc, vec::Vec};
|
||||
use ::std::{rc::Rc, vec::Vec};
|
||||
|
||||
/// Information needed to reply after the key exchange
|
||||
#[derive(Debug, Clone)]
|
||||
|
@ -98,10 +97,7 @@ impl HandshakeTracker {
|
|||
mut handshake: Handshake,
|
||||
handshake_raw: &mut [u8],
|
||||
) -> Result<HandshakeAction, Error> {
|
||||
use connection::handshake::{
|
||||
dirsync::{self, DirSync},
|
||||
HandshakeData,
|
||||
};
|
||||
use connection::handshake::{dirsync::DirSync, HandshakeData};
|
||||
match handshake.data {
|
||||
HandshakeData::DirSync(ref mut ds) => match ds {
|
||||
DirSync::Req(ref mut req) => {
|
||||
|
|
|
@ -1,23 +1,25 @@
|
|||
//! Worker thread implementation
|
||||
use crate::{
|
||||
auth::TokenChecker,
|
||||
auth::{ServiceID, TokenChecker},
|
||||
connection::{
|
||||
self,
|
||||
handshake::{
|
||||
self,
|
||||
dirsync::{self, DirSync},
|
||||
Handshake, HandshakeClient, HandshakeData,
|
||||
Handshake, HandshakeData,
|
||||
},
|
||||
socket::{UdpClient, UdpServer},
|
||||
ConnList, Connection, IDSend, Packet, ID,
|
||||
ConnList, Connection, IDSend, Packet,
|
||||
},
|
||||
dnssec,
|
||||
enc::{hkdf::HkdfSha3, sym::Secret},
|
||||
inner::{HandshakeAction, HandshakeTracker, ThreadTracker},
|
||||
};
|
||||
use ::std::{rc::Rc, sync::Arc, vec::Vec};
|
||||
/// This worker must be cpu-pinned
|
||||
use ::tokio::{net::UdpSocket, sync::Mutex};
|
||||
use std::net::SocketAddr;
|
||||
use ::tokio::{
|
||||
net::UdpSocket,
|
||||
sync::{oneshot, Mutex},
|
||||
};
|
||||
|
||||
/// Track a raw Udp packet
|
||||
pub(crate) struct RawUdp {
|
||||
|
@ -28,8 +30,15 @@ pub(crate) struct RawUdp {
|
|||
}
|
||||
|
||||
pub(crate) enum Work {
|
||||
/// ask the thread to report to the main thread the total number of
|
||||
/// connections present
|
||||
CountConnections(oneshot::Sender<usize>),
|
||||
Connect((oneshot::Sender<u16>, dnssec::Record, ServiceID)),
|
||||
Recv(RawUdp),
|
||||
}
|
||||
pub(crate) enum WorkAnswer {
|
||||
UNUSED,
|
||||
}
|
||||
|
||||
/// Actual worker implementation.
|
||||
pub(crate) struct Worker {
|
||||
|
@ -131,6 +140,13 @@ impl Worker {
|
|||
}
|
||||
};
|
||||
match work {
|
||||
Work::CountConnections(sender) => {
|
||||
let conn_num = self.connections.len();
|
||||
let _ = sender.send(conn_num);
|
||||
}
|
||||
Work::Connect((send_res, dnssec_record, service_id)) => {
|
||||
todo!()
|
||||
}
|
||||
//TODO: reconf message to add channels
|
||||
Work::Recv(pkt) => {
|
||||
self.recv(pkt).await;
|
||||
|
@ -285,7 +301,6 @@ impl Worker {
|
|||
return;
|
||||
}
|
||||
// track connection
|
||||
use handshake::dirsync;
|
||||
let resp_data;
|
||||
if let dirsync::RespInner::ClearText(r_data) = ds_resp.data
|
||||
{
|
||||
|
@ -313,6 +328,7 @@ impl Worker {
|
|||
return;
|
||||
}
|
||||
// create and track the connection to the service
|
||||
// SECURITY:
|
||||
//FIXME: the Secret should be XORed with the client stored
|
||||
// secret (if any)
|
||||
let hkdf = HkdfSha3::new(
|
||||
|
@ -328,7 +344,7 @@ impl Worker {
|
|||
service_connection.id_recv = cci.service_connection_id;
|
||||
service_connection.id_send =
|
||||
IDSend(resp_data.service_connection_id);
|
||||
self.connections.track(service_connection.into());
|
||||
let _ = self.connections.track(service_connection.into());
|
||||
return;
|
||||
}
|
||||
_ => {}
|
||||
|
|
90
src/lib.rs
90
src/lib.rs
|
@ -20,12 +20,9 @@ pub mod dnssec;
|
|||
pub mod enc;
|
||||
mod inner;
|
||||
|
||||
use ::std::{
|
||||
net::SocketAddr,
|
||||
sync::{Arc, Weak},
|
||||
vec::Vec,
|
||||
};
|
||||
use ::tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle};
|
||||
use ::std::{sync::Arc, vec::Vec};
|
||||
use ::tokio::net::UdpSocket;
|
||||
use auth::ServiceID;
|
||||
|
||||
use crate::{
|
||||
auth::TokenChecker,
|
||||
|
@ -94,9 +91,7 @@ impl Drop for Fenrir {
|
|||
impl Fenrir {
|
||||
/// Create a new Fenrir endpoint
|
||||
pub fn new(config: &Config) -> Result<Self, Error> {
|
||||
let listen_num = config.listen.len();
|
||||
let (sender, _) = ::tokio::sync::broadcast::channel(1);
|
||||
let (work_send, work_recv) = ::async_channel::unbounded::<Work>();
|
||||
let endpoint = Fenrir {
|
||||
cfg: config.clone(),
|
||||
sockets: SocketList::new(),
|
||||
|
@ -127,23 +122,23 @@ impl Fenrir {
|
|||
/// asyncronous version for Drop
|
||||
fn stop_sync(&mut self) {
|
||||
let _ = self.stop_working.send(true);
|
||||
let mut toempty_sockets = self.sockets.rm_all();
|
||||
let toempty_sockets = self.sockets.rm_all();
|
||||
let task = ::tokio::task::spawn(toempty_sockets.stop_all());
|
||||
let _ = ::futures::executor::block_on(task);
|
||||
let mut old_thread_pool = Vec::new();
|
||||
::std::mem::swap(&mut self._thread_pool, &mut old_thread_pool);
|
||||
old_thread_pool.into_iter().map(|th| th.join());
|
||||
let _ = old_thread_pool.into_iter().map(|th| th.join());
|
||||
self.dnssec = None;
|
||||
}
|
||||
|
||||
/// Stop all workers, listeners
|
||||
pub async fn stop(&mut self) {
|
||||
let _ = self.stop_working.send(true);
|
||||
let mut toempty_sockets = self.sockets.rm_all();
|
||||
let toempty_sockets = self.sockets.rm_all();
|
||||
toempty_sockets.stop_all().await;
|
||||
let mut old_thread_pool = Vec::new();
|
||||
::std::mem::swap(&mut self._thread_pool, &mut old_thread_pool);
|
||||
old_thread_pool.into_iter().map(|th| th.join());
|
||||
let _ = old_thread_pool.into_iter().map(|th| th.join());
|
||||
self.dnssec = None;
|
||||
}
|
||||
/// Add all UDP sockets found in config
|
||||
|
@ -166,7 +161,7 @@ impl Fenrir {
|
|||
self._thread_work.clone(),
|
||||
arc_s.clone(),
|
||||
));
|
||||
self.sockets.add_socket(arc_s, join);
|
||||
self.sockets.add_socket(arc_s, join).await;
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(e);
|
||||
|
@ -218,18 +213,19 @@ impl Fenrir {
|
|||
}
|
||||
}
|
||||
};
|
||||
work_queues[thread_idx].send(Work::Recv(RawUdp {
|
||||
src: UdpClient(sock_sender),
|
||||
dst: sock_receiver,
|
||||
packet,
|
||||
data,
|
||||
}));
|
||||
let _ = work_queues[thread_idx]
|
||||
.send(Work::Recv(RawUdp {
|
||||
src: UdpClient(sock_sender),
|
||||
dst: sock_receiver,
|
||||
packet,
|
||||
data,
|
||||
}))
|
||||
.await;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the raw TXT record of a Fenrir domain
|
||||
pub async fn resolv_str(&self, domain: &str) -> Result<String, Error> {
|
||||
pub async fn resolv_txt(&self, domain: &str) -> Result<String, Error> {
|
||||
match &self.dnssec {
|
||||
Some(dnssec) => Ok(dnssec.resolv(domain).await?),
|
||||
None => Err(Error::NotInitialized),
|
||||
|
@ -238,10 +234,60 @@ impl Fenrir {
|
|||
|
||||
/// Get the raw TXT record of a Fenrir domain
|
||||
pub async fn resolv(&self, domain: &str) -> Result<dnssec::Record, Error> {
|
||||
let record_str = self.resolv_str(domain).await?;
|
||||
let record_str = self.resolv_txt(domain).await?;
|
||||
Ok(dnssec::Dnssec::parse_txt_record(&record_str)?)
|
||||
}
|
||||
|
||||
/// Connect to a service
|
||||
pub async fn connect(
|
||||
&self,
|
||||
domain: &str,
|
||||
service: ServiceID,
|
||||
) -> Result<(), Error> {
|
||||
let resolved = self.resolv(domain).await?;
|
||||
|
||||
// find the thread with less connections
|
||||
|
||||
let th_num = self._thread_work.len();
|
||||
let mut conn_count = Vec::<usize>::with_capacity(th_num);
|
||||
let mut wait_res =
|
||||
Vec::<::tokio::sync::oneshot::Receiver<usize>>::with_capacity(
|
||||
th_num,
|
||||
);
|
||||
for th in self._thread_work.iter() {
|
||||
let (send, recv) = ::tokio::sync::oneshot::channel();
|
||||
wait_res.push(recv);
|
||||
let _ = th.send(Work::CountConnections(send)).await;
|
||||
}
|
||||
for ch in wait_res.into_iter() {
|
||||
if let Ok(conn_num) = ch.await {
|
||||
conn_count.push(conn_num);
|
||||
}
|
||||
}
|
||||
if conn_count.len() != th_num {
|
||||
return Err(Error::IO(::std::io::Error::new(
|
||||
::std::io::ErrorKind::NotConnected,
|
||||
"can't connect to a thread",
|
||||
)));
|
||||
}
|
||||
let thread_idx = conn_count
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.cmp(b))
|
||||
.map(|(index, _)| index)
|
||||
.unwrap();
|
||||
|
||||
// and tell that thread to connect somewhere
|
||||
let (send, recv) = ::tokio::sync::oneshot::channel();
|
||||
let _ = self._thread_work[thread_idx]
|
||||
.send(Work::Connect((send, resolved, service)))
|
||||
.await;
|
||||
|
||||
let _conn_res = recv.await;
|
||||
|
||||
todo!()
|
||||
}
|
||||
|
||||
/// Start one working thread for each physical cpu
|
||||
/// threads are pinned to each cpu core.
|
||||
/// Work will be divided and rerouted so that there is no need to lock
|
||||
|
|
Loading…
Reference in New Issue