Compare commits

..

No commits in common. "transport" and "main" have entirely different histories.

14 changed files with 310 additions and 1641 deletions

5
TODO
View File

@ -1,6 +1 @@
* Wrapping for everything that wraps (sigh) * Wrapping for everything that wraps (sigh)
* track user connection (add u64 from user)
* API plit
* split API in ThreadLocal, ThreadSafe
* split send/recv API in Centralized, Connection
* all re wrappers on ThreadLocal-Centralized

View File

@ -5,11 +5,11 @@
"systems": "systems" "systems": "systems"
}, },
"locked": { "locked": {
"lastModified": 1687171271, "lastModified": 1685518550,
"narHash": "sha256-BJlq+ozK2B1sJDQXS3tzJM5a+oVZmi1q0FlBK/Xqv7M=", "narHash": "sha256-o2d0KcvaXzTrPRIo0kOLV0/QXHhDQ5DTi+OxcjO8xqY=",
"owner": "numtide", "owner": "numtide",
"repo": "flake-utils", "repo": "flake-utils",
"rev": "abfb11bd1aec8ced1c9bb9adfe68018230f4fb3c", "rev": "a1720a10a6cfe8234c0e93907ffe81be440f4cef",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -38,11 +38,11 @@
}, },
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1687555006, "lastModified": 1686921029,
"narHash": "sha256-GD2Kqb/DXQBRJcHqkM2qFZqbVenyO7Co/80JHRMg2U0=", "narHash": "sha256-J1bX9plPCFhTSh6E3TWn9XSxggBh/zDD4xigyaIQBy8=",
"owner": "nixos", "owner": "nixos",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "33223d479ffde3d05ac16c6dff04ae43cc27e577", "rev": "c7ff1b9b95620ce8728c0d7bd501c458e6da9e04",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -54,11 +54,11 @@
}, },
"nixpkgs-unstable": { "nixpkgs-unstable": {
"locked": { "locked": {
"lastModified": 1687502512, "lastModified": 1686960236,
"narHash": "sha256-dBL/01TayOSZYxtY4cMXuNCBk8UMLoqRZA+94xiFpJA=", "narHash": "sha256-AYCC9rXNLpUWzD9hm+askOfpliLEC9kwAo7ITJc4HIw=",
"owner": "nixos", "owner": "nixos",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "3ae20aa58a6c0d1ca95c9b11f59a2d12eebc511f", "rev": "04af42f3b31dba0ef742d254456dc4c14eedac86",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -98,11 +98,11 @@
"nixpkgs": "nixpkgs_2" "nixpkgs": "nixpkgs_2"
}, },
"locked": { "locked": {
"lastModified": 1687660699, "lastModified": 1687055571,
"narHash": "sha256-crI/CA/OJc778I5qJhwhhl8/PKKzc0D7vvVxOtjfvSo=", "narHash": "sha256-UvLoO6u5n9TzY80BpM4DaacxvyJl7u9mm9CA72d309g=",
"owner": "oxalica", "owner": "oxalica",
"repo": "rust-overlay", "repo": "rust-overlay",
"rev": "b3bd1d49f1ae609c1d68a66bba7a95a9a4256031", "rev": "2de557c780dcb127128ae987fca9d6c2b0d7dc0f",
"type": "github" "type": "github"
}, },
"original": { "original": {

View File

@ -3,9 +3,8 @@
use crate::{ use crate::{
auth::{Domain, ServiceID}, auth::{Domain, ServiceID},
connection::{ connection::{
self,
handshake::{self, Error, Handshake}, handshake::{self, Error, Handshake},
Connection, IDRecv, IDSend, Conn, IDRecv, IDSend,
}, },
enc::{ enc::{
self, self,
@ -19,27 +18,20 @@ use crate::{
use ::tokio::sync::oneshot; use ::tokio::sync::oneshot;
pub(crate) struct Server { pub(crate) struct Server {
pub(crate) id: KeyID, pub id: KeyID,
pub(crate) key: PrivKey, pub key: PrivKey,
pub(crate) domains: Vec<Domain>, pub domains: Vec<Domain>,
} }
pub(crate) type ConnectAnswer = Result<ConnectOk, crate::Error>; pub(crate) type ConnectAnswer = Result<(KeyID, IDSend), crate::Error>;
#[derive(Debug)]
pub(crate) struct ConnectOk {
pub(crate) auth_key_id: KeyID,
pub(crate) auth_id_send: IDSend,
pub(crate) authsrv_conn: connection::AuthSrvConn,
pub(crate) service_conn: Option<connection::ServiceConn>,
}
pub(crate) struct Client { pub(crate) struct Client {
pub(crate) service_id: ServiceID, pub service_id: ServiceID,
pub(crate) service_conn_id: IDRecv, pub service_conn_id: IDRecv,
pub(crate) connection: Connection, pub connection: Conn,
pub(crate) timeout: Option<::tokio::time::Instant>, pub timeout: Option<::tokio::task::JoinHandle<()>>,
pub(crate) answer: oneshot::Sender<ConnectAnswer>, pub answer: oneshot::Sender<ConnectAnswer>,
pub(crate) srv_key_id: KeyID, pub srv_key_id: KeyID,
} }
/// Tracks the keys used by the client and the handshake /// Tracks the keys used by the client and the handshake
@ -86,7 +78,7 @@ impl ClientList {
pub_key: PubKey, pub_key: PubKey,
service_id: ServiceID, service_id: ServiceID,
service_conn_id: IDRecv, service_conn_id: IDRecv,
connection: Connection, connection: Conn,
answer: oneshot::Sender<ConnectAnswer>, answer: oneshot::Sender<ConnectAnswer>,
srv_key_id: KeyID, srv_key_id: KeyID,
) -> Result<(KeyID, &mut Client), oneshot::Sender<ConnectAnswer>> { ) -> Result<(KeyID, &mut Client), oneshot::Sender<ConnectAnswer>> {
@ -136,28 +128,26 @@ impl ClientList {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) struct AuthNeededInfo { pub(crate) struct AuthNeededInfo {
/// Parsed handshake packet /// Parsed handshake packet
pub(crate) handshake: Handshake, pub handshake: Handshake,
/// hkdf generated from the handshake /// hkdf generated from the handshake
pub(crate) hkdf: Hkdf, pub hkdf: Hkdf,
} }
/// Client information needed to fully establish the conenction /// Client information needed to fully establish the conenction
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct ClientConnectInfo { pub(crate) struct ClientConnectInfo {
/// The service ID that we are connecting to /// The service ID that we are connecting to
pub(crate) service_id: ServiceID, pub service_id: ServiceID,
/// The service ID that we are connecting to /// The service ID that we are connecting to
pub(crate) service_connection_id: IDRecv, pub service_connection_id: IDRecv,
/// Parsed handshake packet /// Parsed handshake packet
pub(crate) handshake: Handshake, pub handshake: Handshake,
/// Old timeout for the handshake completion /// Conn
pub(crate) old_timeout: ::tokio::time::Instant, pub connection: Conn,
/// Connection
pub(crate) connection: Connection,
/// where to wake up the waiting client /// where to wake up the waiting client
pub(crate) answer: oneshot::Sender<ConnectAnswer>, pub answer: oneshot::Sender<ConnectAnswer>,
/// server pub(crate)lic key id that we used on the handshake /// server public key id that we used on the handshake
pub(crate) srv_key_id: KeyID, pub srv_key_id: KeyID,
} }
/// Intermediate actions to be taken while parsing the handshake /// Intermediate actions to be taken while parsing the handshake
#[derive(Debug)] #[derive(Debug)]
@ -241,7 +231,7 @@ impl Tracker {
pub_key: PubKey, pub_key: PubKey,
service_id: ServiceID, service_id: ServiceID,
service_conn_id: IDRecv, service_conn_id: IDRecv,
connection: Connection, connection: Conn,
answer: oneshot::Sender<ConnectAnswer>, answer: oneshot::Sender<ConnectAnswer>,
srv_key_id: KeyID, srv_key_id: KeyID,
) -> Result<(KeyID, &mut Client), oneshot::Sender<ConnectAnswer>> { ) -> Result<(KeyID, &mut Client), oneshot::Sender<ConnectAnswer>> {
@ -376,11 +366,13 @@ impl Tracker {
} }
let hshake = let hshake =
self.hshake_cli.remove(resp.client_key_id).unwrap(); self.hshake_cli.remove(resp.client_key_id).unwrap();
if let Some(timeout) = hshake.timeout {
timeout.abort();
}
return Ok(Action::ClientConnect(ClientConnectInfo { return Ok(Action::ClientConnect(ClientConnectInfo {
service_id: hshake.service_id, service_id: hshake.service_id,
service_connection_id: hshake.service_conn_id, service_connection_id: hshake.service_conn_id,
handshake, handshake,
old_timeout: hshake.timeout.unwrap(),
connection: hshake.connection, connection: hshake.connection,
answer: hshake.answer, answer: hshake.answer,
srv_key_id: hshake.srv_key_id, srv_key_id: hshake.srv_key_id,

View File

@ -5,43 +5,21 @@ pub mod packet;
pub mod socket; pub mod socket;
pub mod stream; pub mod stream;
use ::core::num::Wrapping; use ::std::{rc::Rc, vec::Vec};
use ::std::{
collections::{BTreeMap, HashMap, VecDeque},
vec::Vec,
};
pub use crate::connection::{handshake::Handshake, packet::Packet}; pub use crate::connection::{handshake::Handshake, packet::Packet};
use crate::{ use crate::{
connection::{socket::UdpClient, stream::StreamData},
dnssec, dnssec,
enc::{ enc::{
self,
asym::PubKey, asym::PubKey,
hkdf::Hkdf, hkdf::Hkdf,
sym::{self, CipherRecv, CipherSend}, sym::{self, CipherRecv, CipherSend},
Random, Random,
}, },
inner::{worker, ThreadTracker}, inner::ThreadTracker,
}; };
use ::std::rc;
/// Connection errors
#[derive(::thiserror::Error, Debug, Copy, Clone)]
pub enum Error {
/// Can't decrypt packet
#[error("Decrypt error: {0}")]
Decrypt(#[from] crate::enc::Error),
/// Error in parsing a packet realated to the connection
#[error("Chunk parsing: {0}")]
Parse(#[from] stream::Error),
/// No such Connection
#[error("No suck connection")]
NoSuchConnection,
/// No such Stream
#[error("No suck Stream")]
NoSuchStream,
}
/// Fenrir Connection ID /// Fenrir Connection ID
/// ///
@ -148,71 +126,24 @@ impl ProtocolVersion {
} }
} }
/// Connection tracking id. Set by the user
#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
pub struct UserTracker(pub ::core::num::NonZeroU64);
/// Unique tracker of connections
#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
pub struct LibTracker(Wrapping<u64>);
impl LibTracker {
pub(crate) fn new(start: u16) -> Self {
Self(Wrapping(start as u64))
}
pub(crate) fn advance(&mut self, amount: u16) -> Self {
let old = self.0;
self.0 = self.0 + Wrapping(amount as u64);
LibTracker(old)
}
}
/// Collection of connection tracking, but user-given and library generated
#[derive(Debug, Copy, Clone)]
pub struct ConnTracker {
/// Optional tracker set by the user
pub user: Option<UserTracker>,
/// library generated tracker. Unique and non-repeating
pub(crate) lib: LibTracker,
}
impl PartialEq for ConnTracker {
fn eq(&self, other: &Self) -> bool {
self.lib == other.lib
}
}
impl Eq for ConnTracker {}
/// Connection to an Authentication Server
#[derive(Debug, Copy, Clone)]
pub struct AuthSrvConn(pub ConnTracker);
/// Connection to a service
#[derive(Debug, Copy, Clone)]
pub struct ServiceConn(pub ConnTracker);
/*
* TODO: only on Thread{Local,Safe}::Connection oriented flows
/// The connection, as seen from a user of libFenrir /// The connection, as seen from a user of libFenrir
#[derive(Debug)] #[derive(Debug)]
pub struct Conn { pub struct Connection(rc::Weak<Conn>);
pub(crate) queue: ::async_channel::Sender<worker::Work>,
pub(crate) tracker: ConnTracker,
}
impl Conn { /// A single connection and its data
/// Queue some data to be sent in this connection #[derive(Debug)]
// TODO: send_and_wait, that wait for recipient ACK pub(crate) struct Conn {
pub async fn send(&mut self, stream: stream::ID, data: Vec<u8>) { /// Receiving Conn ID
use crate::inner::worker::Work; pub id_recv: IDRecv,
let _ = self /// Sending Conn ID
.queue pub id_send: IDSend,
.send(Work::UserSend((self.tracker.lib, stream, data))) /// The main hkdf used for all secrets in this connection
.await; pub hkdf: Hkdf,
/// Cipher for decrypting data
pub cipher_recv: CipherRecv,
/// Cipher for encrypting data
pub cipher_send: CipherSend,
} }
/// Get the library tracking id
pub fn tracker(&self) -> ConnTracker {
self.tracker
}
}
*/
/// Role: track the connection direction /// Role: track the connection direction
/// ///
@ -229,48 +160,7 @@ pub enum Role {
Client, Client,
} }
#[derive(Debug)] impl Conn {
enum TimerKind {
None,
SendData(::tokio::time::Instant),
Keepalive(::tokio::time::Instant),
}
pub(crate) enum Enqueue {
TimerWait,
Immediate(::tokio::time::Instant),
}
/// A single connection and its data
#[derive(Debug)]
pub(crate) struct Connection {
/// Receiving Conn ID
pub(crate) id_recv: IDRecv,
/// Sending Conn ID
pub(crate) id_send: IDSend,
/// User-managed id to track this connection
/// the user can set this to better track this connection
pub(crate) user_tracker: Option<UserTracker>,
pub(crate) lib_tracker: LibTracker,
/// Sending address
pub(crate) send_addr: UdpClient,
/// The main hkdf used for all secrets in this connection
hkdf: Hkdf,
/// Cipher for decrypting data
pub(crate) cipher_recv: CipherRecv,
/// Cipher for encrypting data
pub(crate) cipher_send: CipherSend,
mtu: usize,
next_timer: TimerKind,
/// send queue for each Stream
send_queue: BTreeMap<stream::ID, stream::SendTracker>,
last_stream_sent: stream::ID,
/// receive queue for each Stream
recv_queue: BTreeMap<stream::ID, stream::Stream>,
streams_ready: VecDeque<stream::ID>,
}
impl Connection {
pub(crate) fn new( pub(crate) fn new(
hkdf: Hkdf, hkdf: Hkdf,
cipher: sym::Kind, cipher: sym::Kind,
@ -288,194 +178,19 @@ impl Connection {
let cipher_recv = CipherRecv::new(cipher, secret_recv); let cipher_recv = CipherRecv::new(cipher, secret_recv);
let cipher_send = CipherSend::new(cipher, secret_send, rand); let cipher_send = CipherSend::new(cipher, secret_send, rand);
use ::std::net::{IpAddr, Ipv4Addr, SocketAddr};
Self { Self {
id_recv: IDRecv(ID::Handshake), id_recv: IDRecv(ID::Handshake),
id_send: IDSend(ID::Handshake), id_send: IDSend(ID::Handshake),
user_tracker: None,
lib_tracker: LibTracker::new(0),
// will be overwritten
send_addr: UdpClient(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
31337,
)),
hkdf, hkdf,
cipher_recv, cipher_recv,
cipher_send, cipher_send,
mtu: 1200,
next_timer: TimerKind::None,
send_queue: BTreeMap::new(),
last_stream_sent: stream::ID(0),
recv_queue: BTreeMap::new(),
streams_ready: VecDeque::with_capacity(4),
}
}
pub(crate) fn get_data(&mut self) -> Option<Vec<(stream::ID, Vec<u8>)>> {
if self.streams_ready.is_empty() {
return None;
}
let ret_len = self.streams_ready.len();
let mut ret = Vec::with_capacity(ret_len);
while let Some(stream_id) = self.streams_ready.pop_front() {
let stream = match self.recv_queue.get_mut(&stream_id) {
Some(stream) => stream,
None => continue,
};
ret.push((stream_id, stream.get()));
}
Some(ret)
}
pub(crate) fn recv(
&mut self,
mut udp: crate::RawUdp,
) -> Result<StreamData, Error> {
let mut data = &mut udp.data[ID::len()..];
let aad = enc::sym::AAD(&[]);
self.cipher_recv.decrypt(aad, &mut data)?;
let mut bytes_parsed = 0;
let mut chunks = Vec::with_capacity(2);
loop {
let chunk = match stream::Chunk::deserialize(&data[bytes_parsed..])
{
Ok(chunk) => chunk,
Err(e) => {
return Err(e.into());
}
};
bytes_parsed = bytes_parsed + chunk.len();
chunks.push(chunk);
if bytes_parsed == data.len() {
break;
}
}
let mut data_ready = StreamData::NotReady;
for chunk in chunks.into_iter() {
let stream_id = chunk.id;
let stream = match self.recv_queue.get_mut(&stream_id) {
Some(stream) => stream,
None => {
::tracing::debug!("Ignoring chunk for unknown stream::ID");
continue;
}
};
match stream.recv(chunk) {
Ok(status) => {
if !self.streams_ready.contains(&stream_id) {
self.streams_ready.push_back(stream_id);
}
data_ready = data_ready | status;
}
Err(e) => ::tracing::debug!("stream: {:?}: {:?}", stream_id, e),
}
}
Ok(data_ready)
}
pub(crate) fn enqueue(
&mut self,
stream: stream::ID,
data: Vec<u8>,
) -> Result<Enqueue, Error> {
let stream = match self.send_queue.get_mut(&stream) {
None => return Err(Error::NoSuchStream),
Some(stream) => stream,
};
stream.enqueue(data);
let instant;
let ret;
self.next_timer = match self.next_timer {
TimerKind::None | TimerKind::Keepalive(_) => {
instant = ::tokio::time::Instant::now();
ret = Enqueue::Immediate(instant);
TimerKind::SendData(instant)
}
TimerKind::SendData(old_timer) => {
// There already is some data to be sent
// wait for this timer,
// or risk going over max transmission rate
ret = Enqueue::TimerWait;
TimerKind::SendData(old_timer)
}
};
Ok(ret)
}
pub(crate) fn write_pkt<'a>(
&mut self,
raw: &'a mut [u8],
) -> Result<&'a [u8], enc::Error> {
assert!(raw.len() >= self.mtu, "I should have at least 1200 MTU");
if self.send_queue.len() == 0 {
return Err(enc::Error::NotEnoughData(0));
}
raw[..ID::len()]
.copy_from_slice(&self.id_send.0.as_u64().to_le_bytes());
let data_from = ID::len() + self.cipher_send.nonce_len().0;
let data_max_to = raw.len() - self.cipher_send.tag_len().0;
let mut chunk_from = data_from;
let mut available_len = data_max_to - data_from;
use std::ops::Bound::{Excluded, Included};
let last_stream = self.last_stream_sent;
// Loop over our streams, write them to the packet.
// Notes:
// * to avoid starvation, just round-robin them all for now
// * we can enqueue multiple times the same stream
// This is useful especially for Datagram streams
'queueloop: {
for (id, stream) in self
.send_queue
.range_mut((Included(last_stream), Included(stream::ID::max())))
{
if available_len < stream::Chunk::headers_len() + 1 {
break 'queueloop;
}
let bytes =
stream.serialize(*id, &mut raw[chunk_from..data_max_to]);
if bytes == 0 {
break 'queueloop;
}
available_len = available_len - bytes;
chunk_from = chunk_from + bytes;
self.last_stream_sent = *id;
}
if available_len > 0 {
for (id, stream) in self.send_queue.range_mut((
Included(stream::ID::min()),
Excluded(last_stream),
)) {
if available_len < stream::Chunk::headers_len() + 1 {
break 'queueloop;
}
let bytes = stream
.serialize(*id, &mut raw[chunk_from..data_max_to]);
if bytes == 0 {
break 'queueloop;
}
available_len = available_len - bytes;
chunk_from = chunk_from + bytes;
self.last_stream_sent = *id;
}
}
}
if chunk_from == data_from {
return Err(enc::Error::NotEnoughData(0));
}
let data_to = chunk_from + self.cipher_send.tag_len().0;
// encrypt
let aad = sym::AAD(&[]);
match self.cipher_send.encrypt(aad, &mut raw[data_from..data_to]) {
Ok(_) => Ok(&raw[..data_to]),
Err(e) => Err(e),
} }
} }
} }
pub(crate) struct ConnList { pub(crate) struct ConnList {
thread_id: ThreadTracker, thread_id: ThreadTracker,
connections: Vec<Option<Connection>>, connections: Vec<Option<Rc<Conn>>>,
user_tracker: BTreeMap<LibTracker, usize>,
last_tracked: LibTracker,
/// Bitmap to track which connection ids are used or free /// Bitmap to track which connection ids are used or free
ids_used: Vec<::bitmaps::Bitmap<1024>>, ids_used: Vec<::bitmaps::Bitmap<1024>>,
} }
@ -491,43 +206,11 @@ impl ConnList {
let mut ret = Self { let mut ret = Self {
thread_id, thread_id,
connections: Vec::with_capacity(INITIAL_CAP), connections: Vec::with_capacity(INITIAL_CAP),
user_tracker: BTreeMap::new(),
last_tracked: LibTracker(Wrapping(0)),
ids_used: vec![bitmap_id], ids_used: vec![bitmap_id],
}; };
ret.connections.resize_with(INITIAL_CAP, || None); ret.connections.resize_with(INITIAL_CAP, || None);
ret ret
} }
pub fn get_id_mut(&mut self, id: ID) -> Result<&mut Connection, Error> {
let conn_id = match id {
ID::ID(conn_id) => conn_id,
ID::Handshake => {
return Err(Error::NoSuchConnection);
}
};
let id_in_thread: usize =
(conn_id.get() / (self.thread_id.total as u64)) as usize;
if let Some(conn) = &mut self.connections[id_in_thread] {
Ok(conn)
} else {
return Err(Error::NoSuchConnection);
}
}
pub fn get_mut(
&mut self,
tracker: LibTracker,
) -> Result<&mut Connection, Error> {
let idx = if let Some(idx) = self.user_tracker.get(&tracker) {
*idx
} else {
return Err(Error::NoSuchConnection);
};
if let Some(conn) = &mut self.connections[idx] {
Ok(conn)
} else {
return Err(Error::NoSuchConnection);
}
}
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
let mut total: usize = 0; let mut total: usize = 0;
for bitmap in self.ids_used.iter() { for bitmap in self.ids_used.iter() {
@ -537,23 +220,7 @@ impl ConnList {
} }
/// Only *Reserve* a connection, /// Only *Reserve* a connection,
/// without actually tracking it in self.connections /// without actually tracking it in self.connections
pub(crate) fn reserve_and_track<'a>(
&'a mut self,
mut conn: Connection,
) -> (LibTracker, &'a mut Connection) {
let (id_conn, id_in_thread) = self.reserve_first_with_idx();
conn.id_recv = id_conn;
let tracker = self.get_new_tracker(id_in_thread);
conn.lib_tracker = tracker;
self.connections[id_in_thread] = Some(conn);
(tracker, self.connections[id_in_thread].as_mut().unwrap())
}
/// Only *Reserve* a connection,
/// without actually tracking it in self.connections
pub(crate) fn reserve_first(&mut self) -> IDRecv { pub(crate) fn reserve_first(&mut self) -> IDRecv {
self.reserve_first_with_idx().0
}
fn reserve_first_with_idx(&mut self) -> (IDRecv, usize) {
// uhm... bad things are going on here: // uhm... bad things are going on here:
// * id must be initialized, but only because: // * id must be initialized, but only because:
// * rust does not understand that after the `!found` id is always // * rust does not understand that after the `!found` id is always
@ -591,13 +258,10 @@ impl ConnList {
let actual_id = ((id_in_thread as u64) * (self.thread_id.total as u64)) let actual_id = ((id_in_thread as u64) * (self.thread_id.total as u64))
+ (self.thread_id.id as u64); + (self.thread_id.id as u64);
let new_id = IDRecv(ID::new_u64(actual_id)); let new_id = IDRecv(ID::new_u64(actual_id));
(new_id, id_in_thread) new_id
} }
/// NOTE: does NOT check if the connection has been previously reserved! /// NOTE: does NOT check if the connection has been previously reserved!
pub(crate) fn track( pub(crate) fn track(&mut self, conn: Rc<Conn>) -> Result<(), ()> {
&mut self,
mut conn: Connection,
) -> Result<LibTracker, ()> {
let conn_id = match conn.id_recv { let conn_id = match conn.id_recv {
IDRecv(ID::Handshake) => { IDRecv(ID::Handshake) => {
return Err(()); return Err(());
@ -606,22 +270,8 @@ impl ConnList {
}; };
let id_in_thread: usize = let id_in_thread: usize =
(conn_id.get() / (self.thread_id.total as u64)) as usize; (conn_id.get() / (self.thread_id.total as u64)) as usize;
let tracker = self.get_new_tracker(id_in_thread);
conn.lib_tracker = tracker;
self.connections[id_in_thread] = Some(conn); self.connections[id_in_thread] = Some(conn);
Ok(tracker) Ok(())
}
fn get_new_tracker(&mut self, id_in_thread: usize) -> LibTracker {
let mut tracker;
loop {
tracker = self.last_tracked.advance(self.thread_id.total);
if self.user_tracker.get(&tracker).is_none() {
// like, never gonna happen, it's 64 bit
let _ = self.user_tracker.insert(tracker, id_in_thread);
break;
}
}
tracker
} }
pub(crate) fn remove(&mut self, id: IDRecv) { pub(crate) fn remove(&mut self, id: IDRecv) {
if let IDRecv(ID::ID(raw_id)) = id { if let IDRecv(ID::ID(raw_id)) = id {
@ -653,6 +303,7 @@ enum MapEntry {
Present(IDSend), Present(IDSend),
Reserved, Reserved,
} }
use ::std::collections::HashMap;
/// Link the public key of the authentication server to a connection id /// Link the public key of the authentication server to a connection id
/// so that we can reuse that connection to ask for more authentications /// so that we can reuse that connection to ask for more authentications

View File

@ -1,12 +1,10 @@
//! Errors while parsing streams //! Errors while parsing streams
/// Crypto errors /// Crypto errors
#[derive(::thiserror::Error, Debug, Copy, Clone)] #[derive(::thiserror::Error, Debug, Copy, Clone)]
pub enum Error { pub enum Error {
/// Error while parsing key material /// Error while parsing key material
#[error("Not enough data for stream chunk: {0}")] #[error("Not enough data for stream chunk: {0}")]
NotEnoughData(usize), NotEnoughData(usize),
/// Sequence outside of the window
#[error("Sequence out of the sliding window")]
OutOfWindow,
} }

View File

@ -19,7 +19,7 @@ pub enum Kind {
} }
/// Id of the stream /// Id of the stream
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] #[derive(Debug, Copy, Clone)]
pub struct ID(pub u16); pub struct ID(pub u16);
impl ID { impl ID {
@ -27,14 +27,6 @@ impl ID {
pub const fn len() -> usize { pub const fn len() -> usize {
2 2
} }
/// Minimum possible Stream ID (u16::MIN)
pub const fn min() -> Self {
Self(u16::MIN)
}
/// Maximum possible Stream ID (u16::MAX)
pub const fn max() -> Self {
Self(u16::MAX)
}
} }
/// length of the chunk /// length of the chunk
@ -48,30 +40,6 @@ impl ChunkLen {
} }
} }
//TODO: make pub?
#[derive(Debug, Copy, Clone)]
pub(crate) struct SequenceStart(pub(crate) Sequence);
impl SequenceStart {
pub(crate) fn plus_u32(&self, other: u32) -> Sequence {
self.0.plus_u32(other)
}
pub(crate) fn offset(&self, seq: Sequence) -> usize {
if self.0 .0 <= seq.0 {
(seq.0 - self.0 .0).0 as usize
} else {
(seq.0 + (Sequence::max().0 - self.0 .0)).0 as usize
}
}
}
// SequenceEnd is INCLUSIVE
#[derive(Debug, Copy, Clone)]
pub(crate) struct SequenceEnd(pub(crate) Sequence);
impl SequenceEnd {
pub(crate) fn plus_u32(&self, other: u32) -> Sequence {
self.0.plus_u32(other)
}
}
/// Sequence number to rebuild the stream correctly /// Sequence number to rebuild the stream correctly
#[derive(Debug, Copy, Clone)] #[derive(Debug, Copy, Clone)]
pub struct Sequence(pub ::core::num::Wrapping<u32>); pub struct Sequence(pub ::core::num::Wrapping<u32>);
@ -80,52 +48,14 @@ impl Sequence {
const SEQ_NOFLAG: u32 = 0x3FFFFFFF; const SEQ_NOFLAG: u32 = 0x3FFFFFFF;
/// return a new sequence number, starting at random /// return a new sequence number, starting at random
pub fn new(rand: &Random) -> Self { pub fn new(rand: &Random) -> Self {
let mut raw_seq: [u8; 4] = [0; 4]; let seq: u32 = 0;
rand.fill(&mut raw_seq); rand.fill(&mut seq.to_le_bytes());
let seq = u32::from_le_bytes(raw_seq);
Self(::core::num::Wrapping(seq & Self::SEQ_NOFLAG)) Self(::core::num::Wrapping(seq & Self::SEQ_NOFLAG))
} }
/// Length of the serialized field /// Length of the serialized field
pub const fn len() -> usize { pub const fn len() -> usize {
4 4
} }
/// Maximum possible sequence
pub const fn max() -> Self {
Self(::core::num::Wrapping(Self::SEQ_NOFLAG))
}
pub(crate) fn is_between(
&self,
start: SequenceStart,
end: SequenceEnd,
) -> bool {
if start.0 .0 < end.0 .0 {
start.0 .0 <= self.0 && self.0 <= end.0 .0
} else {
start.0 .0 <= self.0 || self.0 <= end.0 .0
}
}
pub(crate) fn remaining_window(&self, end: SequenceEnd) -> u32 {
if self.0 <= end.0 .0 {
(end.0 .0 .0 - self.0 .0) + 1
} else {
end.0 .0 .0 + 1 + (Self::max().0 - self.0).0
}
}
pub(crate) fn plus_u32(self, other: u32) -> Self {
Self(::core::num::Wrapping(
(self.0 .0 + other) & Self::SEQ_NOFLAG,
))
}
}
impl ::core::ops::Add for Sequence {
type Output = Self;
fn add(self, other: Self) -> Self {
Self(::core::num::Wrapping(
(self.0 + other.0).0 & Self::SEQ_NOFLAG,
))
}
} }
/// Chunk of data representing a stream /// Chunk of data representing a stream
@ -149,10 +79,6 @@ impl<'a> Chunk<'a> {
const FLAGS_EXCLUDED_BITMASK: u8 = 0x3F; const FLAGS_EXCLUDED_BITMASK: u8 = 0x3F;
const FLAG_START_BITMASK: u8 = 0x80; const FLAG_START_BITMASK: u8 = 0x80;
const FLAG_END_BITMASK: u8 = 0x40; const FLAG_END_BITMASK: u8 = 0x40;
/// Return the length of the header of a Chunk
pub const fn headers_len() -> usize {
ID::len() + ChunkLen::len() + Sequence::len()
}
/// Returns the total length of the chunk, including headers /// Returns the total length of the chunk, including headers
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
ID::len() + ChunkLen::len() + Sequence::len() + self.data.len() ID::len() + ChunkLen::len() + Sequence::len() + self.data.len()
@ -238,26 +164,6 @@ impl Tracker {
} }
} }
#[derive(Debug, Eq, PartialEq)]
pub(crate) enum StreamData {
/// not enough data to return somthing to the user
NotReady = 0,
/// we can return something to the user
Ready,
}
impl ::core::ops::BitOr for StreamData {
type Output = Self;
// Required method
fn bitor(self, other: Self) -> Self::Output {
if self == StreamData::Ready || other == StreamData::Ready {
StreamData::Ready
} else {
StreamData::NotReady
}
}
}
/// Actual stream-tracking structure /// Actual stream-tracking structure
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) struct Stream { pub(crate) struct Stream {
@ -274,79 +180,4 @@ impl Stream {
data: Tracker::new(kind, rand), data: Tracker::new(kind, rand),
} }
} }
pub(crate) fn recv(&mut self, chunk: Chunk) -> Result<StreamData, Error> {
match &mut self.data {
Tracker::ROB(tracker) => tracker.recv(chunk),
}
}
pub(crate) fn get(&mut self) -> Vec<u8> {
match &mut self.data {
Tracker::ROB(tracker) => tracker.get(),
}
}
}
/// Track what has been sent and what has been ACK'd from a stream
#[derive(Debug)]
pub(crate) struct SendTracker {
queue: Vec<Vec<u8>>,
sent: Vec<usize>,
ackd: Vec<usize>,
chunk_started: bool,
is_datagram: bool,
next_sequence: Sequence,
}
impl SendTracker {
pub(crate) fn new(rand: &Random) -> Self {
Self {
queue: Vec::with_capacity(4),
sent: Vec::with_capacity(4),
ackd: Vec::with_capacity(4),
chunk_started: false,
is_datagram: false,
next_sequence: Sequence::new(rand),
}
}
/// Enqueue user data to be sent
pub(crate) fn enqueue(&mut self, data: Vec<u8>) {
self.queue.push(data);
self.sent.push(0);
self.ackd.push(0);
}
/// Write the user data to the buffer and mark it as sent
pub(crate) fn get(&mut self, out: &mut [u8]) -> usize {
let data = match self.queue.get(0) {
Some(data) => data,
None => return 0,
};
let len = ::std::cmp::min(out.len(), data.len());
out[..len].copy_from_slice(&data[self.sent[0]..len]);
self.sent[0] = self.sent[0] + len;
len
}
/// Mark the sent data as successfully received from the receiver
pub(crate) fn ack(&mut self, size: usize) {
todo!()
}
pub(crate) fn serialize(&mut self, id: ID, raw: &mut [u8]) -> usize {
let max_data_len = raw.len() - Chunk::headers_len();
let data_len = ::std::cmp::min(max_data_len, self.queue[0].len());
let flag_start = !self.chunk_started;
let flag_end = self.is_datagram && data_len == self.queue[0].len();
let chunk = Chunk {
id,
flag_start,
flag_end,
sequence: self.next_sequence,
data: &self.queue[0][..data_len],
};
self.next_sequence = Sequence(
self.next_sequence.0 + ::core::num::Wrapping(data_len as u32),
);
if chunk.flag_end {
self.chunk_started = false;
}
chunk.serialize(raw);
data_len
}
} }

View File

@ -0,0 +1,29 @@
//! Implementation of the Reliable, Ordered, Bytestream transmission model
//! AKA: TCP-like
use crate::{
connection::stream::{Chunk, Error, Sequence},
enc::Random,
};
/// Reliable, Ordered, Bytestream stream tracker
/// AKA: TCP-like
#[derive(Debug, Clone)]
pub(crate) struct ReliableOrderedBytestream {
window_start: Sequence,
window_len: usize,
data: Vec<u8>,
}
impl ReliableOrderedBytestream {
pub(crate) fn new(rand: &Random) -> Self {
Self {
window_start: Sequence::new(rand),
window_len: 1048576, // 1MB. should be enough for anybody. (lol)
data: Vec::new(),
}
}
pub(crate) fn recv(&mut self, chunk: Chunk) -> Result<(), Error> {
todo!()
}
}

View File

@ -1,210 +0,0 @@
//! Implementation of the Reliable, Ordered, Bytestream transmission model
//! AKA: TCP-like
use crate::{
connection::stream::{
Chunk, Error, Sequence, SequenceEnd, SequenceStart, StreamData,
},
enc::Random,
};
#[cfg(test)]
mod tests;
/// Reliable, Ordered, Bytestream stream tracker
/// AKA: TCP-like
#[derive(Debug, Clone)]
pub(crate) struct ReliableOrderedBytestream {
pub(crate) window_start: SequenceStart,
window_end: SequenceEnd,
pivot: u32,
data: Vec<u8>,
missing: Vec<(Sequence, Sequence)>,
}
impl ReliableOrderedBytestream {
pub(crate) fn new(rand: &Random) -> Self {
let window_len = 1048576; // 1MB. should be enough for anybody. (lol)
let window_start = SequenceStart(Sequence::new(rand));
let window_end = SequenceEnd(window_start.0.plus_u32(window_len - 1));
let mut data = Vec::with_capacity(window_len as usize);
data.resize(data.capacity(), 0);
Self {
window_start,
window_end,
pivot: window_len,
data,
missing: [(window_start.0, window_end.0)].to_vec(),
}
}
pub(crate) fn with_window_size(rand: &Random, size: u32) -> Self {
assert!(
size < Sequence::max().0 .0,
"Max window size is {}",
Sequence::max().0 .0
);
let window_len = size; // 1MB. should be enough for anybody. (lol)
let window_start = SequenceStart(Sequence::new(rand));
let window_end = SequenceEnd(window_start.0.plus_u32(window_len - 1));
let mut data = Vec::with_capacity(window_len as usize);
data.resize(data.capacity(), 0);
Self {
window_start,
window_end,
pivot: window_len,
data,
missing: [(window_start.0, window_end.0)].to_vec(),
}
}
pub(crate) fn window_size(&self) -> u32 {
self.data.len() as u32
}
pub(crate) fn get(&mut self) -> Vec<u8> {
if self.missing.len() == 0 {
let (first, second) = self.data.split_at(self.pivot as usize);
let mut ret = Vec::with_capacity(self.data.len());
ret.extend_from_slice(first);
ret.extend_from_slice(second);
self.window_start =
SequenceStart(self.window_start.plus_u32(ret.len() as u32));
self.window_end =
SequenceEnd(self.window_end.plus_u32(ret.len() as u32));
self.data.clear();
return ret;
}
let data_len = self.window_start.offset(self.missing[0].0);
let last_missing_idx = self.missing.len() - 1;
let mut last_missing = &mut self.missing[last_missing_idx];
last_missing.1 = last_missing.1.plus_u32(data_len as u32);
self.window_start =
SequenceStart(self.window_start.plus_u32(data_len as u32));
self.window_end =
SequenceEnd(self.window_end.plus_u32(data_len as u32));
let mut ret = Vec::with_capacity(data_len);
let (first, second) = self.data[..].split_at(self.pivot as usize);
let first_len = ::core::cmp::min(data_len, first.len());
let second_len = data_len - first_len;
ret.extend_from_slice(&first[..first_len]);
ret.extend_from_slice(&second[..second_len]);
self.pivot =
((self.pivot as usize + data_len) % self.data.len()) as u32;
ret
}
pub(crate) fn recv(&mut self, chunk: Chunk) -> Result<StreamData, Error> {
if !chunk
.sequence
.is_between(self.window_start, self.window_end)
{
return Err(Error::OutOfWindow);
}
// make sure we consider only the bytes inside the sliding window
let maxlen = ::std::cmp::min(
chunk.sequence.remaining_window(self.window_end) as usize,
chunk.data.len(),
);
if maxlen == 0 {
// empty window or empty chunk, but we don't care
return Err(Error::OutOfWindow);
}
// translate Sequences to offsets in self.data
let data = &chunk.data[..maxlen];
let offset = self.window_start.offset(chunk.sequence);
let offset_end = offset + chunk.data.len() - 1;
// Find the chunks we are missing that we can copy,
// and fix the missing tracker
let mut copy_ranges = Vec::new();
let mut to_delete = Vec::new();
let mut to_add = Vec::new();
// note: the ranges are (INCLUSIVE, INCLUSIVE)
for (idx, el) in self.missing.iter_mut().enumerate() {
let missing_from = self.window_start.offset(el.0);
if missing_from > offset_end {
break;
}
let missing_to = self.window_start.offset(el.1);
if missing_to < offset {
continue;
}
if missing_from >= offset && missing_from <= offset_end {
if missing_to <= offset_end {
// [.....chunk.....]
// [..missing..]
to_delete.push(idx);
copy_ranges.push((missing_from, missing_to));
} else {
// [....chunk....]
// [...missing...]
copy_ranges.push((missing_from, offset_end));
el.0 =
el.0.plus_u32(((offset_end - missing_from) + 1) as u32);
}
} else if missing_from < offset {
if missing_to > offset_end {
// [..chunk..]
// [....missing....]
to_add.push((
el.0.plus_u32(((offset_end - missing_from) + 1) as u32),
el.1,
));
el.1 = el.0.plus_u32(((offset - missing_from) - 1) as u32);
copy_ranges.push((offset, offset_end));
} else if offset <= missing_to {
// [....chunk....]
// [...missing...]
copy_ranges.push((offset, (missing_to - 0)));
el.1 =
el.0.plus_u32(((offset_end - missing_from) - 1) as u32);
}
}
}
{
let mut deleted = 0;
for idx in to_delete.into_iter() {
self.missing.remove(idx + deleted);
deleted = deleted + 1;
}
}
self.missing.append(&mut to_add);
self.missing
.sort_by(|(from_a, _), (from_b, _)| from_a.0 .0.cmp(&from_b.0 .0));
// copy only the missing data
let (first, second) = self.data[..].split_at_mut(self.pivot as usize);
for (from, to) in copy_ranges.into_iter() {
let to = to + 1;
if from <= first.len() {
let first_from = from;
let first_to = ::core::cmp::min(first.len(), to);
let data_first_from = from - offset;
let data_first_to = first_to - offset;
first[first_from..first_to]
.copy_from_slice(&data[data_first_from..data_first_to]);
let second_to = to - first_to;
let data_second_to = data_first_to + second_to;
second[..second_to]
.copy_from_slice(&data[data_first_to..data_second_to]);
} else {
let second_from = from - first.len();
let second_to = to - first.len();
let data_from = from - offset;
let data_to = to - offset;
second[second_from..second_to]
.copy_from_slice(&data[data_from..data_to]);
}
}
if self.missing.len() == 0
|| self.window_start.offset(self.missing[0].0) == 0
{
Ok(StreamData::Ready)
} else {
Ok(StreamData::NotReady)
}
}
}

View File

@ -1,249 +0,0 @@
use crate::{
connection::stream::{self, rob::*, Chunk},
enc::Random,
};
#[::tracing_test::traced_test]
#[test]
fn test_stream_rob_sequential() {
let rand = Random::new();
let mut rob = ReliableOrderedBytestream::with_window_size(&rand, 1048576);
let mut data = Vec::with_capacity(1024);
data.resize(data.capacity(), 0);
rand.fill(&mut data[..]);
let start = rob.window_start.0;
let chunk = Chunk {
id: stream::ID(42),
flag_start: true,
flag_end: false,
sequence: start,
data: &data[..512],
};
let got = rob.get();
assert!(&got[..] == &[], "rob: got data?");
let _ = rob.recv(chunk);
let got = rob.get();
assert!(
&data[..512] == &got[..],
"ROB1: DIFF: {:?} {:?}",
&data[..512].len(),
&got[..].len()
);
let chunk = Chunk {
id: stream::ID(42),
flag_start: false,
flag_end: true,
sequence: start.plus_u32(512),
data: &data[512..],
};
let _ = rob.recv(chunk);
let got = rob.get();
assert!(
&data[512..] == &got[..],
"ROB2: DIFF: {:?} {:?}",
&data[512..].len(),
&got[..].len()
);
}
#[::tracing_test::traced_test]
#[test]
fn test_stream_rob_retransmit() {
let rand = Random::new();
let max_window: usize = 100;
let mut rob =
ReliableOrderedBytestream::with_window_size(&rand, max_window as u32);
let mut data = Vec::with_capacity(120);
data.resize(data.capacity(), 0);
for i in 0..data.len() {
data[i] = i as u8;
}
let start = rob.window_start.0;
let chunk = Chunk {
id: stream::ID(42),
flag_start: true,
flag_end: false,
sequence: start,
data: &data[..40],
};
let _ = rob.recv(chunk);
let chunk = Chunk {
id: stream::ID(42),
flag_start: false,
flag_end: false,
sequence: start.plus_u32(50),
data: &data[50..60],
};
let _ = rob.recv(chunk);
let chunk = Chunk {
id: stream::ID(42),
flag_start: false,
flag_end: false,
sequence: start.plus_u32(40),
data: &data[40..60],
};
let _ = rob.recv(chunk);
let chunk = Chunk {
id: stream::ID(42),
flag_start: false,
flag_end: false,
sequence: start.plus_u32(80),
data: &data[80..],
};
let _ = rob.recv(chunk);
let chunk = Chunk {
id: stream::ID(42),
flag_start: false,
flag_end: false,
sequence: start.plus_u32(50),
data: &data[50..90],
};
let _ = rob.recv(chunk);
let chunk = Chunk {
id: stream::ID(42),
flag_start: false,
flag_end: false,
sequence: start.plus_u32(max_window as u32),
data: &data[max_window..],
};
let _ = rob.recv(chunk);
let chunk = Chunk {
id: stream::ID(42),
flag_start: false,
flag_end: true,
sequence: start.plus_u32(90),
data: &data[90..max_window],
};
let _ = rob.recv(chunk);
let got = rob.get();
assert!(
&data[..max_window] == &got[..],
"DIFF:\n {:?}\n {:?}",
&data[..max_window],
&got[..],
);
}
#[::tracing_test::traced_test]
#[test]
fn test_stream_rob_rolling() {
let rand = Random::new();
let max_window: usize = 100;
let mut rob =
ReliableOrderedBytestream::with_window_size(&rand, max_window as u32);
let mut data = Vec::with_capacity(120);
data.resize(data.capacity(), 0);
for i in 0..data.len() {
data[i] = i as u8;
}
let start = rob.window_start.0;
let chunk = Chunk {
id: stream::ID(42),
flag_start: true,
flag_end: false,
sequence: start,
data: &data[..40],
};
let _ = rob.recv(chunk);
let chunk = Chunk {
id: stream::ID(42),
flag_start: true,
flag_end: false,
sequence: start.plus_u32(50),
data: &data[50..100],
};
let _ = rob.recv(chunk);
let got = rob.get();
assert!(
&data[..40] == &got[..],
"DIFF:\n {:?}\n {:?}",
&data[..40],
&got[..],
);
let chunk = Chunk {
id: stream::ID(42),
flag_start: true,
flag_end: false,
sequence: start.plus_u32(40),
data: &data[40..],
};
let _ = rob.recv(chunk);
let got = rob.get();
assert!(
&data[40..] == &got[..],
"DIFF:\n {:?}\n {:?}",
&data[40..],
&got[..],
);
}
#[::tracing_test::traced_test]
#[test]
fn test_stream_rob_rolling_second_case() {
let rand = Random::new();
let max_window: usize = 100;
let mut rob =
ReliableOrderedBytestream::with_window_size(&rand, max_window as u32);
let mut data = Vec::with_capacity(120);
data.resize(data.capacity(), 0);
for i in 0..data.len() {
data[i] = i as u8;
}
let start = rob.window_start.0;
let chunk = Chunk {
id: stream::ID(42),
flag_start: true,
flag_end: false,
sequence: start,
data: &data[..40],
};
let _ = rob.recv(chunk);
let chunk = Chunk {
id: stream::ID(42),
flag_start: true,
flag_end: false,
sequence: start.plus_u32(50),
data: &data[50..100],
};
let _ = rob.recv(chunk);
let got = rob.get();
assert!(
&data[..40] == &got[..],
"DIFF:\n {:?}\n {:?}",
&data[..40],
&got[..],
);
let chunk = Chunk {
id: stream::ID(42),
flag_start: true,
flag_end: false,
sequence: start.plus_u32(40),
data: &data[40..100],
};
let _ = rob.recv(chunk);
let chunk = Chunk {
id: stream::ID(42),
flag_start: true,
flag_end: false,
sequence: start.plus_u32(100),
data: &data[100..],
};
let _ = rob.recv(chunk);
let got = rob.get();
assert!(
&data[40..] == &got[..],
"DIFF:\n {:?}\n {:?}",
&data[40..],
&got[..],
);
}

View File

@ -241,14 +241,6 @@ impl CipherSend {
pub fn kind(&self) -> Kind { pub fn kind(&self) -> Kind {
self.cipher.kind() self.cipher.kind()
} }
/// Get the length of the nonce for this cipher
pub fn nonce_len(&self) -> NonceLen {
self.cipher.nonce_len()
}
/// Get the length of the nonce for this cipher
pub fn tag_len(&self) -> TagLen {
self.cipher.tag_len()
}
} }
/// XChaCha20Poly1305 cipher /// XChaCha20Poly1305 cipher

View File

@ -4,10 +4,6 @@
pub(crate) mod worker; pub(crate) mod worker;
use crate::inner::worker::Work;
use ::std::{collections::BTreeMap, vec::Vec};
use ::tokio::time::Instant;
/// Track the total number of threads and our index /// Track the total number of threads and our index
/// 65K cpus should be enough for anybody /// 65K cpus should be enough for anybody
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
@ -16,109 +12,3 @@ pub(crate) struct ThreadTracker {
/// Note: starts from 1 /// Note: starts from 1
pub id: u16, pub id: u16,
} }
pub(crate) static mut SLEEP_RESOLUTION: ::std::time::Duration =
if cfg!(linux) || cfg!(macos) {
::std::time::Duration::from_millis(1)
} else {
// windows
::std::time::Duration::from_millis(16)
};
pub(crate) async fn set_minimum_sleep_resolution() {
let nanosleep = ::std::time::Duration::from_nanos(1);
let mut tests: usize = 3;
while tests > 0 {
let pre_sleep = ::std::time::Instant::now();
::tokio::time::sleep(nanosleep).await;
let post_sleep = ::std::time::Instant::now();
let slept_for = post_sleep - pre_sleep;
#[allow(unsafe_code)]
unsafe {
if slept_for < SLEEP_RESOLUTION {
SLEEP_RESOLUTION = slept_for;
}
}
tests = tests - 1;
}
}
/// Sleeping has a higher resolution that we would like for packet pacing.
/// So we sleep for however log we need, then chunk up all the work here
/// we will end up chunking the work in SLEEP_RESOLUTION, then we will busy wait
/// for more precise timing
pub(crate) struct Timers {
times: BTreeMap<Instant, Work>,
}
impl Timers {
pub(crate) fn new() -> Self {
Self {
times: BTreeMap::new(),
}
}
pub(crate) fn get_next(&self) -> ::tokio::time::Sleep {
match self.times.keys().next() {
Some(entry) => ::tokio::time::sleep_until((*entry).into()),
None => {
::tokio::time::sleep(::std::time::Duration::from_secs(3600))
}
}
}
pub(crate) fn add(
&mut self,
duration: ::tokio::time::Duration,
work: Work,
) -> ::tokio::time::Instant {
// the returned time is the key in the map.
// Make sure it is unique.
//
// We can be pretty sure we won't do a lot of stuff
// in a single nanosecond, so if we hit a time that is already present
// just add a nanosecond and retry
let mut time = ::tokio::time::Instant::now() + duration;
let mut work = work;
loop {
if let Some(old_val) = self.times.insert(time, work) {
work = self.times.insert(time, old_val).unwrap();
time = time + ::std::time::Duration::from_nanos(1);
} else {
break;
}
}
time
}
pub(crate) fn remove(&mut self, time: ::tokio::time::Instant) {
let _ = self.times.remove(&time);
}
/// Get all the work from now up until now + SLEEP_RESOLUTION
pub(crate) fn get_work(&mut self) -> Vec<Work> {
let now: ::tokio::time::Instant = ::std::time::Instant::now().into();
let mut ret = Vec::with_capacity(4);
let mut count_rm = 0;
#[allow(unsafe_code)]
let next_instant = unsafe { now + SLEEP_RESOLUTION };
let mut iter = self.times.iter_mut().peekable();
loop {
match iter.peek() {
None => break,
Some(next) => {
if *next.0 > next_instant {
break;
}
}
}
let mut work = Work::DropHandshake(crate::enc::asym::KeyID(0));
let mut entry = iter.next().unwrap();
::core::mem::swap(&mut work, &mut entry.1);
ret.push(work);
count_rm = count_rm + 1;
}
while count_rm > 0 {
self.times.pop_first();
count_rm = count_rm - 1;
}
ret
}
}

View File

@ -11,19 +11,17 @@ use crate::{
}, },
packet::{self, Packet}, packet::{self, Packet},
socket::{UdpClient, UdpServer}, socket::{UdpClient, UdpServer},
stream, AuthSrvConn, ConnList, ConnTracker, Connection, IDSend, Conn, ConnList, IDSend,
LibTracker, ServiceConn,
}, },
dnssec, dnssec,
enc::{ enc::{
self,
asym::{self, KeyID, PrivKey, PubKey}, asym::{self, KeyID, PrivKey, PubKey},
hkdf::{self, Hkdf}, hkdf::{self, Hkdf},
sym, Random, Secret, sym, Random, Secret,
}, },
inner::ThreadTracker, inner::ThreadTracker,
}; };
use ::std::{collections::VecDeque, sync::Arc, vec::Vec}; use ::std::{sync::Arc, vec::Vec};
/// This worker must be cpu-pinned /// This worker must be cpu-pinned
use ::tokio::{ use ::tokio::{
net::UdpSocket, net::UdpSocket,
@ -46,25 +44,6 @@ pub(crate) struct ConnectInfo {
// TODO: UserID, Token information // TODO: UserID, Token information
} }
/// return to the user the data received from a connection
#[derive(Debug, Clone)]
pub struct ConnData {
/// Connection tracking information
pub conn: ConnTracker,
/// received data, for each stream
pub data: Vec<(stream::ID, Vec<u8>)>,
}
/// Connection event. Mostly used to give the data to the user
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum Event {
/// Work loop has exited. nothing more to do
End,
/// Data from a connection
Data(ConnData),
}
pub(crate) enum Work { pub(crate) enum Work {
/// ask the thread to report to the main thread the total number of /// ask the thread to report to the main thread the total number of
/// connections present /// connections present
@ -72,8 +51,6 @@ pub(crate) enum Work {
Connect(ConnectInfo), Connect(ConnectInfo),
DropHandshake(KeyID), DropHandshake(KeyID),
Recv(RawUdp), Recv(RawUdp),
UserSend((LibTracker, stream::ID, Vec<u8>)),
SendData((LibTracker, ::tokio::time::Instant)),
} }
/// Actual worker implementation. /// Actual worker implementation.
@ -87,15 +64,11 @@ pub struct Worker {
token_check: Option<Arc<Mutex<TokenChecker>>>, token_check: Option<Arc<Mutex<TokenChecker>>>,
sockets: Vec<Arc<UdpSocket>>, sockets: Vec<Arc<UdpSocket>>,
queue: ::async_channel::Receiver<Work>, queue: ::async_channel::Receiver<Work>,
queue_sender: ::async_channel::Sender<Work>,
queue_timeouts_recv: mpsc::UnboundedReceiver<Work>, queue_timeouts_recv: mpsc::UnboundedReceiver<Work>,
queue_timeouts_send: mpsc::UnboundedSender<Work>, queue_timeouts_send: mpsc::UnboundedSender<Work>,
thread_channels: Vec<::async_channel::Sender<Work>>, thread_channels: Vec<::async_channel::Sender<Work>>,
connections: ConnList, connections: ConnList,
// connectsion untracker by the user. (users still needs to get(..) them)
untracked_connections: VecDeque<LibTracker>,
handshakes: handshake::Tracker, handshakes: handshake::Tracker,
work_timers: super::Timers,
} }
#[allow(unsafe_code)] #[allow(unsafe_code)]
@ -109,7 +82,6 @@ impl Worker {
token_check: Option<Arc<Mutex<TokenChecker>>>, token_check: Option<Arc<Mutex<TokenChecker>>>,
sockets: Vec<Arc<UdpSocket>>, sockets: Vec<Arc<UdpSocket>>,
queue: ::async_channel::Receiver<Work>, queue: ::async_channel::Receiver<Work>,
queue_sender: ::async_channel::Sender<Work>,
) -> ::std::io::Result<Self> { ) -> ::std::io::Result<Self> {
let (queue_timeouts_send, queue_timeouts_recv) = let (queue_timeouts_send, queue_timeouts_recv) =
mpsc::unbounded_channel(); mpsc::unbounded_channel();
@ -146,93 +118,17 @@ impl Worker {
token_check, token_check,
sockets, sockets,
queue, queue,
queue_sender,
queue_timeouts_recv, queue_timeouts_recv,
queue_timeouts_send, queue_timeouts_send,
thread_channels: Vec::new(), thread_channels: Vec::new(),
connections: ConnList::new(thread_id), connections: ConnList::new(thread_id),
untracked_connections: VecDeque::with_capacity(8),
handshakes, handshakes,
work_timers: super::Timers::new(),
}) })
} }
/// return a handle to the worker that you can use to send data
/// The handle will enqueue work in the main worker and is thread-local safe
///
/// While this does not require `&mut` on the `Worker`, everything
/// will be put in the work queue,
/// So you might have less immediate results in a few cases
pub fn handle(&self) -> Handle {
Handle {
queue: self.queue_sender.clone(),
}
}
/// change the UserTracker in the connection
///
/// This is `unsafe` because you will be responsible for manually updating
/// any copy of the `ConnTracker` you might have cloned around
#[allow(unsafe_code)]
pub unsafe fn set_connection_tracker(
&mut self,
tracker: ConnTracker,
new_id: connection::UserTracker,
) -> Result<ConnTracker, crate::Error> {
let conn = self.connections.get_mut(tracker.lib)?;
conn.user_tracker = Some(new_id);
Ok(ConnTracker {
lib: tracker.lib,
user: Some(new_id),
})
}
/// Enqueue data to send
pub fn send(
&mut self,
tracker: LibTracker,
stream: stream::ID,
data: Vec<u8>,
) -> Result<(), crate::Error> {
let conn = self.connections.get_mut(tracker)?;
conn.enqueue(stream, data)?;
Ok(())
}
/// Returns new connections, if any
///
/// You can provide an optional tracker, different from the library tracker.
///
/// Differently from the library tracker, you can change this later on,
/// but you will be responsible to change it on every `ConnTracker`
/// you might have cloned elsewhere
pub fn try_get_connection(
&mut self,
tracker: Option<connection::UserTracker>,
) -> Option<ConnTracker> {
let ret_tracker = ConnTracker {
lib: self.untracked_connections.pop_front()?,
user: None,
};
match tracker {
Some(tracker) => {
#[allow(unsafe_code)]
match unsafe {
self.set_connection_tracker(ret_tracker, tracker)
} {
Ok(tracker) => Some(tracker),
Err(_) => {
// we had a connection, but it expired before the user
// remembered to get it. Just remove it from the queue.
None
}
}
}
None => Some(ret_tracker),
}
}
/// Continuously loop and process work as needed /// Continuously loop and process work as needed
pub async fn work_loop(&mut self) -> Result<Event, crate::Error> { pub async fn work_loop(&mut self) {
'mainloop: loop { 'mainloop: loop {
let next_timer = self.work_timers.get_next();
::tokio::pin!(next_timer);
let work = ::tokio::select! { let work = ::tokio::select! {
tell_stopped = self.stop_working.recv() => { tell_stopped = self.stop_working.recv() => {
if let Ok(stop_ch) = tell_stopped { if let Ok(stop_ch) = tell_stopped {
@ -241,13 +137,6 @@ impl Worker {
} }
break; break;
} }
() = &mut next_timer => {
let work_list = self.work_timers.get_work();
for w in work_list.into_iter() {
let _ = self.queue_sender.send(w).await;
}
continue 'mainloop;
}
maybe_timeout = self.queue.recv() => { maybe_timeout = self.queue.recv() => {
match maybe_timeout { match maybe_timeout {
Ok(work) => work, Ok(work) => work,
@ -404,14 +293,12 @@ impl Worker {
// are PubKey::Exchange // are PubKey::Exchange
unreachable!() unreachable!()
} }
let mut conn = Connection::new( let mut conn = Conn::new(
hkdf, hkdf,
cipher_selected, cipher_selected,
connection::Role::Client, connection::Role::Client,
&self.rand, &self.rand,
); );
let dest = UdpClient(addr.as_sockaddr().unwrap());
conn.send_addr = dest;
let auth_recv_id = self.connections.reserve_first(); let auth_recv_id = self.connections.reserve_first();
let service_conn_id = self.connections.reserve_first(); let service_conn_id = self.connections.reserve_first();
@ -502,13 +389,15 @@ impl Worker {
// send always from the first socket // send always from the first socket
// FIXME: select based on routing table // FIXME: select based on routing table
let sender = self.sockets[0].local_addr().unwrap(); let sender = self.sockets[0].local_addr().unwrap();
let dest = UdpClient(addr.as_sockaddr().unwrap());
// start the timeout right before sending the packet // start the timeout right before sending the packet
let time_drop = self.work_timers.add( hshake.timeout = Some(::tokio::task::spawn_local(
::tokio::time::Duration::from_secs(10), Self::handshake_timeout(
Work::DropHandshake(client_key_id), self.queue_timeouts_send.clone(),
); client_key_id,
hshake.timeout = Some(time_drop); ),
));
// send packet // send packet
self.send_packet(raw, dest, UdpServer(sender)).await; self.send_packet(raw, dest, UdpServer(sender)).await;
@ -524,70 +413,21 @@ impl Worker {
} }
}; };
} }
Work::Recv(pkt) => match self.recv(pkt).await { Work::Recv(pkt) => {
Ok(event) => return Ok(event), self.recv(pkt).await;
Err(_) => continue 'mainloop,
},
Work::UserSend((tracker, stream, data)) => {
let conn = match self.connections.get_mut(tracker) {
Ok(conn) => conn,
Err(_) => continue 'mainloop,
};
use connection::Enqueue;
if let Ok(enqueued) = conn.enqueue(stream, data) {
match enqueued {
Enqueue::Immediate(instant) => {
let _ = self
.queue_sender
.send(Work::SendData((tracker, instant)))
.await;
}
Enqueue::TimerWait => {}
} }
} }
} }
Work::SendData((tracker, instant)) => {
// make sure we don't process events before they are
// actually needed.
// This is basically busy waiting with extra steps,
// but we don't want to spawn lots of timers and
// we don't really have a fine-grained sleep that is
// multiplatform
let now = ::tokio::time::Instant::now();
if instant <= now {
let _ = self
.queue_sender
.send(Work::SendData((tracker, instant)))
.await;
continue;
} }
async fn handshake_timeout(
let mut raw: Vec<u8> = Vec::with_capacity(1200); timeout_queue: mpsc::UnboundedSender<Work>,
raw.resize(raw.capacity(), 0); key_id: KeyID,
let conn = match self.connections.get_mut(tracker) { ) {
Ok(conn) => conn, ::tokio::time::sleep(::std::time::Duration::from_secs(10)).await;
Err(_) => continue, let _ = timeout_queue.send(Work::DropHandshake(key_id));
};
let pkt = match conn.write_pkt(&mut raw) {
Ok(pkt) => pkt,
Err(enc::Error::NotEnoughData(0)) => continue,
Err(e) => {
::tracing::error!("Packet generation: {:?}", e);
continue;
}
};
let dest = conn.send_addr;
let src = UdpServer(self.sockets[0].local_addr().unwrap());
let len = pkt.len();
raw.truncate(len);
let _ = self.send_packet(raw, dest, src);
}
}
}
Ok(Event::End)
} }
/// Read and do stuff with the raw udp packet /// Read and do stuff with the raw udp packet
async fn recv(&mut self, mut udp: RawUdp) -> Result<Event, ()> { async fn recv(&mut self, mut udp: RawUdp) {
if udp.packet.id.is_handshake() { if udp.packet.id.is_handshake() {
let handshake = match Handshake::deserialize( let handshake = match Handshake::deserialize(
&udp.data[connection::ID::len()..], &udp.data[connection::ID::len()..],
@ -595,7 +435,7 @@ impl Worker {
Ok(handshake) => handshake, Ok(handshake) => handshake,
Err(e) => { Err(e) => {
::tracing::debug!("Handshake parsing: {}", e); ::tracing::debug!("Handshake parsing: {}", e);
return Err(()); return;
} }
}; };
let action = match self.handshakes.recv_handshake( let action = match self.handshakes.recv_handshake(
@ -605,38 +445,9 @@ impl Worker {
Ok(action) => action, Ok(action) => action,
Err(err) => { Err(err) => {
::tracing::debug!("Handshake recv error {}", err); ::tracing::debug!("Handshake recv error {}", err);
return Err(()); return;
} }
}; };
self.recv_handshake(udp, action).await;
Err(())
} else {
self.recv_packet(udp)
}
}
/// Receive a non-handshake packet
fn recv_packet(&mut self, udp: RawUdp) -> Result<Event, ()> {
let conn = match self.connections.get_id_mut(udp.packet.id) {
Ok(conn) => conn,
Err(_) => return Err(()),
};
match conn.recv(udp) {
Ok(stream::StreamData::NotReady) => Err(()),
Ok(stream::StreamData::Ready) => Ok(Event::Data(ConnData {
conn: ConnTracker {
user: conn.user_tracker,
lib: conn.lib_tracker,
},
data: conn.get_data().unwrap(),
})),
Err(e) => {
::tracing::trace!("Conn Recv: {:?}", e.to_string());
Err(())
}
}
}
/// Receive an handshake packet
async fn recv_handshake(&mut self, udp: RawUdp, action: handshake::Action) {
match action { match action {
handshake::Action::AuthNeeded(authinfo) => { handshake::Action::AuthNeeded(authinfo) => {
let req; let req;
@ -661,7 +472,8 @@ impl Worker {
let maybe_auth_check = { let maybe_auth_check = {
match &self.token_check { match &self.token_check {
None => { None => {
if req_data.auth.user == auth::USERID_ANONYMOUS { if req_data.auth.user == auth::USERID_ANONYMOUS
{
Ok(true) Ok(true)
} else { } else {
Ok(false) Ok(false)
@ -703,20 +515,16 @@ impl Worker {
let head_len = req.cipher.nonce_len(); let head_len = req.cipher.nonce_len();
let tag_len = req.cipher.tag_len(); let tag_len = req.cipher.tag_len();
let mut auth_conn = Connection::new( let mut auth_conn = Conn::new(
authinfo.hkdf, authinfo.hkdf,
req.cipher, req.cipher,
connection::Role::Server, connection::Role::Server,
&self.rand, &self.rand,
); );
auth_conn.id_send = IDSend(req_data.id); auth_conn.id_send = IDSend(req_data.id);
auth_conn.send_addr = udp.src;
// track connection // track connection
let auth_id_recv = self.connections.reserve_first(); let auth_id_recv = self.connections.reserve_first();
auth_conn.id_recv = auth_id_recv; auth_conn.id_recv = auth_id_recv;
let (tracker, auth_conn) =
self.connections.reserve_and_track(auth_conn);
self.untracked_connections.push_back(tracker);
let resp_data = dirsync::resp::Data { let resp_data = dirsync::resp::Data {
client_nonce: req_data.nonce, client_nonce: req_data.nonce,
@ -736,9 +544,9 @@ impl Worker {
connection::ID::len() + resp.encrypted_offset(); connection::ID::len() + resp.encrypted_offset();
let encrypt_until = let encrypt_until =
encrypt_from + resp.encrypted_length(head_len, tag_len); encrypt_from + resp.encrypted_length(head_len, tag_len);
let resp_handshake = Handshake::new(handshake::Data::DirSync( let resp_handshake = Handshake::new(
DirSync::Resp(resp), handshake::Data::DirSync(DirSync::Resp(resp)),
)); );
let packet = Packet { let packet = Packet {
id: connection::ID::new_handshake(), id: connection::ID::new_handshake(),
data: packet::Data::Handshake(resp_handshake), data: packet::Data::Handshake(resp_handshake),
@ -758,7 +566,6 @@ impl Worker {
self.send_packet(raw_out, udp.src, udp.dst).await; self.send_packet(raw_out, udp.src, udp.dst).await;
} }
handshake::Action::ClientConnect(cci) => { handshake::Action::ClientConnect(cci) => {
self.work_timers.remove(cci.old_timeout);
let ds_resp; let ds_resp;
if let handshake::Data::DirSync(DirSync::Resp(resp)) = if let handshake::Data::DirSync(DirSync::Resp(resp)) =
cci.handshake.data cci.handshake.data
@ -770,7 +577,9 @@ impl Worker {
} }
// track connection // track connection
let resp_data; let resp_data;
if let dirsync::resp::State::ClearText(r_data) = ds_resp.data { if let dirsync::resp::State::ClearText(r_data) =
ds_resp.data
{
resp_data = r_data; resp_data = r_data;
} else { } else {
::tracing::error!( ::tracing::error!(
@ -778,31 +587,20 @@ impl Worker {
); );
unreachable!(); unreachable!();
} }
let auth_id_send = IDSend(resp_data.id); let auth_srv_conn = IDSend(resp_data.id);
let mut conn = cci.connection; let mut conn = cci.connection;
conn.id_send = auth_id_send; conn.id_send = auth_srv_conn;
let id_recv = conn.id_recv; let id_recv = conn.id_recv;
let cipher = conn.cipher_recv.kind(); let cipher = conn.cipher_recv.kind();
// track the connection to the authentication server // track the connection to the authentication server
let track_auth_conn = match self.connections.track(conn) { if self.connections.track(conn.into()).is_err() {
Ok(track_auth_conn) => track_auth_conn, ::tracing::error!("Could not track new connection");
Err(_) => {
::tracing::error!(
"Could not track new auth srv connection"
);
self.connections.remove(id_recv); self.connections.remove(id_recv);
// FIXME: proper connection closing
let _ = cci.answer.send(Err( let _ = cci.answer.send(Err(
handshake::Error::InternalTracking.into(), handshake::Error::InternalTracking.into(),
)); ));
return; return;
} }
};
let authsrv_conn = AuthSrvConn(ConnTracker {
lib: track_auth_conn,
user: None,
});
let mut service_conn = None;
if cci.service_id != auth::SERVICEID_AUTH { if cci.service_id != auth::SERVICEID_AUTH {
// create and track the connection to the service // create and track the connection to the service
// SECURITY: xor with secrets // SECURITY: xor with secrets
@ -813,7 +611,7 @@ impl Worker {
cci.service_id.as_bytes(), cci.service_id.as_bytes(),
resp_data.service_key, resp_data.service_key,
); );
let mut service_connection = Connection::new( let mut service_connection = Conn::new(
hkdf, hkdf,
cipher, cipher,
connection::Role::Client, connection::Role::Client,
@ -822,39 +620,16 @@ impl Worker {
service_connection.id_recv = cci.service_connection_id; service_connection.id_recv = cci.service_connection_id;
service_connection.id_send = service_connection.id_send =
IDSend(resp_data.service_connection_id); IDSend(resp_data.service_connection_id);
let track_serv_conn = let _ =
match self.connections.track(service_connection) { self.connections.track(service_connection.into());
Ok(track_serv_conn) => track_serv_conn,
Err(_) => {
::tracing::error!(
"Could not track new service connection"
);
self.connections
.remove(cci.service_connection_id);
// FIXME: proper connection closing
// FIXME: drop auth srv connection if we just
// established it
let _ = cci.answer.send(Err(
handshake::Error::InternalTracking.into(),
));
return;
} }
}; let _ =
service_conn = Some(ServiceConn(ConnTracker { cci.answer.send(Ok((cci.srv_key_id, auth_srv_conn)));
lib: track_serv_conn,
user: None,
}));
}
let _ = cci.answer.send(Ok(handshake::tracker::ConnectOk {
auth_key_id: cci.srv_key_id,
auth_id_send,
authsrv_conn,
service_conn,
}));
} }
handshake::Action::Nothing => {} handshake::Action::Nothing => {}
}; };
} }
}
async fn send_packet( async fn send_packet(
&self, &self,
data: Vec<u8>, data: Vec<u8>,
@ -878,16 +653,3 @@ impl Worker {
let _res = src_sock.send_to(&data, client.0).await; let _res = src_sock.send_to(&data, client.0).await;
} }
} }
/// Handle to send work asyncronously to the worker
#[derive(Debug, Clone)]
pub struct Handle {
queue: ::async_channel::Sender<Work>,
}
impl Handle {
// TODO
// pub fn send(..)
// pub fn set_connection_id(..)
// try_get_new_connections()
}

View File

@ -34,12 +34,12 @@ use crate::{
AuthServerConnections, Packet, AuthServerConnections, Packet,
}, },
inner::{ inner::{
worker::{ConnectInfo, Event, RawUdp, Work, Worker}, worker::{ConnectInfo, RawUdp, Work, Worker},
ThreadTracker, ThreadTracker,
}, },
}; };
pub use config::Config; pub use config::Config;
pub use connection::{AuthSrvConn, ServiceConn}; pub use connection::Connection;
/// Main fenrir library errors /// Main fenrir library errors
#[derive(::thiserror::Error, Debug)] #[derive(::thiserror::Error, Debug)]
@ -59,15 +59,15 @@ pub enum Error {
/// Handshake errors /// Handshake errors
#[error("Handshake: {0:?}")] #[error("Handshake: {0:?}")]
Handshake(#[from] handshake::Error), Handshake(#[from] handshake::Error),
/// Key error
#[error("key: {0:?}")]
Key(#[from] crate::enc::Error),
/// Resolution problems. wrong or incomplete DNSSEC data /// Resolution problems. wrong or incomplete DNSSEC data
#[error("DNSSEC resolution: {0}")] #[error("DNSSEC resolution: {0}")]
Resolution(String), Resolution(String),
/// Wrapper on encryption errors /// Wrapper on encryption errors
#[error("Crypto: {0}")] #[error("Encrypt: {0}")]
Crypto(#[from] enc::Error), Encrypt(enc::Error),
/// Wrapper on connection errors
#[error("Connection: {0}")]
Connection(#[from] connection::Error),
} }
pub(crate) enum StopWorking { pub(crate) enum StopWorking {
@ -176,7 +176,6 @@ impl Fenrir {
config: &Config, config: &Config,
tokio_rt: Arc<::tokio::runtime::Runtime>, tokio_rt: Arc<::tokio::runtime::Runtime>,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
inner::set_minimum_sleep_resolution().await;
let (sender, _) = ::tokio::sync::broadcast::channel(1); let (sender, _) = ::tokio::sync::broadcast::channel(1);
let dnssec = dnssec::Dnssec::new(&config.resolvers)?; let dnssec = dnssec::Dnssec::new(&config.resolvers)?;
// bind sockets early so we can change "port 0" (aka: random) // bind sockets early so we can change "port 0" (aka: random)
@ -215,7 +214,6 @@ impl Fenrir {
pub async fn with_workers( pub async fn with_workers(
config: &Config, config: &Config,
) -> Result<(Self, Vec<Worker>), Error> { ) -> Result<(Self, Vec<Worker>), Error> {
inner::set_minimum_sleep_resolution().await;
let (stop_working, _) = ::tokio::sync::broadcast::channel(1); let (stop_working, _) = ::tokio::sync::broadcast::channel(1);
let dnssec = dnssec::Dnssec::new(&config.resolvers)?; let dnssec = dnssec::Dnssec::new(&config.resolvers)?;
// bind sockets early so we can change "port 0" (aka: random) // bind sockets early so we can change "port 0" (aka: random)
@ -384,7 +382,7 @@ impl Fenrir {
&self, &self,
domain: &Domain, domain: &Domain,
service: ServiceID, service: ServiceID,
) -> Result<(AuthSrvConn, Option<ServiceConn>), Error> { ) -> Result<(), Error> {
let resolved = self.resolv(domain).await?; let resolved = self.resolv(domain).await?;
self.connect_resolved(resolved, domain, service).await self.connect_resolved(resolved, domain, service).await
} }
@ -394,7 +392,7 @@ impl Fenrir {
resolved: dnssec::Record, resolved: dnssec::Record,
domain: &Domain, domain: &Domain,
service: ServiceID, service: ServiceID,
) -> Result<(AuthSrvConn, Option<ServiceConn>), Error> { ) -> Result<(), Error> {
loop { loop {
// check if we already have a connection to that auth. srv // check if we already have a connection to that auth. srv
let is_reserved = { let is_reserved = {
@ -462,28 +460,29 @@ impl Fenrir {
.await; .await;
match recv.await { match recv.await {
Ok(res) => match res { Ok(res) => {
match res {
Err(e) => { Err(e) => {
let mut conn_auth_lock = self.conn_auth_srv.lock().await; let mut conn_auth_lock =
self.conn_auth_srv.lock().await;
conn_auth_lock.remove_reserved(&resolved); conn_auth_lock.remove_reserved(&resolved);
Err(e) Err(e)
} }
Ok(connections) => { Ok((key_id, id_send)) => {
let key = resolved let key = resolved
.public_keys .public_keys
.iter() .iter()
.find(|k| k.0 == connections.auth_key_id) .find(|k| k.0 == key_id)
.unwrap(); .unwrap();
let mut conn_auth_lock = self.conn_auth_srv.lock().await; let mut conn_auth_lock =
conn_auth_lock.add( self.conn_auth_srv.lock().await;
&key.1, conn_auth_lock.add(&key.1, id_send, &resolved);
connections.auth_id_send,
&resolved,
);
Ok((connections.authsrv_conn, connections.service_conn)) //FIXME: user needs to somehow track the connection
Ok(())
}
}
} }
},
Err(e) => { Err(e) => {
// Thread dropped the sender. no more thread? // Thread dropped the sender. no more thread?
let mut conn_auth_lock = self.conn_auth_srv.lock().await; let mut conn_auth_lock = self.conn_auth_srv.lock().await;
@ -525,7 +524,6 @@ impl Fenrir {
self.token_check.clone(), self.token_check.clone(),
socks, socks,
work_recv, work_recv,
work_send.clone(),
) )
.await?; .await?;
// don't keep around private keys too much // don't keep around private keys too much
@ -549,6 +547,7 @@ impl Fenrir {
} }
Ok(worker) Ok(worker)
} }
// needs to be called before add_sockets // needs to be called before add_sockets
/// Start one working thread for each physical cpu /// Start one working thread for each physical cpu
/// threads are pinned to each cpu core. /// threads are pinned to each cpu core.
@ -590,7 +589,6 @@ impl Fenrir {
let th_tokio_rt = tokio_rt.clone(); let th_tokio_rt = tokio_rt.clone();
let th_config = self.cfg.clone(); let th_config = self.cfg.clone();
let (work_send, work_recv) = ::async_channel::unbounded::<Work>(); let (work_send, work_recv) = ::async_channel::unbounded::<Work>();
let th_work_send = work_send.clone();
let th_stop_working = self.stop_working.subscribe(); let th_stop_working = self.stop_working.subscribe();
let th_token_check = self.token_check.clone(); let th_token_check = self.token_check.clone();
let th_sockets = sockets.clone(); let th_sockets = sockets.clone();
@ -631,23 +629,13 @@ impl Fenrir {
th_token_check, th_token_check,
th_sockets, th_sockets,
work_recv, work_recv,
th_work_send,
) )
.await .await
{ {
Ok(worker) => worker, Ok(worker) => worker,
Err(_) => return, Err(_) => return,
}; };
loop { worker.work_loop().await
match worker.work_loop().await {
Ok(_) => continue,
Ok(Event::End) => break,
Err(e) => {
::tracing::error!("Worker: {:?}", e);
break;
}
}
}
}); });
}); });
loop { loop {

View File

@ -88,7 +88,7 @@ async fn test_connection_dirsync() {
.connect_resolved(dnssec_record, &test_domain, auth::SERVICEID_AUTH) .connect_resolved(dnssec_record, &test_domain, auth::SERVICEID_AUTH)
.await .await
{ {
Ok((_, _)) => {} Ok(()) => {}
Err(e) => { Err(e) => {
assert!(false, "Err on client connection: {:?}", e); assert!(false, "Err on client connection: {:?}", e);
} }