Compare commits

...

10 Commits

Author SHA1 Message Date
9ca4123c37
Review conn tracking, data reporting
Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
2023-06-28 19:11:26 +02:00
c3c8238730
Stream ROB: Reconstruct data TCP-like
this was more convoluted than I thought.
maybe someone will simplify this.

Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
2023-06-28 19:11:26 +02:00
a810fc9a9e
Stream enqueue and serialize to the packet
Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
2023-06-28 19:11:26 +02:00
9c67210e3e
User conn tracking, enqueue data, timers
Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
2023-06-28 19:11:26 +02:00
2fe91d5dd3
Give the user a tracker for conn interactions
Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
2023-06-28 19:11:22 +02:00
11d6b4e467
Merge branch 'namespace' 2023-06-28 19:07:05 +02:00
5dff5c8c9a
Namespace split the dirsync request/response
There was no big problem, but it was messy

Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
2023-06-28 18:58:24 +02:00
bf877cf86e
Rename lots of stuff to properly use namespaces
Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
2023-06-28 18:58:23 +02:00
cd7be0ff69
Stream stubs, start using namespaces as intended
Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
2023-06-28 18:58:21 +02:00
d6825980fd
Merge branch 'handshake_dir_sync'
Have a working handshake and a connection,
sounds like something we might need. eventually.
2023-06-28 18:51:43 +02:00
25 changed files with 2279 additions and 919 deletions

3
TODO
View File

@ -1 +1,4 @@
* Wrapping for everything that wraps (sigh)
* track user connection (add u64 from user)
* split API in LocalThread and ThreadSafe
* split send/recv API in Centralized, Decentralized

24
flake.lock generated
View File

@ -5,11 +5,11 @@
"systems": "systems"
},
"locked": {
"lastModified": 1681202837,
"narHash": "sha256-H+Rh19JDwRtpVPAWp64F+rlEtxUWBAQW28eAi3SRSzg=",
"lastModified": 1687171271,
"narHash": "sha256-BJlq+ozK2B1sJDQXS3tzJM5a+oVZmi1q0FlBK/Xqv7M=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "cfacdce06f30d2b68473a46042957675eebb3401",
"rev": "abfb11bd1aec8ced1c9bb9adfe68018230f4fb3c",
"type": "github"
},
"original": {
@ -38,11 +38,11 @@
},
"nixpkgs": {
"locked": {
"lastModified": 1684922889,
"narHash": "sha256-l0WZAmln8959O7RdYUJ3gnAIM9OPKFLKHKGX4q+Blrk=",
"lastModified": 1687555006,
"narHash": "sha256-GD2Kqb/DXQBRJcHqkM2qFZqbVenyO7Co/80JHRMg2U0=",
"owner": "nixos",
"repo": "nixpkgs",
"rev": "04aaf8511678a0d0f347fdf1e8072fe01e4a509e",
"rev": "33223d479ffde3d05ac16c6dff04ae43cc27e577",
"type": "github"
},
"original": {
@ -54,11 +54,11 @@
},
"nixpkgs-unstable": {
"locked": {
"lastModified": 1684844536,
"narHash": "sha256-M7HhXYVqAuNb25r/d3FOO0z4GxPqDIZp5UjHFbBgw0Q=",
"lastModified": 1687502512,
"narHash": "sha256-dBL/01TayOSZYxtY4cMXuNCBk8UMLoqRZA+94xiFpJA=",
"owner": "nixos",
"repo": "nixpkgs",
"rev": "d30264c2691128adc261d7c9388033645f0e742b",
"rev": "3ae20aa58a6c0d1ca95c9b11f59a2d12eebc511f",
"type": "github"
},
"original": {
@ -98,11 +98,11 @@
"nixpkgs": "nixpkgs_2"
},
"locked": {
"lastModified": 1684894917,
"narHash": "sha256-kwKCfmliHIxKuIjnM95TRcQxM/4AAEIZ+4A9nDJ6cJs=",
"lastModified": 1687660699,
"narHash": "sha256-crI/CA/OJc778I5qJhwhhl8/PKKzc0D7vvVxOtjfvSo=",
"owner": "oxalica",
"repo": "rust-overlay",
"rev": "9ea38d547100edcf0da19aaebbdffa2810585495",
"rev": "b3bd1d49f1ae609c1d68a66bba7a95a9a4256031",
"type": "github"
},
"original": {

View File

@ -2,11 +2,10 @@
//! Configuration to initialize the Fenrir networking library
use crate::{
connection::handshake::HandshakeID,
connection::handshake,
enc::{
asym::{KeyExchangeKind, KeyID, PrivKey, PubKey},
hkdf::HkdfKind,
sym::CipherKind,
hkdf, sym,
},
};
use ::std::{
@ -44,13 +43,13 @@ pub struct Config {
/// List of DNS resolvers to use
pub resolvers: Vec<SocketAddr>,
/// Supported handshakes
pub handshakes: Vec<HandshakeID>,
pub handshakes: Vec<handshake::ID>,
/// Supported key exchanges
pub key_exchanges: Vec<KeyExchangeKind>,
/// Supported Hkdfs
pub hkdfs: Vec<HkdfKind>,
pub hkdfs: Vec<hkdf::Kind>,
/// Supported Ciphers
pub ciphers: Vec<CipherKind>,
pub ciphers: Vec<sym::Kind>,
/// list of authentication servers
/// clients will have this empty
pub servers: Vec<AuthServer>,
@ -73,10 +72,10 @@ impl Default for Config {
),
],
resolvers: Vec::new(),
handshakes: [HandshakeID::DirectorySynchronized].to_vec(),
handshakes: [handshake::ID::DirectorySynchronized].to_vec(),
key_exchanges: [KeyExchangeKind::X25519DiffieHellman].to_vec(),
hkdfs: [HkdfKind::Sha3].to_vec(),
ciphers: [CipherKind::XChaCha20Poly1305].to_vec(),
hkdfs: [hkdf::Kind::Sha3].to_vec(),
ciphers: [sym::Kind::XChaCha20Poly1305].to_vec(),
servers: Vec::new(),
server_keys: Vec::new(),
}

View File

@ -0,0 +1,77 @@
//! Directory synchronized handshake
//! 1-RTT connection
//!
//! The simplest, fastest handshake supported by Fenrir
//! Downside: It does not offer protection from DDos,
//! no perfect forward secrecy
//!
//! To grant a form of perfect forward secrecy, the server should periodically
//! change the DNSSEC public/private keys
use crate::enc::{
sym::{NonceLen, TagLen},
Random,
};
pub mod req;
pub mod resp;
// TODO: merge with crate::enc::sym::Nonce
/// random nonce
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Nonce(pub(crate) [u8; 16]);
impl Nonce {
/// Create a new random Nonce
pub fn new(rnd: &Random) -> Self {
use ::core::mem::MaybeUninit;
let mut out: MaybeUninit<[u8; 16]>;
#[allow(unsafe_code)]
unsafe {
out = MaybeUninit::uninit();
let _ = rnd.fill(out.assume_init_mut());
Self(out.assume_init())
}
}
/// Length of the serialized Nonce
pub const fn len() -> usize {
16
}
}
impl From<&[u8; 16]> for Nonce {
fn from(raw: &[u8; 16]) -> Self {
Self(raw.clone())
}
}
/// Parsed handshake
#[derive(Debug, Clone, PartialEq)]
pub enum DirSync {
/// Directory synchronized handshake: client request
Req(req::Req),
/// Directory synchronized handshake: server response
Resp(resp::Resp),
}
impl DirSync {
/// actual length of the dirsync handshake data
pub fn len(&self, head_len: NonceLen, tag_len: TagLen) -> usize {
match self {
DirSync::Req(req) => req.len(),
DirSync::Resp(resp) => resp.len(head_len, tag_len),
}
}
/// Serialize into raw bytes
/// NOTE: assumes that there is exactly asa much buffer as needed
pub fn serialize(
&self,
head_len: NonceLen,
tag_len: TagLen,
out: &mut [u8],
) {
match self {
DirSync::Req(req) => req.serialize(head_len, tag_len, out),
DirSync::Resp(resp) => resp.serialize(head_len, tag_len, out),
}
}
}

View File

@ -1,85 +1,22 @@
//! Directory synchronized handshake
//! 1-RTT connection
//!
//! The simplest, fastest handshake supported by Fenrir
//! Downside: It does not offer protection from DDos,
//! no perfect forward secrecy
//!
//! To grant a form of perfect forward secrecy, the server should periodically
//! change the DNSSEC public/private keys
//! Directory synchronized handshake, Request parsing
use super::{Error, HandshakeData};
use crate::{
auth,
connection::{ProtocolVersion, ID},
connection::{
handshake::{
self,
dirsync::{DirSync, Nonce},
Error,
},
ProtocolVersion, ID,
},
enc::{
asym::{ExchangePubKey, KeyExchangeKind, KeyID},
hkdf::HkdfKind,
sym::{CipherKind, HeadLen, TagLen},
Random, Secret,
hkdf,
sym::{self, NonceLen, TagLen},
},
};
// TODO: merge with crate::enc::sym::Nonce
/// random nonce
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Nonce(pub(crate) [u8; 16]);
impl Nonce {
/// Create a new random Nonce
pub fn new(rnd: &Random) -> Self {
use ::core::mem::MaybeUninit;
let mut out: MaybeUninit<[u8; 16]>;
#[allow(unsafe_code)]
unsafe {
out = MaybeUninit::uninit();
let _ = rnd.fill(out.assume_init_mut());
Self(out.assume_init())
}
}
/// Length of the serialized Nonce
pub const fn len() -> usize {
16
}
}
impl From<&[u8; 16]> for Nonce {
fn from(raw: &[u8; 16]) -> Self {
Self(raw.clone())
}
}
/// Parsed handshake
#[derive(Debug, Clone, PartialEq)]
pub enum DirSync {
/// Directory synchronized handshake: client request
Req(Req),
/// Directory synchronized handshake: server response
Resp(Resp),
}
impl DirSync {
/// actual length of the dirsync handshake data
pub fn len(&self, head_len: HeadLen, tag_len: TagLen) -> usize {
match self {
DirSync::Req(req) => req.len(),
DirSync::Resp(resp) => resp.len(head_len, tag_len),
}
}
/// Serialize into raw bytes
/// NOTE: assumes that there is exactly asa much buffer as needed
pub fn serialize(
&self,
head_len: HeadLen,
tag_len: TagLen,
out: &mut [u8],
) {
match self {
DirSync::Req(req) => req.serialize(head_len, tag_len, out),
DirSync::Resp(resp) => resp.serialize(head_len, tag_len, out),
}
}
}
/// Client request of a directory synchronized handshake
#[derive(Debug, Clone, PartialEq)]
pub struct Req {
@ -88,13 +25,13 @@ pub struct Req {
/// Selected key exchange
pub exchange: KeyExchangeKind,
/// Selected hkdf
pub hkdf: HkdfKind,
pub hkdf: hkdf::Kind,
/// Selected cipher
pub cipher: CipherKind,
pub cipher: sym::Kind,
/// Client ephemeral public key used for key exchanges
pub exchange_key: ExchangePubKey,
/// encrypted data
pub data: ReqInner,
pub data: State,
// SECURITY: TODO: Add padding to min: 1200 bytes
// to avoid amplification attaks
// also: 1200 < 1280 to allow better vpn compatibility
@ -105,30 +42,30 @@ impl Req {
/// NOTE: starts from the beginning of the fenrir packet
pub fn encrypted_offset(&self) -> usize {
ProtocolVersion::len()
+ crate::handshake::HandshakeID::len()
+ handshake::ID::len()
+ KeyID::len()
+ KeyExchangeKind::len()
+ HkdfKind::len()
+ CipherKind::len()
+ hkdf::Kind::len()
+ sym::Kind::len()
+ self.exchange_key.kind().pub_len()
}
/// return the total length of the cleartext data
pub fn encrypted_length(
&self,
head_len: HeadLen,
head_len: NonceLen,
tag_len: TagLen,
) -> usize {
match &self.data {
ReqInner::ClearText(data) => data.len() + head_len.0 + tag_len.0,
ReqInner::CipherText(length) => *length,
State::ClearText(data) => data.len() + head_len.0 + tag_len.0,
State::CipherText(length) => *length,
}
}
/// actual length of the directory synchronized request
pub fn len(&self) -> usize {
KeyID::len()
+ KeyExchangeKind::len()
+ HkdfKind::len()
+ CipherKind::len()
+ hkdf::Kind::len()
+ sym::Kind::len()
+ self.exchange_key.kind().pub_len()
+ self.cipher.nonce_len().0
+ self.data.len()
@ -138,7 +75,7 @@ impl Req {
/// NOTE: assumes that there is exactly as much buffer as needed
pub fn serialize(
&self,
head_len: HeadLen,
head_len: NonceLen,
tag_len: TagLen,
out: &mut [u8],
) {
@ -150,7 +87,7 @@ impl Req {
let written_next = 5 + key_len;
self.exchange_key.serialize_into(&mut out[5..written_next]);
let written = written_next;
if let ReqInner::ClearText(data) = &self.data {
if let State::ClearText(data) = &self.data {
let from = written + head_len.0;
let to = out.len() - tag_len.0;
data.serialize(&mut out[from..to]);
@ -160,8 +97,8 @@ impl Req {
}
}
impl super::HandshakeParsing for Req {
fn deserialize(raw: &[u8]) -> Result<HandshakeData, Error> {
impl handshake::Parsing for Req {
fn deserialize(raw: &[u8]) -> Result<handshake::Data, Error> {
const MIN_PKT_LEN: usize = 10;
if raw.len() < MIN_PKT_LEN {
return Err(Error::NotEnoughData);
@ -173,25 +110,25 @@ impl super::HandshakeParsing for Req {
Some(exchange) => exchange,
None => return Err(Error::Parsing),
};
let hkdf: HkdfKind = match HkdfKind::from_u8(raw[3]) {
let hkdf: hkdf::Kind = match hkdf::Kind::from_u8(raw[3]) {
Some(exchange) => exchange,
None => return Err(Error::Parsing),
};
let cipher: CipherKind = match CipherKind::from_u8(raw[4]) {
let cipher: sym::Kind = match sym::Kind::from_u8(raw[4]) {
Some(cipher) => cipher,
None => return Err(Error::Parsing),
};
const CURR_SIZE: usize = KeyID::len()
+ KeyExchangeKind::len()
+ HkdfKind::len()
+ CipherKind::len();
+ hkdf::Kind::len()
+ sym::Kind::len();
let (exchange_key, len) =
match ExchangePubKey::deserialize(&raw[CURR_SIZE..]) {
Ok(exchange_key) => exchange_key,
Err(e) => return Err(e.into()),
};
let data = ReqInner::CipherText(raw.len() - (CURR_SIZE + len));
Ok(HandshakeData::DirSync(DirSync::Req(Self {
let data = State::CipherText(raw.len() - (CURR_SIZE + len));
Ok(handshake::Data::DirSync(DirSync::Req(Self {
key_id,
exchange,
hkdf,
@ -204,18 +141,18 @@ impl super::HandshakeParsing for Req {
/// Quick way to avoid mixing cipher and clear text
#[derive(Debug, Clone, PartialEq)]
pub enum ReqInner {
pub enum State {
/// Data is still encrytped, we only keep the length
CipherText(usize),
/// Client data, decrypted and parsed
ClearText(ReqData),
ClearText(Data),
}
impl ReqInner {
impl State {
/// The length of the data
pub fn len(&self) -> usize {
match self {
ReqInner::CipherText(len) => *len,
ReqInner::ClearText(data) => data.len(),
State::CipherText(len) => *len,
State::ClearText(data) => data.len(),
}
}
/// parse the cleartext
@ -224,19 +161,19 @@ impl ReqInner {
raw: &[u8],
) -> Result<(), Error> {
let clear = match self {
ReqInner::CipherText(len) => {
State::CipherText(len) => {
assert!(
*len > raw.len(),
"DirSync::ReqInner::CipherText length mismatch"
"DirSync::State::CipherText length mismatch"
);
match ReqData::deserialize(raw) {
match Data::deserialize(raw) {
Ok(clear) => clear,
Err(e) => return Err(e),
}
}
_ => return Err(Error::Parsing),
};
*self = ReqInner::ClearText(clear);
*self = State::ClearText(clear);
Ok(())
}
}
@ -321,7 +258,7 @@ impl AuthInfo {
/// Decrypted request data
#[derive(Debug, Clone, PartialEq)]
pub struct ReqData {
pub struct Data {
/// Random nonce, the client can use this to track multiple key exchanges
pub nonce: Nonce,
/// Client key id so the client can use and rotate keys
@ -331,7 +268,7 @@ pub struct ReqData {
/// Authentication data
pub auth: AuthInfo,
}
impl ReqData {
impl Data {
/// actual length of the request data
pub fn len(&self) -> usize {
Nonce::len() + KeyID::len() + ID::len() + self.auth.len()
@ -383,179 +320,3 @@ impl ReqData {
})
}
}
/// Quick way to avoid mixing cipher and clear text
#[derive(Debug, Clone, PartialEq)]
pub enum RespInner {
/// Server data, still in ciphertext
CipherText(usize),
/// Parsed, cleartext server data
ClearText(RespData),
}
impl RespInner {
/// The length of the data
pub fn len(&self) -> usize {
match self {
RespInner::CipherText(len) => *len,
RespInner::ClearText(_) => RespData::len(),
}
}
/// parse the cleartext
pub fn deserialize_as_cleartext(
&mut self,
raw: &[u8],
) -> Result<(), Error> {
let clear = match self {
RespInner::CipherText(len) => {
assert!(
*len > raw.len(),
"DirSync::RespInner::CipherText length mismatch"
);
match RespData::deserialize(raw) {
Ok(clear) => clear,
Err(e) => return Err(e),
}
}
_ => return Err(Error::Parsing),
};
*self = RespInner::ClearText(clear);
Ok(())
}
/// Serialize the still cleartext data
pub fn serialize(&self, out: &mut [u8]) {
if let RespInner::ClearText(clear) = &self {
clear.serialize(out);
}
}
}
/// Server response in a directory synchronized handshake
#[derive(Debug, Clone, PartialEq)]
pub struct Resp {
/// Tells the client with which key the exchange was done
pub client_key_id: KeyID,
/// actual response data, might be encrypted
pub data: RespInner,
}
impl super::HandshakeParsing for Resp {
fn deserialize(raw: &[u8]) -> Result<HandshakeData, Error> {
const MIN_PKT_LEN: usize = 68;
if raw.len() < MIN_PKT_LEN {
return Err(Error::NotEnoughData);
}
let client_key_id: KeyID =
KeyID(u16::from_le_bytes(raw[0..KeyID::len()].try_into().unwrap()));
Ok(HandshakeData::DirSync(DirSync::Resp(Self {
client_key_id,
data: RespInner::CipherText(raw[KeyID::len()..].len()),
})))
}
}
impl Resp {
/// return the offset of the encrypted data
/// NOTE: starts from the beginning of the fenrir packet
pub fn encrypted_offset(&self) -> usize {
ProtocolVersion::len()
+ crate::connection::handshake::HandshakeID::len()
+ KeyID::len()
}
/// return the total length of the cleartext data
pub fn encrypted_length(
&self,
head_len: HeadLen,
tag_len: TagLen,
) -> usize {
match &self.data {
RespInner::ClearText(_data) => {
RespData::len() + head_len.0 + tag_len.0
}
RespInner::CipherText(len) => *len,
}
}
/// Total length of the response handshake
pub fn len(&self, head_len: HeadLen, tag_len: TagLen) -> usize {
KeyID::len() + head_len.0 + self.data.len() + tag_len.0
}
/// Serialize into raw bytes
/// NOTE: assumes that there is exactly as much buffer as needed
pub fn serialize(
&self,
head_len: HeadLen,
_tag_len: TagLen,
out: &mut [u8],
) {
out[0..KeyID::len()]
.copy_from_slice(&self.client_key_id.0.to_le_bytes());
let start_data = KeyID::len() + head_len.0;
let end_data = start_data + self.data.len();
self.data.serialize(&mut out[start_data..end_data]);
}
}
/// Decrypted response data
#[derive(Debug, Clone, PartialEq)]
pub struct RespData {
/// Client nonce, copied from the request
pub client_nonce: Nonce,
/// Server Connection ID
pub id: ID,
/// Service Connection ID
pub service_connection_id: ID,
/// Service encryption key
pub service_key: Secret,
}
impl RespData {
/// Return the expected length for buffer allocation
pub fn len() -> usize {
Nonce::len() + ID::len() + ID::len() + Secret::len()
}
/// Serialize the data into a buffer
/// NOTE: assumes that there is exactly asa much buffer as needed
pub fn serialize(&self, out: &mut [u8]) {
let mut start = 0;
let mut end = Nonce::len();
out[start..end].copy_from_slice(&self.client_nonce.0);
start = end;
end = end + ID::len();
self.id.serialize(&mut out[start..end]);
start = end;
end = end + ID::len();
self.service_connection_id.serialize(&mut out[start..end]);
start = end;
end = end + Secret::len();
out[start..end].copy_from_slice(self.service_key.as_ref());
}
/// Parse the cleartext raw data
pub fn deserialize(raw: &[u8]) -> Result<Self, Error> {
let raw_sized: &[u8; 16] = raw[..Nonce::len()].try_into().unwrap();
let client_nonce: Nonce = raw_sized.into();
let end = Nonce::len() + ID::len();
let id: ID =
u64::from_le_bytes(raw[Nonce::len()..end].try_into().unwrap())
.into();
if id.is_handshake() {
return Err(Error::Parsing);
}
let parsed = end;
let end = parsed + ID::len();
let service_connection_id: ID =
u64::from_le_bytes(raw[parsed..end].try_into().unwrap()).into();
if service_connection_id.is_handshake() {
return Err(Error::Parsing);
}
let parsed = end;
let end = parsed + Secret::len();
let raw_secret: &[u8; 32] = raw[parsed..end].try_into().unwrap();
let service_key = raw_secret.into();
Ok(Self {
client_nonce,
id,
service_connection_id,
service_key,
})
}
}

View File

@ -0,0 +1,189 @@
//! Directory synchronized handshake, Response parsing
use crate::{
connection::{
handshake::{
self,
dirsync::{DirSync, Nonce},
Error,
},
ProtocolVersion, ID,
},
enc::{
asym::KeyID,
sym::{NonceLen, TagLen},
Secret,
},
};
/// Server response in a directory synchronized handshake
#[derive(Debug, Clone, PartialEq)]
pub struct Resp {
/// Tells the client with which key the exchange was done
pub client_key_id: KeyID,
/// actual response data, might be encrypted
pub data: State,
}
impl handshake::Parsing for Resp {
fn deserialize(raw: &[u8]) -> Result<handshake::Data, Error> {
const MIN_PKT_LEN: usize = 68;
if raw.len() < MIN_PKT_LEN {
return Err(Error::NotEnoughData);
}
let client_key_id: KeyID =
KeyID(u16::from_le_bytes(raw[0..KeyID::len()].try_into().unwrap()));
Ok(handshake::Data::DirSync(DirSync::Resp(Self {
client_key_id,
data: State::CipherText(raw[KeyID::len()..].len()),
})))
}
}
impl Resp {
/// return the offset of the encrypted data
/// NOTE: starts from the beginning of the fenrir packet
pub fn encrypted_offset(&self) -> usize {
ProtocolVersion::len() + handshake::ID::len() + KeyID::len()
}
/// return the total length of the cleartext data
pub fn encrypted_length(
&self,
head_len: NonceLen,
tag_len: TagLen,
) -> usize {
match &self.data {
State::ClearText(_data) => Data::len() + head_len.0 + tag_len.0,
State::CipherText(len) => *len,
}
}
/// Total length of the response handshake
pub fn len(&self, head_len: NonceLen, tag_len: TagLen) -> usize {
KeyID::len() + head_len.0 + self.data.len() + tag_len.0
}
/// Serialize into raw bytes
/// NOTE: assumes that there is exactly as much buffer as needed
pub fn serialize(
&self,
head_len: NonceLen,
_tag_len: TagLen,
out: &mut [u8],
) {
out[0..KeyID::len()]
.copy_from_slice(&self.client_key_id.0.to_le_bytes());
let start_data = KeyID::len() + head_len.0;
let end_data = start_data + self.data.len();
self.data.serialize(&mut out[start_data..end_data]);
}
}
/// Quick way to avoid mixing cipher and clear text
#[derive(Debug, Clone, PartialEq)]
pub enum State {
/// Server data, still in ciphertext
CipherText(usize),
/// Parsed, cleartext server data
ClearText(Data),
}
impl State {
/// The length of the data
pub fn len(&self) -> usize {
match self {
State::CipherText(len) => *len,
State::ClearText(_) => Data::len(),
}
}
/// parse the cleartext
pub fn deserialize_as_cleartext(
&mut self,
raw: &[u8],
) -> Result<(), Error> {
let clear = match self {
State::CipherText(len) => {
assert!(
*len > raw.len(),
"DirSync::State::CipherText length mismatch"
);
match Data::deserialize(raw) {
Ok(clear) => clear,
Err(e) => return Err(e),
}
}
_ => return Err(Error::Parsing),
};
*self = State::ClearText(clear);
Ok(())
}
/// Serialize the still cleartext data
pub fn serialize(&self, out: &mut [u8]) {
if let State::ClearText(clear) = &self {
clear.serialize(out);
}
}
}
/// Decrypted response data
#[derive(Debug, Clone, PartialEq)]
pub struct Data {
/// Client nonce, copied from the request
pub client_nonce: Nonce,
/// Server Connection ID
pub id: ID,
/// Service Connection ID
pub service_connection_id: ID,
/// Service encryption key
pub service_key: Secret,
}
impl Data {
/// Return the expected length for buffer allocation
pub fn len() -> usize {
Nonce::len() + ID::len() + ID::len() + Secret::len()
}
/// Serialize the data into a buffer
/// NOTE: assumes that there is exactly asa much buffer as needed
pub fn serialize(&self, out: &mut [u8]) {
let mut start = 0;
let mut end = Nonce::len();
out[start..end].copy_from_slice(&self.client_nonce.0);
start = end;
end = end + ID::len();
self.id.serialize(&mut out[start..end]);
start = end;
end = end + ID::len();
self.service_connection_id.serialize(&mut out[start..end]);
start = end;
end = end + Secret::len();
out[start..end].copy_from_slice(self.service_key.as_ref());
}
/// Parse the cleartext raw data
pub fn deserialize(raw: &[u8]) -> Result<Self, Error> {
let raw_sized: &[u8; 16] = raw[..Nonce::len()].try_into().unwrap();
let client_nonce: Nonce = raw_sized.into();
let end = Nonce::len() + ID::len();
let id: ID =
u64::from_le_bytes(raw[Nonce::len()..end].try_into().unwrap())
.into();
if id.is_handshake() {
return Err(Error::Parsing);
}
let parsed = end;
let end = parsed + ID::len();
let service_connection_id: ID =
u64::from_le_bytes(raw[parsed..end].try_into().unwrap()).into();
if service_connection_id.is_handshake() {
return Err(Error::Parsing);
}
let parsed = end;
let end = parsed + Secret::len();
let raw_secret: &[u8; 32] = raw[parsed..end].try_into().unwrap();
let service_key = raw_secret.into();
Ok(Self {
client_nonce,
id,
service_connection_id,
service_key,
})
}
}

View File

@ -4,10 +4,11 @@ pub mod dirsync;
#[cfg(test)]
mod tests;
pub(crate) mod tracker;
pub(crate) use tracker::{Action, Tracker};
use crate::{
connection::ProtocolVersion,
enc::sym::{HeadLen, TagLen},
enc::sym::{NonceLen, TagLen},
};
use ::num_traits::FromPrimitive;
@ -56,7 +57,7 @@ pub enum Error {
::strum_macros::IntoStaticStr,
)]
#[repr(u8)]
pub enum HandshakeID {
pub enum ID {
/// 1-RTT Directory synchronized handshake. Fast, no forward secrecy
#[strum(serialize = "directory_synchronized")]
DirectorySynchronized = 0,
@ -67,7 +68,7 @@ pub enum HandshakeID {
#[strum(serialize = "stateless")]
Stateless,
}
impl HandshakeID {
impl ID {
/// The length of the serialized field
pub const fn len() -> usize {
1
@ -75,28 +76,28 @@ impl HandshakeID {
}
/// Parsed handshake
#[derive(Debug, Clone, PartialEq)]
pub enum HandshakeData {
pub enum Data {
/// Directory synchronized handhsake
DirSync(dirsync::DirSync),
}
impl HandshakeData {
impl Data {
/// actual length of the handshake data
pub fn len(&self, head_len: HeadLen, tag_len: TagLen) -> usize {
pub fn len(&self, head_len: NonceLen, tag_len: TagLen) -> usize {
match self {
HandshakeData::DirSync(d) => d.len(head_len, tag_len),
Data::DirSync(d) => d.len(head_len, tag_len),
}
}
/// Serialize into raw bytes
/// NOTE: assumes that there is exactly asa much buffer as needed
pub fn serialize(
&self,
head_len: HeadLen,
head_len: NonceLen,
tag_len: TagLen,
out: &mut [u8],
) {
match self {
HandshakeData::DirSync(d) => d.serialize(head_len, tag_len, out),
Data::DirSync(d) => d.serialize(head_len, tag_len, out),
}
}
}
@ -133,19 +134,19 @@ pub struct Handshake {
/// Fenrir Protocol version
pub fenrir_version: ProtocolVersion,
/// enum for the parsed data
pub data: HandshakeData,
pub data: Data,
}
impl Handshake {
/// Build new handshake from the data
pub fn new(data: HandshakeData) -> Self {
pub fn new(data: Data) -> Self {
Handshake {
fenrir_version: ProtocolVersion::V0,
data,
}
}
/// return the total length of the handshake
pub fn len(&self, head_len: HeadLen, tag_len: TagLen) -> usize {
pub fn len(&self, head_len: NonceLen, tag_len: TagLen) -> usize {
ProtocolVersion::len()
+ HandshakeKind::len()
+ self.data.len(head_len, tag_len)
@ -165,9 +166,11 @@ impl Handshake {
None => return Err(Error::Parsing),
};
let data = match handshake_kind {
HandshakeKind::DirSyncReq => dirsync::Req::deserialize(&raw[2..])?,
HandshakeKind::DirSyncReq => {
dirsync::req::Req::deserialize(&raw[2..])?
}
HandshakeKind::DirSyncResp => {
dirsync::Resp::deserialize(&raw[2..])?
dirsync::resp::Resp::deserialize(&raw[2..])?
}
};
Ok(Self {
@ -179,13 +182,13 @@ impl Handshake {
/// NOTE: assumes that there is exactly as much buffer as needed
pub fn serialize(
&self,
head_len: HeadLen,
head_len: NonceLen,
tag_len: TagLen,
out: &mut [u8],
) {
out[0] = self.fenrir_version as u8;
out[1] = match &self.data {
HandshakeData::DirSync(d) => match d {
Data::DirSync(d) => match d {
dirsync::DirSync::Req(_) => HandshakeKind::DirSyncReq,
dirsync::DirSync::Resp(_) => HandshakeKind::DirSyncResp,
},
@ -194,6 +197,6 @@ impl Handshake {
}
}
trait HandshakeParsing {
fn deserialize(raw: &[u8]) -> Result<HandshakeData, Error>;
trait Parsing {
fn deserialize(raw: &[u8]) -> Result<Data, Error>;
}

View File

@ -1,13 +1,16 @@
use crate::{
auth,
connection::{handshake::*, ID},
connection::{
handshake::{self, dirsync, Handshake},
ID,
},
enc::{self, asym::KeyID},
};
#[test]
fn test_handshake_dirsync_req() {
let rand = enc::Random::new();
let cipher = enc::sym::CipherKind::XChaCha20Poly1305;
let cipher = enc::sym::Kind::XChaCha20Poly1305;
let (_, exchange_key) =
match enc::asym::KeyExchangeKind::X25519DiffieHellman.new_keypair(&rand)
@ -19,11 +22,11 @@ fn test_handshake_dirsync_req() {
}
};
let data = dirsync::ReqInner::ClearText(dirsync::ReqData {
let data = dirsync::req::State::ClearText(dirsync::req::Data {
nonce: dirsync::Nonce::new(&rand),
client_key_id: KeyID(2424),
id: ID::ID(::core::num::NonZeroU64::new(424242).unwrap()),
auth: dirsync::AuthInfo {
auth: dirsync::req::AuthInfo {
user: auth::UserID::new(&rand),
token: auth::Token::new_anonymous(&rand),
service_id: auth::SERVICEID_AUTH,
@ -31,16 +34,16 @@ fn test_handshake_dirsync_req() {
},
});
let h_req = Handshake::new(HandshakeData::DirSync(dirsync::DirSync::Req(
dirsync::Req {
let h_req = Handshake::new(handshake::Data::DirSync(
dirsync::DirSync::Req(dirsync::req::Req {
key_id: KeyID(4224),
exchange: enc::asym::KeyExchangeKind::X25519DiffieHellman,
hkdf: enc::hkdf::HkdfKind::Sha3,
cipher: enc::sym::CipherKind::XChaCha20Poly1305,
hkdf: enc::hkdf::Kind::Sha3,
cipher: enc::sym::Kind::XChaCha20Poly1305,
exchange_key,
data,
},
)));
}),
));
let mut bytes = Vec::<u8>::with_capacity(
h_req.len(cipher.nonce_len(), cipher.tag_len()),
@ -55,7 +58,7 @@ fn test_handshake_dirsync_req() {
return;
}
};
if let HandshakeData::DirSync(dirsync::DirSync::Req(r_a)) =
if let handshake::Data::DirSync(dirsync::DirSync::Req(r_a)) =
&mut deserialized.data
{
let enc_start = r_a.encrypted_offset() + cipher.nonce_len().0;
@ -74,11 +77,11 @@ fn test_handshake_dirsync_req() {
#[test]
fn test_handshake_dirsync_reqsp() {
let rand = enc::Random::new();
let cipher = enc::sym::CipherKind::XChaCha20Poly1305;
let cipher = enc::sym::Kind::XChaCha20Poly1305;
let service_key = enc::Secret::new_rand(&rand);
let data = dirsync::RespInner::ClearText(dirsync::RespData {
let data = dirsync::resp::State::ClearText(dirsync::resp::Data {
client_nonce: dirsync::Nonce::new(&rand),
id: ID::ID(::core::num::NonZeroU64::new(424242).unwrap()),
service_connection_id: ID::ID(
@ -87,8 +90,8 @@ fn test_handshake_dirsync_reqsp() {
service_key,
});
let h_resp = Handshake::new(HandshakeData::DirSync(
dirsync::DirSync::Resp(dirsync::Resp {
let h_resp = Handshake::new(handshake::Data::DirSync(
dirsync::DirSync::Resp(dirsync::resp::Resp {
client_key_id: KeyID(4444),
data,
}),
@ -107,7 +110,7 @@ fn test_handshake_dirsync_reqsp() {
return;
}
};
if let HandshakeData::DirSync(dirsync::DirSync::Resp(r_a)) =
if let handshake::Data::DirSync(dirsync::DirSync::Resp(r_a)) =
&mut deserialized.data
{
let enc_start = r_a.encrypted_offset() + cipher.nonce_len().0;

View File

@ -10,40 +10,47 @@ use crate::{
enc::{
self,
asym::{self, KeyID, PrivKey, PubKey},
hkdf::{Hkdf, HkdfKind},
sym::{CipherKind, CipherRecv},
hkdf::{self, Hkdf},
sym::{self, CipherRecv},
},
inner::ThreadTracker,
};
use ::tokio::sync::oneshot;
pub(crate) struct HandshakeServer {
pub id: KeyID,
pub key: PrivKey,
pub domains: Vec<Domain>,
pub(crate) struct Server {
pub(crate) id: KeyID,
pub(crate) key: PrivKey,
pub(crate) domains: Vec<Domain>,
}
pub(crate) type ConnectAnswer = Result<(KeyID, IDSend), crate::Error>;
pub(crate) type ConnectAnswer = Result<ConnectOk, crate::Error>;
#[derive(Debug)]
pub(crate) struct ConnectOk {
pub(crate) auth_key_id: KeyID,
pub(crate) auth_id_send: IDSend,
pub(crate) authsrv_conn: connection::AuthSrvConn,
pub(crate) service_conn: Option<connection::ServiceConn>,
}
pub(crate) struct HandshakeClient {
pub service_id: ServiceID,
pub service_conn_id: IDRecv,
pub connection: Connection,
pub timeout: Option<::tokio::task::JoinHandle<()>>,
pub answer: oneshot::Sender<ConnectAnswer>,
pub srv_key_id: KeyID,
pub(crate) struct Client {
pub(crate) service_id: ServiceID,
pub(crate) service_conn_id: IDRecv,
pub(crate) connection: Connection,
pub(crate) timeout: Option<::tokio::time::Instant>,
pub(crate) answer: oneshot::Sender<ConnectAnswer>,
pub(crate) srv_key_id: KeyID,
}
/// Tracks the keys used by the client and the handshake
/// they are associated with
pub(crate) struct HandshakeClientList {
pub(crate) struct ClientList {
used: Vec<::bitmaps::Bitmap<1024>>, // index = KeyID
keys: Vec<Option<(PrivKey, PubKey)>>,
list: Vec<Option<HandshakeClient>>,
list: Vec<Option<Client>>,
}
impl HandshakeClientList {
impl ClientList {
pub(crate) fn new() -> Self {
Self {
used: [::bitmaps::Bitmap::<1024>::new()].to_vec(),
@ -51,13 +58,13 @@ impl HandshakeClientList {
list: Vec::with_capacity(16),
}
}
pub(crate) fn get(&self, id: KeyID) -> Option<&HandshakeClient> {
pub(crate) fn get(&self, id: KeyID) -> Option<&Client> {
if id.0 as usize >= self.list.len() {
return None;
}
self.list[id.0 as usize].as_ref()
}
pub(crate) fn remove(&mut self, id: KeyID) -> Option<HandshakeClient> {
pub(crate) fn remove(&mut self, id: KeyID) -> Option<Client> {
if id.0 as usize >= self.list.len() {
return None;
}
@ -82,8 +89,7 @@ impl HandshakeClientList {
connection: Connection,
answer: oneshot::Sender<ConnectAnswer>,
srv_key_id: KeyID,
) -> Result<(KeyID, &mut HandshakeClient), oneshot::Sender<ConnectAnswer>>
{
) -> Result<(KeyID, &mut Client), oneshot::Sender<ConnectAnswer>> {
let maybe_free_key_idx =
self.used.iter().enumerate().find_map(|(idx, bmap)| {
match bmap.first_false_index() {
@ -112,7 +118,7 @@ impl HandshakeClientList {
self.list.push(None);
}
self.keys[free_key_idx] = Some((priv_key, pub_key));
self.list[free_key_idx] = Some(HandshakeClient {
self.list[free_key_idx] = Some(Client {
service_id,
service_conn_id,
connection,
@ -130,30 +136,32 @@ impl HandshakeClientList {
#[derive(Debug, Clone)]
pub(crate) struct AuthNeededInfo {
/// Parsed handshake packet
pub handshake: Handshake,
pub(crate) handshake: Handshake,
/// hkdf generated from the handshake
pub hkdf: Hkdf,
pub(crate) hkdf: Hkdf,
}
/// Client information needed to fully establish the conenction
#[derive(Debug)]
pub(crate) struct ClientConnectInfo {
/// The service ID that we are connecting to
pub service_id: ServiceID,
pub(crate) service_id: ServiceID,
/// The service ID that we are connecting to
pub service_connection_id: IDRecv,
pub(crate) service_connection_id: IDRecv,
/// Parsed handshake packet
pub handshake: Handshake,
pub(crate) handshake: Handshake,
/// Old timeout for the handshake completion
pub(crate) old_timeout: ::tokio::time::Instant,
/// Connection
pub connection: Connection,
pub(crate) connection: Connection,
/// where to wake up the waiting client
pub answer: oneshot::Sender<ConnectAnswer>,
/// server public key id that we used on the handshake
pub srv_key_id: KeyID,
pub(crate) answer: oneshot::Sender<ConnectAnswer>,
/// server pub(crate)lic key id that we used on the handshake
pub(crate) srv_key_id: KeyID,
}
/// Intermediate actions to be taken while parsing the handshake
#[derive(Debug)]
pub(crate) enum HandshakeAction {
pub(crate) enum Action {
/// Parsing finished, all ok, nothing to do
Nothing,
/// Packet parsed, now go perform authentication
@ -167,20 +175,20 @@ pub(crate) enum HandshakeAction {
/// Each of them will handle a subset of all handshakes.
/// Each handshake is routed to a different tracker by checking
/// core = (udp_src_sender_port % total_threads) - 1
pub(crate) struct HandshakeTracker {
pub(crate) struct Tracker {
thread_id: ThreadTracker,
key_exchanges: Vec<asym::KeyExchangeKind>,
ciphers: Vec<CipherKind>,
ciphers: Vec<sym::Kind>,
/// ephemeral keys used server side in key exchange
keys_srv: Vec<HandshakeServer>,
keys_srv: Vec<Server>,
/// ephemeral keys used client side in key exchange
hshake_cli: HandshakeClientList,
hshake_cli: ClientList,
}
impl HandshakeTracker {
impl Tracker {
pub(crate) fn new(
thread_id: ThreadTracker,
ciphers: Vec<CipherKind>,
ciphers: Vec<sym::Kind>,
key_exchanges: Vec<asym::KeyExchangeKind>,
) -> Self {
Self {
@ -188,7 +196,7 @@ impl HandshakeTracker {
ciphers,
key_exchanges,
keys_srv: Vec::new(),
hshake_cli: HandshakeClientList::new(),
hshake_cli: ClientList::new(),
}
}
pub(crate) fn add_server_key(
@ -199,7 +207,7 @@ impl HandshakeTracker {
if self.keys_srv.iter().find(|&k| k.id == id).is_some() {
return Err(());
}
self.keys_srv.push(HandshakeServer {
self.keys_srv.push(Server {
id,
key,
domains: Vec::new(),
@ -236,8 +244,7 @@ impl HandshakeTracker {
connection: Connection,
answer: oneshot::Sender<ConnectAnswer>,
srv_key_id: KeyID,
) -> Result<(KeyID, &mut HandshakeClient), oneshot::Sender<ConnectAnswer>>
{
) -> Result<(KeyID, &mut Client), oneshot::Sender<ConnectAnswer>> {
self.hshake_cli.add(
priv_key,
pub_key,
@ -248,10 +255,7 @@ impl HandshakeTracker {
srv_key_id,
)
}
pub(crate) fn remove_client(
&mut self,
key_id: KeyID,
) -> Option<HandshakeClient> {
pub(crate) fn remove_client(&mut self, key_id: KeyID) -> Option<Client> {
self.hshake_cli.remove(key_id)
}
pub(crate) fn timeout_client(
@ -269,10 +273,10 @@ impl HandshakeTracker {
&mut self,
mut handshake: Handshake,
handshake_raw: &mut [u8],
) -> Result<HandshakeAction, Error> {
use connection::handshake::{dirsync::DirSync, HandshakeData};
) -> Result<Action, Error> {
use handshake::dirsync::DirSync;
match handshake.data {
HandshakeData::DirSync(ref mut ds) => match ds {
handshake::Data::DirSync(ref mut ds) => match ds {
DirSync::Req(ref mut req) => {
if !self.key_exchanges.contains(&req.exchange) {
return Err(enc::Error::UnsupportedKeyExchange.into());
@ -310,7 +314,8 @@ impl HandshakeTracker {
Ok(shared_key) => shared_key,
Err(e) => return Err(handshake::Error::Key(e).into()),
};
let hkdf = Hkdf::new(HkdfKind::Sha3, b"fenrir", shared_key);
let hkdf =
Hkdf::new(hkdf::Kind::Sha3, b"fenrir", shared_key);
let secret_recv = hkdf.get_secret(b"to_server");
let cipher_recv = CipherRecv::new(req.cipher, secret_recv);
use crate::enc::sym::AAD;
@ -334,7 +339,7 @@ impl HandshakeTracker {
}
}
return Ok(HandshakeAction::AuthNeeded(AuthNeededInfo {
return Ok(Action::AuthNeeded(AuthNeededInfo {
handshake,
hkdf,
}));
@ -371,19 +376,15 @@ impl HandshakeTracker {
}
let hshake =
self.hshake_cli.remove(resp.client_key_id).unwrap();
if let Some(timeout) = hshake.timeout {
timeout.abort();
}
return Ok(HandshakeAction::ClientConnect(
ClientConnectInfo {
service_id: hshake.service_id,
service_connection_id: hshake.service_conn_id,
handshake,
connection: hshake.connection,
answer: hshake.answer,
srv_key_id: hshake.srv_key_id,
},
));
return Ok(Action::ClientConnect(ClientConnectInfo {
service_id: hshake.service_id,
service_connection_id: hshake.service_conn_id,
handshake,
old_timeout: hshake.timeout.unwrap(),
connection: hshake.connection,
answer: hshake.answer,
srv_key_id: hshake.srv_key_id,
}));
}
},
}

View File

@ -3,30 +3,124 @@
pub mod handshake;
pub mod packet;
pub mod socket;
pub mod stream;
use ::std::{rc::Rc, vec::Vec};
pub use crate::connection::{
handshake::Handshake,
packet::{ConnectionID as ID, Packet, PacketData},
use ::core::num::Wrapping;
use ::std::{
collections::{BTreeMap, HashMap},
vec::Vec,
};
pub use crate::connection::{handshake::Handshake, packet::Packet};
use crate::{
connection::{socket::UdpClient, stream::StreamData},
dnssec,
enc::{
self,
asym::PubKey,
hkdf::Hkdf,
sym::{CipherKind, CipherRecv, CipherSend},
sym::{self, CipherRecv, CipherSend},
Random,
},
inner::ThreadTracker,
inner::{worker, ThreadTracker},
};
/// Connaction errors
#[derive(::thiserror::Error, Debug, Copy, Clone)]
pub(crate) enum Error {
/// Can't decrypt packet
#[error("Decrypt error: {0}")]
Decrypt(#[from] crate::enc::Error),
#[error("Chunk parsing: {0}")]
Parse(#[from] stream::Error),
}
/// Fenrir Connection ID
///
/// 0 is special as it represents the handshake
/// Connection IDs are to be considered u64 little endian
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum ID {
/// Connection id 0 represent the handshake
Handshake,
/// Non-zero id can represent any connection
ID(::core::num::NonZeroU64),
}
impl ID {
/// Set the conenction id to handshake
pub fn new_handshake() -> Self {
Self::Handshake
}
/// New id from u64. PLZ NON ZERO
pub(crate) fn new_u64(raw: u64) -> Self {
#[allow(unsafe_code)]
unsafe {
ID::ID(::core::num::NonZeroU64::new_unchecked(raw))
}
}
pub(crate) fn as_u64(&self) -> u64 {
match self {
ID::Handshake => 0,
ID::ID(id) => id.get(),
}
}
/// New random service ID
pub fn new_rand(rand: &Random) -> Self {
let mut raw = [0; 8];
let mut num = 0;
while num == 0 {
rand.fill(&mut raw);
num = u64::from_le_bytes(raw);
}
#[allow(unsafe_code)]
unsafe {
ID::ID(::core::num::NonZeroU64::new_unchecked(num))
}
}
/// Quick check to know if this is an handshake
pub fn is_handshake(&self) -> bool {
*self == ID::Handshake
}
/// length if the connection ID in bytes
pub const fn len() -> usize {
8
}
/// write the ID to a buffer
pub fn serialize(&self, out: &mut [u8]) {
match self {
ID::Handshake => out[..8].copy_from_slice(&[0; 8]),
ID::ID(id) => out[..8].copy_from_slice(&id.get().to_le_bytes()),
}
}
}
impl From<u64> for ID {
fn from(raw: u64) -> Self {
if raw == 0 {
ID::Handshake
} else {
#[allow(unsafe_code)]
unsafe {
ID::ID(::core::num::NonZeroU64::new_unchecked(raw))
}
}
}
}
impl From<[u8; 8]> for ID {
fn from(raw: [u8; 8]) -> Self {
let raw_u64 = u64::from_le_bytes(raw);
raw_u64.into()
}
}
/// strong typedef for receiving connection id
#[derive(Debug, Copy, Clone, PartialEq)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct IDRecv(pub ID);
/// strong typedef for sending connection id
#[derive(Debug, Copy, Clone, PartialEq)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct IDSend(pub ID);
/// Version of the fenrir protocol in use
@ -47,24 +141,56 @@ impl ProtocolVersion {
}
}
/// A single connection and its data
#[derive(Debug)]
pub struct Connection {
/// Receiving Connection ID
pub id_recv: IDRecv,
/// Sending Connection ID
pub id_send: IDSend,
/// The main hkdf used for all secrets in this connection
pub hkdf: Hkdf,
/// Cipher for decrypting data
pub cipher_recv: CipherRecv,
/// Cipher for encrypting data
pub cipher_send: CipherSend,
/// Unique tracker of connections
#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
pub struct ConnTracker(Wrapping<u64>);
impl ConnTracker {
pub(crate) fn new(start: u16) -> Self {
Self(Wrapping(start as u64))
}
pub(crate) fn advance(&mut self, amount: u16) -> Self {
let old = self.0;
self.0 = self.0 + Wrapping(amount as u64);
ConnTracker(old)
}
}
/// Role: used to set the correct secrets
/// * Server: Connection is Incoming
/// * Client: Connection is Outgoing
/// Connection to an Authentication Server
#[derive(Debug)]
pub struct AuthSrvConn(pub Conn);
/// Connection to a service
#[derive(Debug)]
pub struct ServiceConn(pub Conn);
/// The connection, as seen from a user of libFenrir
#[derive(Debug)]
pub struct Conn {
pub(crate) queue: ::async_channel::Sender<worker::Work>,
pub(crate) fast: ConnTracker,
}
impl Conn {
/// Queue some data to be sent in this connection
// TODO: send_and_wait, that wait for recipient ACK
pub async fn send(&mut self, stream: stream::ID, data: Vec<u8>) {
use crate::inner::worker::Work;
let _ = self
.queue
.send(Work::UserSend((self.tracker(), stream, data)))
.await;
}
/// Get the library tracking id
pub fn tracker(&self) -> ConnTracker {
self.fast
}
}
/// Role: track the connection direction
///
/// The Role is used to select the correct secrets, and track the direction
/// of the connection
/// * Server: Conn is Incoming
/// * Client: Conn is Outgoing
#[derive(Debug, Copy, Clone)]
#[repr(u8)]
pub enum Role {
@ -74,10 +200,54 @@ pub enum Role {
Client,
}
#[derive(Debug)]
enum TimerKind {
None,
SendData(::tokio::time::Instant),
Keepalive(::tokio::time::Instant),
}
pub(crate) enum Enqueue {
NoSuchStream,
TimerWait,
Immediate(::tokio::time::Instant),
}
/// Connection tracking id. Set by the user
#[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Copy, Clone)]
pub struct UserTracker(pub ::core::num::NonZeroU64);
/// A single connection and its data
#[derive(Debug)]
pub(crate) struct Connection {
/// Receiving Conn ID
pub(crate) id_recv: IDRecv,
/// Sending Conn ID
pub(crate) id_send: IDSend,
/// User-managed id to track this connection
/// the user can set this to better track this connection
pub(crate) user_tracker: Option<UserTracker>,
/// Sending address
pub(crate) send_addr: UdpClient,
/// The main hkdf used for all secrets in this connection
hkdf: Hkdf,
/// Cipher for decrypting data
pub(crate) cipher_recv: CipherRecv,
/// Cipher for encrypting data
pub(crate) cipher_send: CipherSend,
mtu: usize,
next_timer: TimerKind,
/// send queue for each Stream
send_queue: BTreeMap<stream::ID, stream::SendTracker>,
last_stream_sent: stream::ID,
/// receive queue for each Stream
recv_queue: BTreeMap<stream::ID, stream::Stream>,
}
impl Connection {
pub(crate) fn new(
hkdf: Hkdf,
cipher: CipherKind,
cipher: sym::Kind,
role: Role,
rand: &Random,
) -> Self {
@ -92,21 +262,172 @@ impl Connection {
let cipher_recv = CipherRecv::new(cipher, secret_recv);
let cipher_send = CipherSend::new(cipher, secret_send, rand);
use ::std::net::{IpAddr, Ipv4Addr, SocketAddr};
Self {
id_recv: IDRecv(ID::Handshake),
id_send: IDSend(ID::Handshake),
user_tracker: None,
// will be overwritten
send_addr: UdpClient(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
31337,
)),
hkdf,
cipher_recv,
cipher_send,
mtu: 1200,
next_timer: TimerKind::None,
send_queue: BTreeMap::new(),
last_stream_sent: stream::ID(0),
recv_queue: BTreeMap::new(),
}
}
pub(crate) fn recv(
&mut self,
mut udp: crate::RawUdp,
) -> Result<StreamData, Error> {
let mut data = &mut udp.data[ID::len()..];
let aad = enc::sym::AAD(&[]);
self.cipher_recv.decrypt(aad, &mut data)?;
let mut bytes_parsed = 0;
let mut chunks = Vec::with_capacity(2);
loop {
let chunk = match stream::Chunk::deserialize(&data[bytes_parsed..])
{
Ok(chunk) => chunk,
Err(e) => {
return Err(e.into());
}
};
bytes_parsed = bytes_parsed + chunk.len();
chunks.push(chunk);
if bytes_parsed == data.len() {
break;
}
}
let mut data_ready = StreamData::NotReady;
for chunk in chunks.into_iter() {
let stream_id = chunk.id;
let stream = match self.recv_queue.get_mut(&stream_id) {
Some(stream) => stream,
None => {
::tracing::debug!("Ignoring chunk for unknown stream::ID");
continue;
}
};
match stream.recv(chunk) {
Ok(status) => data_ready = data_ready | status,
Err(e) => ::tracing::debug!("stream: {:?}: {:?}", stream_id, e),
}
}
Ok(data_ready)
}
pub(crate) fn enqueue(
&mut self,
stream: stream::ID,
data: Vec<u8>,
) -> Enqueue {
let stream = match self.send_queue.get_mut(&stream) {
None => return Enqueue::NoSuchStream,
Some(stream) => stream,
};
stream.enqueue(data);
let instant;
let ret;
self.next_timer = match self.next_timer {
TimerKind::None | TimerKind::Keepalive(_) => {
instant = ::tokio::time::Instant::now();
ret = Enqueue::Immediate(instant);
TimerKind::SendData(instant)
}
TimerKind::SendData(old_timer) => {
// There already is some data to be sent
// wait for this timer,
// or risk going over max transmission rate
ret = Enqueue::TimerWait;
TimerKind::SendData(old_timer)
}
};
ret
}
pub(crate) fn write_pkt<'a>(
&mut self,
raw: &'a mut [u8],
) -> Result<&'a [u8], enc::Error> {
assert!(raw.len() >= self.mtu, "I should have at least 1200 MTU");
if self.send_queue.len() == 0 {
return Err(enc::Error::NotEnoughData(0));
}
raw[..ID::len()]
.copy_from_slice(&self.id_send.0.as_u64().to_le_bytes());
let data_from = ID::len() + self.cipher_send.nonce_len().0;
let data_max_to = raw.len() - self.cipher_send.tag_len().0;
let mut chunk_from = data_from;
let mut available_len = data_max_to - data_from;
use std::ops::Bound::{Excluded, Included};
let last_stream = self.last_stream_sent;
// Loop over our streams, write them to the packet.
// Notes:
// * to avoid starvation, just round-robin them all for now
// * we can enqueue multiple times the same stream
// This is useful especially for Datagram streams
'queueloop: {
for (id, stream) in self
.send_queue
.range_mut((Included(last_stream), Included(stream::ID::max())))
{
if available_len < stream::Chunk::headers_len() + 1 {
break 'queueloop;
}
let bytes =
stream.serialize(*id, &mut raw[chunk_from..data_max_to]);
if bytes == 0 {
break 'queueloop;
}
available_len = available_len - bytes;
chunk_from = chunk_from + bytes;
self.last_stream_sent = *id;
}
if available_len > 0 {
for (id, stream) in self.send_queue.range_mut((
Included(stream::ID::min()),
Excluded(last_stream),
)) {
if available_len < stream::Chunk::headers_len() + 1 {
break 'queueloop;
}
let bytes = stream
.serialize(*id, &mut raw[chunk_from..data_max_to]);
if bytes == 0 {
break 'queueloop;
}
available_len = available_len - bytes;
chunk_from = chunk_from + bytes;
self.last_stream_sent = *id;
}
}
}
if chunk_from == data_from {
return Err(enc::Error::NotEnoughData(0));
}
let data_to = chunk_from + self.cipher_send.tag_len().0;
// encrypt
let aad = sym::AAD(&[]);
match self.cipher_send.encrypt(aad, &mut raw[data_from..data_to]) {
Ok(_) => Ok(&raw[..data_to]),
Err(e) => Err(e),
}
}
}
// PERF: Arc<RwLock<ConnList>> loks a bit too much, need to find
// faster ways to do this
pub(crate) struct ConnList {
thread_id: ThreadTracker,
connections: Vec<Option<Rc<Connection>>>,
connections: Vec<Option<Connection>>,
user_tracker: BTreeMap<ConnTracker, usize>,
last_tracked: ConnTracker,
/// Bitmap to track which connection ids are used or free
ids_used: Vec<::bitmaps::Bitmap<1024>>,
}
@ -122,11 +443,35 @@ impl ConnList {
let mut ret = Self {
thread_id,
connections: Vec::with_capacity(INITIAL_CAP),
user_tracker: BTreeMap::new(),
last_tracked: ConnTracker(Wrapping(0)),
ids_used: vec![bitmap_id],
};
ret.connections.resize_with(INITIAL_CAP, || None);
ret
}
pub fn get_id_mut(&mut self, id: ID) -> Option<&mut Connection> {
let conn_id = match id {
ID::Handshake => {
return None;
}
ID::ID(conn_id) => conn_id,
};
let id_in_thread: usize =
(conn_id.get() / (self.thread_id.total as u64)) as usize;
(&mut self.connections[id_in_thread]).into()
}
pub fn get_mut(&mut self, tracker: ConnTracker) -> Option<&mut Connection> {
let idx = if let Some(idx) = self.user_tracker.get(&tracker) {
*idx
} else {
return None;
};
match &mut self.connections[idx] {
None => None,
Some(conn) => Some(conn),
}
}
pub fn len(&self) -> usize {
let mut total: usize = 0;
for bitmap in self.ids_used.iter() {
@ -177,7 +522,10 @@ impl ConnList {
new_id
}
/// NOTE: does NOT check if the connection has been previously reserved!
pub(crate) fn track(&mut self, conn: Rc<Connection>) -> Result<(), ()> {
pub(crate) fn track(
&mut self,
conn: Connection,
) -> Result<ConnTracker, ()> {
let conn_id = match conn.id_recv {
IDRecv(ID::Handshake) => {
return Err(());
@ -187,7 +535,16 @@ impl ConnList {
let id_in_thread: usize =
(conn_id.get() / (self.thread_id.total as u64)) as usize;
self.connections[id_in_thread] = Some(conn);
Ok(())
let mut tracked;
loop {
tracked = self.last_tracked.advance(self.thread_id.total);
if self.user_tracker.get(&tracked).is_none() {
// like, never gonna happen, it's 64 bit
let _ = self.user_tracker.insert(tracked, id_in_thread);
break;
}
}
Ok(tracked)
}
pub(crate) fn remove(&mut self, id: IDRecv) {
if let IDRecv(ID::ID(raw_id)) = id {
@ -219,7 +576,6 @@ enum MapEntry {
Present(IDSend),
Reserved,
}
use ::std::collections::HashMap;
/// Link the public key of the authentication server to a connection id
/// so that we can reuse that connection to ask for more authentications

View File

@ -1,125 +1,44 @@
//
//! Raw packet handling, encryption, decryption, parsing
use crate::enc::{
sym::{HeadLen, TagLen},
Random,
use crate::{
connection,
enc::sym::{NonceLen, TagLen},
};
/// Fenrir Connection id
/// 0 is special as it represents the handshake
/// Connection IDs are to be considered u64 little endian
#[derive(Debug, Copy, Clone, PartialEq)]
pub enum ConnectionID {
/// Connection id 0 represent the handshake
Handshake,
/// Non-zero id can represent any connection
ID(::core::num::NonZeroU64),
}
impl ConnectionID {
/// Set the conenction id to handshake
pub fn new_handshake() -> Self {
Self::Handshake
}
/// New id from u64. PLZ NON ZERO
pub(crate) fn new_u64(raw: u64) -> Self {
#[allow(unsafe_code)]
unsafe {
ConnectionID::ID(::core::num::NonZeroU64::new_unchecked(raw))
}
}
pub(crate) fn as_u64(&self) -> u64 {
match self {
ConnectionID::Handshake => 0,
ConnectionID::ID(id) => id.get(),
}
}
/// New random service ID
pub fn new_rand(rand: &Random) -> Self {
let mut raw = [0; 8];
let mut num = 0;
while num == 0 {
rand.fill(&mut raw);
num = u64::from_le_bytes(raw);
}
#[allow(unsafe_code)]
unsafe {
ConnectionID::ID(::core::num::NonZeroU64::new_unchecked(num))
}
}
/// Quick check to know if this is an handshake
pub fn is_handshake(&self) -> bool {
*self == ConnectionID::Handshake
}
/// length if the connection ID in bytes
pub const fn len() -> usize {
8
}
/// write the ID to a buffer
pub fn serialize(&self, out: &mut [u8]) {
match self {
ConnectionID::Handshake => out[..8].copy_from_slice(&[0; 8]),
ConnectionID::ID(id) => {
out[..8].copy_from_slice(&id.get().to_le_bytes())
}
}
}
}
impl From<u64> for ConnectionID {
fn from(raw: u64) -> Self {
if raw == 0 {
ConnectionID::Handshake
} else {
#[allow(unsafe_code)]
unsafe {
ConnectionID::ID(::core::num::NonZeroU64::new_unchecked(raw))
}
}
}
}
impl From<[u8; 8]> for ConnectionID {
fn from(raw: [u8; 8]) -> Self {
let raw_u64 = u64::from_le_bytes(raw);
raw_u64.into()
}
}
/// Enumerate the possible data in a fenrir packet
#[derive(Debug, Clone)]
pub enum PacketData {
pub enum Data {
/// A parsed handshake packet
Handshake(super::Handshake),
/// Raw packet. we only have the connection ID and packet length
Raw(usize),
}
impl PacketData {
impl Data {
/// total length of the data in bytes
pub fn len(&self, head_len: HeadLen, tag_len: TagLen) -> usize {
pub fn len(&self, head_len: NonceLen, tag_len: TagLen) -> usize {
match self {
PacketData::Handshake(h) => h.len(head_len, tag_len),
PacketData::Raw(len) => *len,
Data::Handshake(h) => h.len(head_len, tag_len),
Data::Raw(len) => *len,
}
}
/// serialize data into bytes
/// NOTE: assumes that there is exactly asa much buffer as needed
pub fn serialize(
&self,
head_len: HeadLen,
head_len: NonceLen,
tag_len: TagLen,
out: &mut [u8],
) {
assert!(
self.len(head_len, tag_len) == out.len(),
"PacketData: wrong buffer length"
"Data: wrong buffer length"
);
match self {
PacketData::Handshake(h) => h.serialize(head_len, tag_len, out),
PacketData::Raw(_) => {
::tracing::error!("Tried to serialize a raw PacketData!");
Data::Handshake(h) => h.serialize(head_len, tag_len, out),
Data::Raw(_) => {
::tracing::error!("Tried to serialize a raw Data!");
}
}
}
@ -131,9 +50,9 @@ const MIN_PACKET_BYTES: usize = 16;
#[derive(Debug, Clone)]
pub struct Packet {
/// Id of the Fenrir connection.
pub id: ConnectionID,
pub id: connection::ID,
/// actual data inside the packet
pub data: PacketData,
pub data: Data,
}
impl Packet {
@ -146,27 +65,30 @@ impl Packet {
let raw_id: [u8; 8] = (raw[..8]).try_into().expect("unreachable");
Ok(Packet {
id: raw_id.into(),
data: PacketData::Raw(raw.len()),
data: Data::Raw(raw.len()),
})
}
/// get the total length of the packet
pub fn len(&self, head_len: HeadLen, tag_len: TagLen) -> usize {
ConnectionID::len() + self.data.len(head_len, tag_len)
pub fn len(&self, head_len: NonceLen, tag_len: TagLen) -> usize {
connection::ID::len() + self.data.len(head_len, tag_len)
}
/// serialize packet into buffer
/// NOTE: assumes that there is exactly asa much buffer as needed
pub fn serialize(
&self,
head_len: HeadLen,
head_len: NonceLen,
tag_len: TagLen,
out: &mut [u8],
) {
assert!(
out.len() > ConnectionID::len(),
out.len() > connection::ID::len(),
"Packet: not enough buffer to serialize"
);
self.id.serialize(&mut out[0..ConnectionID::len()]);
self.data
.serialize(head_len, tag_len, &mut out[ConnectionID::len()..]);
self.id.serialize(&mut out[0..connection::ID::len()]);
self.data.serialize(
head_len,
tag_len,
&mut out[connection::ID::len()..],
);
}
}

View File

@ -0,0 +1,12 @@
//! Errors while parsing streams
/// Crypto errors
#[derive(::thiserror::Error, Debug, Copy, Clone)]
pub enum Error {
/// Error while parsing key material
#[error("Not enough data for stream chunk: {0}")]
NotEnoughData(usize),
/// Sequence outside of the window
#[error("Sequence out of the sliding window")]
OutOfWindow,
}

View File

@ -0,0 +1,347 @@
//! Here we implement the multiplexing stream feature of Fenrir
//!
//! For now we will only have the TCP-like, reliable, in-order delivery
mod errors;
mod rob;
pub use errors::Error;
use crate::{connection::stream::rob::ReliableOrderedBytestream, enc::Random};
/// Kind of stream. any combination of:
/// reliable/unreliable ordered/unordered, bytestream/datagram
#[derive(Debug, Copy, Clone)]
#[repr(u8)]
pub enum Kind {
/// ROB: Reliable, Ordered, Bytestream
/// AKA: TCP-like
ROB = 0,
}
/// Id of the stream
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct ID(pub u16);
impl ID {
/// Length of the serialized field
pub const fn len() -> usize {
2
}
/// Minimum possible Stream ID (u16::MIN)
pub const fn min() -> Self {
Self(u16::MIN)
}
/// Maximum possible Stream ID (u16::MAX)
pub const fn max() -> Self {
Self(u16::MAX)
}
}
/// length of the chunk
#[derive(Debug, Copy, Clone)]
pub struct ChunkLen(pub u16);
impl ChunkLen {
/// Length of the serialized field
pub const fn len() -> usize {
2
}
}
//TODO: make pub?
#[derive(Debug, Copy, Clone)]
pub(crate) struct SequenceStart(pub(crate) Sequence);
impl SequenceStart {
pub(crate) fn plus_u32(&self, other: u32) -> Sequence {
self.0.plus_u32(other)
}
pub(crate) fn offset(&self, seq: Sequence) -> usize {
if self.0 .0 <= seq.0 {
(seq.0 - self.0 .0).0 as usize
} else {
(seq.0 + (Sequence::max().0 - self.0 .0)).0 as usize
}
}
}
// SequenceEnd is INCLUSIVE
#[derive(Debug, Copy, Clone)]
pub(crate) struct SequenceEnd(pub(crate) Sequence);
impl SequenceEnd {
pub(crate) fn plus_u32(&self, other: u32) -> Sequence {
self.0.plus_u32(other)
}
}
/// Sequence number to rebuild the stream correctly
#[derive(Debug, Copy, Clone)]
pub struct Sequence(pub ::core::num::Wrapping<u32>);
impl Sequence {
const SEQ_NOFLAG: u32 = 0x3FFFFFFF;
/// return a new sequence number, starting at random
pub fn new(rand: &Random) -> Self {
let mut raw_seq: [u8; 4] = [0; 4];
rand.fill(&mut raw_seq);
let seq = u32::from_le_bytes(raw_seq);
Self(::core::num::Wrapping(seq & Self::SEQ_NOFLAG))
}
/// Length of the serialized field
pub const fn len() -> usize {
4
}
/// Maximum possible sequence
pub const fn max() -> Self {
Self(::core::num::Wrapping(Self::SEQ_NOFLAG))
}
pub(crate) fn is_between(
&self,
start: SequenceStart,
end: SequenceEnd,
) -> bool {
if start.0 .0 < end.0 .0 {
start.0 .0 <= self.0 && self.0 <= end.0 .0
} else {
start.0 .0 <= self.0 || self.0 <= end.0 .0
}
}
pub(crate) fn remaining_window(&self, end: SequenceEnd) -> u32 {
if self.0 <= end.0 .0 {
(end.0 .0 .0 - self.0 .0) + 1
} else {
end.0 .0 .0 + 1 + (Self::max().0 - self.0).0
}
}
pub(crate) fn plus_u32(self, other: u32) -> Self {
Self(::core::num::Wrapping(
(self.0 .0 + other) & Self::SEQ_NOFLAG,
))
}
}
impl ::core::ops::Add for Sequence {
type Output = Self;
fn add(self, other: Self) -> Self {
Self(::core::num::Wrapping(
(self.0 + other.0).0 & Self::SEQ_NOFLAG,
))
}
}
/// Chunk of data representing a stream
/// Every chunk is as follows:
/// | id (2 bytes) | length (2 bytes) |
/// | flag_start (1 BIT) | flag_end (1 BIT) | sequence (30 bits) |
#[derive(Debug, Clone)]
pub struct Chunk<'a> {
/// Id of the stream this chunk is part of
pub id: ID,
/// Is this the beginning of a message?
pub flag_start: bool,
/// Is this the end of a message?
pub flag_end: bool,
/// Sequence number to reconstruct the Stream
pub sequence: Sequence,
data: &'a [u8],
}
impl<'a> Chunk<'a> {
const FLAGS_EXCLUDED_BITMASK: u8 = 0x3F;
const FLAG_START_BITMASK: u8 = 0x80;
const FLAG_END_BITMASK: u8 = 0x40;
/// Return the length of the header of a Chunk
pub const fn headers_len() -> usize {
ID::len() + ChunkLen::len() + Sequence::len()
}
/// Returns the total length of the chunk, including headers
pub fn len(&self) -> usize {
ID::len() + ChunkLen::len() + Sequence::len() + self.data.len()
}
/// deserialize a chunk of a stream
pub fn deserialize(raw: &'a [u8]) -> Result<Self, Error> {
if raw.len() <= ID::len() + ChunkLen::len() + Sequence::len() {
return Err(Error::NotEnoughData(0));
}
let id = ID(u16::from_le_bytes(raw[0..ID::len()].try_into().unwrap()));
let mut bytes_next = ID::len() + ChunkLen::len();
let length = ChunkLen(u16::from_le_bytes(
raw[ID::len()..bytes_next].try_into().unwrap(),
));
if ID::len() + ChunkLen::len() + Sequence::len() + length.0 as usize
> raw.len()
{
return Err(Error::NotEnoughData(4));
}
let flag_start = (raw[bytes_next] & Self::FLAG_START_BITMASK) != 0;
let flag_end = (raw[bytes_next] & Self::FLAG_END_BITMASK) != 0;
let bytes = bytes_next + 1;
bytes_next = bytes + Sequence::len();
let mut sequence_bytes: [u8; Sequence::len()] =
raw[bytes..bytes_next].try_into().unwrap();
sequence_bytes[0] = sequence_bytes[0] & Self::FLAGS_EXCLUDED_BITMASK;
let sequence =
Sequence(::core::num::Wrapping(u32::from_le_bytes(sequence_bytes)));
Ok(Self {
id,
flag_start,
flag_end,
sequence,
data: &raw[bytes_next..(bytes_next + length.0 as usize)],
})
}
/// serialize a chunk of a stream
pub fn serialize(&self, raw_out: &mut [u8]) {
raw_out[0..ID::len()].copy_from_slice(&self.id.0.to_le_bytes());
let mut bytes_next = ID::len() + ChunkLen::len();
raw_out[ID::len()..bytes_next]
.copy_from_slice(&(self.data.len() as u16).to_le_bytes());
let bytes = bytes_next;
bytes_next = bytes_next + Sequence::len();
raw_out[bytes..bytes_next]
.copy_from_slice(&self.sequence.0 .0.to_le_bytes());
let mut flag_byte = raw_out[bytes] & Self::FLAGS_EXCLUDED_BITMASK;
if self.flag_start {
flag_byte = flag_byte | Self::FLAG_START_BITMASK;
}
if self.flag_end {
flag_byte = flag_byte | Self::FLAG_END_BITMASK;
}
raw_out[bytes] = flag_byte;
let bytes = bytes_next;
bytes_next = bytes_next + self.data.len();
raw_out[bytes..bytes_next].copy_from_slice(&self.data);
}
}
/// Kind of stream. any combination of:
/// reliable/unreliable ordered/unordered, bytestream/datagram
/// differences from Kind:
/// * not public
/// * has actual data
#[derive(Debug, Clone)]
pub(crate) enum Tracker {
/// ROB: Reliable, Ordered, Bytestream
/// AKA: TCP-like
ROB(ReliableOrderedBytestream),
}
impl Tracker {
pub(crate) fn new(kind: Kind, rand: &Random) -> Self {
match kind {
Kind::ROB => Tracker::ROB(ReliableOrderedBytestream::new(rand)),
}
}
}
#[derive(Debug, Eq, PartialEq)]
pub(crate) enum StreamData {
/// not enough data to return somthing to the user
NotReady = 0,
/// we can return something to the user
Ready,
}
impl ::core::ops::BitOr for StreamData {
type Output = Self;
// Required method
fn bitor(self, other: Self) -> Self::Output {
if self == StreamData::Ready || other == StreamData::Ready {
StreamData::Ready
} else {
StreamData::NotReady
}
}
}
/// Actual stream-tracking structure
#[derive(Debug, Clone)]
pub(crate) struct Stream {
id: ID,
data: Tracker,
}
impl Stream {
pub(crate) fn new(kind: Kind, rand: &Random) -> Self {
let id: u16 = 0;
rand.fill(&mut id.to_le_bytes());
Self {
id: ID(id),
data: Tracker::new(kind, rand),
}
}
pub(crate) fn recv(&mut self, chunk: Chunk) -> Result<StreamData, Error> {
match &mut self.data {
Tracker::ROB(tracker) => tracker.recv(chunk),
}
}
}
/// Track what has been sent and what has been ACK'd from a stream
#[derive(Debug)]
pub(crate) struct SendTracker {
queue: Vec<Vec<u8>>,
sent: Vec<usize>,
ackd: Vec<usize>,
chunk_started: bool,
is_datagram: bool,
next_sequence: Sequence,
}
impl SendTracker {
pub(crate) fn new(rand: &Random) -> Self {
Self {
queue: Vec::with_capacity(4),
sent: Vec::with_capacity(4),
ackd: Vec::with_capacity(4),
chunk_started: false,
is_datagram: false,
next_sequence: Sequence::new(rand),
}
}
/// Enqueue user data to be sent
pub(crate) fn enqueue(&mut self, data: Vec<u8>) {
self.queue.push(data);
self.sent.push(0);
self.ackd.push(0);
}
/// Write the user data to the buffer and mark it as sent
pub(crate) fn get(&mut self, out: &mut [u8]) -> usize {
let data = match self.queue.get(0) {
Some(data) => data,
None => return 0,
};
let len = ::std::cmp::min(out.len(), data.len());
out[..len].copy_from_slice(&data[self.sent[0]..len]);
self.sent[0] = self.sent[0] + len;
len
}
/// Mark the sent data as successfully received from the receiver
pub(crate) fn ack(&mut self, size: usize) {
todo!()
}
pub(crate) fn serialize(&mut self, id: ID, raw: &mut [u8]) -> usize {
let max_data_len = raw.len() - Chunk::headers_len();
let data_len = ::std::cmp::min(max_data_len, self.queue[0].len());
let flag_start = !self.chunk_started;
let flag_end = self.is_datagram && data_len == self.queue[0].len();
let chunk = Chunk {
id,
flag_start,
flag_end,
sequence: self.next_sequence,
data: &self.queue[0][..data_len],
};
self.next_sequence = Sequence(
self.next_sequence.0 + ::core::num::Wrapping(data_len as u32),
);
if chunk.flag_end {
self.chunk_started = false;
}
chunk.serialize(raw);
data_len
}
}

View File

@ -0,0 +1,210 @@
//! Implementation of the Reliable, Ordered, Bytestream transmission model
//! AKA: TCP-like
use crate::{
connection::stream::{
Chunk, Error, Sequence, SequenceEnd, SequenceStart, StreamData,
},
enc::Random,
};
#[cfg(test)]
mod tests;
/// Reliable, Ordered, Bytestream stream tracker
/// AKA: TCP-like
#[derive(Debug, Clone)]
pub(crate) struct ReliableOrderedBytestream {
pub(crate) window_start: SequenceStart,
window_end: SequenceEnd,
pivot: u32,
data: Vec<u8>,
missing: Vec<(Sequence, Sequence)>,
}
impl ReliableOrderedBytestream {
pub(crate) fn new(rand: &Random) -> Self {
let window_len = 1048576; // 1MB. should be enough for anybody. (lol)
let window_start = SequenceStart(Sequence::new(rand));
let window_end = SequenceEnd(window_start.0.plus_u32(window_len - 1));
let mut data = Vec::with_capacity(window_len as usize);
data.resize(data.capacity(), 0);
Self {
window_start,
window_end,
pivot: window_len,
data,
missing: [(window_start.0, window_end.0)].to_vec(),
}
}
pub(crate) fn with_window_size(rand: &Random, size: u32) -> Self {
assert!(
size < Sequence::max().0 .0,
"Max window size is {}",
Sequence::max().0 .0
);
let window_len = size; // 1MB. should be enough for anybody. (lol)
let window_start = SequenceStart(Sequence::new(rand));
let window_end = SequenceEnd(window_start.0.plus_u32(window_len - 1));
let mut data = Vec::with_capacity(window_len as usize);
data.resize(data.capacity(), 0);
Self {
window_start,
window_end,
pivot: window_len,
data,
missing: [(window_start.0, window_end.0)].to_vec(),
}
}
pub(crate) fn window_size(&self) -> u32 {
self.data.len() as u32
}
pub(crate) fn get(&mut self) -> Vec<u8> {
if self.missing.len() == 0 {
let (first, second) = self.data.split_at(self.pivot as usize);
let mut ret = Vec::with_capacity(self.data.len());
ret.extend_from_slice(first);
ret.extend_from_slice(second);
self.window_start =
SequenceStart(self.window_start.plus_u32(ret.len() as u32));
self.window_end =
SequenceEnd(self.window_end.plus_u32(ret.len() as u32));
self.data.clear();
return ret;
}
let data_len = self.window_start.offset(self.missing[0].0);
let last_missing_idx = self.missing.len() - 1;
let mut last_missing = &mut self.missing[last_missing_idx];
last_missing.1 = last_missing.1.plus_u32(data_len as u32);
self.window_start =
SequenceStart(self.window_start.plus_u32(data_len as u32));
self.window_end =
SequenceEnd(self.window_end.plus_u32(data_len as u32));
let mut ret = Vec::with_capacity(data_len);
let (first, second) = self.data[..].split_at(self.pivot as usize);
let first_len = ::core::cmp::min(data_len, first.len());
let second_len = data_len - first_len;
ret.extend_from_slice(&first[..first_len]);
ret.extend_from_slice(&second[..second_len]);
self.pivot =
((self.pivot as usize + data_len) % self.data.len()) as u32;
ret
}
pub(crate) fn recv(&mut self, chunk: Chunk) -> Result<StreamData, Error> {
if !chunk
.sequence
.is_between(self.window_start, self.window_end)
{
return Err(Error::OutOfWindow);
}
// make sure we consider only the bytes inside the sliding window
let maxlen = ::std::cmp::min(
chunk.sequence.remaining_window(self.window_end) as usize,
chunk.data.len(),
);
if maxlen == 0 {
// empty window or empty chunk, but we don't care
return Err(Error::OutOfWindow);
}
// translate Sequences to offsets in self.data
let data = &chunk.data[..maxlen];
let offset = self.window_start.offset(chunk.sequence);
let offset_end = offset + chunk.data.len() - 1;
// Find the chunks we are missing that we can copy,
// and fix the missing tracker
let mut copy_ranges = Vec::new();
let mut to_delete = Vec::new();
let mut to_add = Vec::new();
// note: the ranges are (INCLUSIVE, INCLUSIVE)
for (idx, el) in self.missing.iter_mut().enumerate() {
let missing_from = self.window_start.offset(el.0);
if missing_from > offset_end {
break;
}
let missing_to = self.window_start.offset(el.1);
if missing_to < offset {
continue;
}
if missing_from >= offset && missing_from <= offset_end {
if missing_to <= offset_end {
// [.....chunk.....]
// [..missing..]
to_delete.push(idx);
copy_ranges.push((missing_from, missing_to));
} else {
// [....chunk....]
// [...missing...]
copy_ranges.push((missing_from, offset_end));
el.0 =
el.0.plus_u32(((offset_end - missing_from) + 1) as u32);
}
} else if missing_from < offset {
if missing_to > offset_end {
// [..chunk..]
// [....missing....]
to_add.push((
el.0.plus_u32(((offset_end - missing_from) + 1) as u32),
el.1,
));
el.1 = el.0.plus_u32(((offset - missing_from) - 1) as u32);
copy_ranges.push((offset, offset_end));
} else if offset <= missing_to {
// [....chunk....]
// [...missing...]
copy_ranges.push((offset, (missing_to - 0)));
el.1 =
el.0.plus_u32(((offset_end - missing_from) - 1) as u32);
}
}
}
{
let mut deleted = 0;
for idx in to_delete.into_iter() {
self.missing.remove(idx + deleted);
deleted = deleted + 1;
}
}
self.missing.append(&mut to_add);
self.missing
.sort_by(|(from_a, _), (from_b, _)| from_a.0 .0.cmp(&from_b.0 .0));
// copy only the missing data
let (first, second) = self.data[..].split_at_mut(self.pivot as usize);
for (from, to) in copy_ranges.into_iter() {
let to = to + 1;
if from <= first.len() {
let first_from = from;
let first_to = ::core::cmp::min(first.len(), to);
let data_first_from = from - offset;
let data_first_to = first_to - offset;
first[first_from..first_to]
.copy_from_slice(&data[data_first_from..data_first_to]);
let second_to = to - first_to;
let data_second_to = data_first_to + second_to;
second[..second_to]
.copy_from_slice(&data[data_first_to..data_second_to]);
} else {
let second_from = from - first.len();
let second_to = to - first.len();
let data_from = from - offset;
let data_to = to - offset;
second[second_from..second_to]
.copy_from_slice(&data[data_from..data_to]);
}
}
if self.missing.len() == 0
|| self.window_start.offset(self.missing[0].0) == 0
{
Ok(StreamData::Ready)
} else {
Ok(StreamData::NotReady)
}
}
}

View File

@ -0,0 +1,249 @@
use crate::{
connection::stream::{self, rob::*, Chunk},
enc::Random,
};
#[::tracing_test::traced_test]
#[test]
fn test_stream_rob_sequential() {
let rand = Random::new();
let mut rob = ReliableOrderedBytestream::with_window_size(&rand, 1048576);
let mut data = Vec::with_capacity(1024);
data.resize(data.capacity(), 0);
rand.fill(&mut data[..]);
let start = rob.window_start.0;
let chunk = Chunk {
id: stream::ID(42),
flag_start: true,
flag_end: false,
sequence: start,
data: &data[..512],
};
let got = rob.get();
assert!(&got[..] == &[], "rob: got data?");
let _ = rob.recv(chunk);
let got = rob.get();
assert!(
&data[..512] == &got[..],
"ROB1: DIFF: {:?} {:?}",
&data[..512].len(),
&got[..].len()
);
let chunk = Chunk {
id: stream::ID(42),
flag_start: false,
flag_end: true,
sequence: start.plus_u32(512),
data: &data[512..],
};
let _ = rob.recv(chunk);
let got = rob.get();
assert!(
&data[512..] == &got[..],
"ROB2: DIFF: {:?} {:?}",
&data[512..].len(),
&got[..].len()
);
}
#[::tracing_test::traced_test]
#[test]
fn test_stream_rob_retransmit() {
let rand = Random::new();
let max_window: usize = 100;
let mut rob =
ReliableOrderedBytestream::with_window_size(&rand, max_window as u32);
let mut data = Vec::with_capacity(120);
data.resize(data.capacity(), 0);
for i in 0..data.len() {
data[i] = i as u8;
}
let start = rob.window_start.0;
let chunk = Chunk {
id: stream::ID(42),
flag_start: true,
flag_end: false,
sequence: start,
data: &data[..40],
};
let _ = rob.recv(chunk);
let chunk = Chunk {
id: stream::ID(42),
flag_start: false,
flag_end: false,
sequence: start.plus_u32(50),
data: &data[50..60],
};
let _ = rob.recv(chunk);
let chunk = Chunk {
id: stream::ID(42),
flag_start: false,
flag_end: false,
sequence: start.plus_u32(40),
data: &data[40..60],
};
let _ = rob.recv(chunk);
let chunk = Chunk {
id: stream::ID(42),
flag_start: false,
flag_end: false,
sequence: start.plus_u32(80),
data: &data[80..],
};
let _ = rob.recv(chunk);
let chunk = Chunk {
id: stream::ID(42),
flag_start: false,
flag_end: false,
sequence: start.plus_u32(50),
data: &data[50..90],
};
let _ = rob.recv(chunk);
let chunk = Chunk {
id: stream::ID(42),
flag_start: false,
flag_end: false,
sequence: start.plus_u32(max_window as u32),
data: &data[max_window..],
};
let _ = rob.recv(chunk);
let chunk = Chunk {
id: stream::ID(42),
flag_start: false,
flag_end: true,
sequence: start.plus_u32(90),
data: &data[90..max_window],
};
let _ = rob.recv(chunk);
let got = rob.get();
assert!(
&data[..max_window] == &got[..],
"DIFF:\n {:?}\n {:?}",
&data[..max_window],
&got[..],
);
}
#[::tracing_test::traced_test]
#[test]
fn test_stream_rob_rolling() {
let rand = Random::new();
let max_window: usize = 100;
let mut rob =
ReliableOrderedBytestream::with_window_size(&rand, max_window as u32);
let mut data = Vec::with_capacity(120);
data.resize(data.capacity(), 0);
for i in 0..data.len() {
data[i] = i as u8;
}
let start = rob.window_start.0;
let chunk = Chunk {
id: stream::ID(42),
flag_start: true,
flag_end: false,
sequence: start,
data: &data[..40],
};
let _ = rob.recv(chunk);
let chunk = Chunk {
id: stream::ID(42),
flag_start: true,
flag_end: false,
sequence: start.plus_u32(50),
data: &data[50..100],
};
let _ = rob.recv(chunk);
let got = rob.get();
assert!(
&data[..40] == &got[..],
"DIFF:\n {:?}\n {:?}",
&data[..40],
&got[..],
);
let chunk = Chunk {
id: stream::ID(42),
flag_start: true,
flag_end: false,
sequence: start.plus_u32(40),
data: &data[40..],
};
let _ = rob.recv(chunk);
let got = rob.get();
assert!(
&data[40..] == &got[..],
"DIFF:\n {:?}\n {:?}",
&data[40..],
&got[..],
);
}
#[::tracing_test::traced_test]
#[test]
fn test_stream_rob_rolling_second_case() {
let rand = Random::new();
let max_window: usize = 100;
let mut rob =
ReliableOrderedBytestream::with_window_size(&rand, max_window as u32);
let mut data = Vec::with_capacity(120);
data.resize(data.capacity(), 0);
for i in 0..data.len() {
data[i] = i as u8;
}
let start = rob.window_start.0;
let chunk = Chunk {
id: stream::ID(42),
flag_start: true,
flag_end: false,
sequence: start,
data: &data[..40],
};
let _ = rob.recv(chunk);
let chunk = Chunk {
id: stream::ID(42),
flag_start: true,
flag_end: false,
sequence: start.plus_u32(50),
data: &data[50..100],
};
let _ = rob.recv(chunk);
let got = rob.get();
assert!(
&data[..40] == &got[..],
"DIFF:\n {:?}\n {:?}",
&data[..40],
&got[..],
);
let chunk = Chunk {
id: stream::ID(42),
flag_start: true,
flag_end: false,
sequence: start.plus_u32(40),
data: &data[40..100],
};
let _ = rob.recv(chunk);
let chunk = Chunk {
id: stream::ID(42),
flag_start: true,
flag_end: false,
sequence: start.plus_u32(100),
data: &data[100..],
};
let _ = rob.recv(chunk);
let got = rob.get();
assert!(
&data[40..] == &got[..],
"DIFF:\n {:?}\n {:?}",
&data[40..],
&got[..],
);
}

View File

@ -43,12 +43,11 @@
//! ]
use crate::{
connection::handshake::HandshakeID,
connection::handshake,
enc::{
self,
asym::{KeyExchangeKind, KeyID, PubKey},
hkdf::HkdfKind,
sym::CipherKind,
hkdf, sym,
},
};
use ::core::num::NonZeroU16;
@ -180,7 +179,7 @@ pub struct Address {
/// Weight of this address in the priority group
pub weight: AddressWeight,
/// List of supported handshakes
pub handshake_ids: Vec<HandshakeID>,
pub handshake_ids: Vec<handshake::ID>,
/// Public key IDs used by this address
pub public_key_idx: Vec<PubKeyIdx>,
}
@ -331,7 +330,7 @@ impl Address {
for raw_handshake_id in
raw[bytes_parsed..(bytes_parsed + num_handshake_ids)].iter()
{
match HandshakeID::from_u8(*raw_handshake_id) {
match handshake::ID::from_u8(*raw_handshake_id) {
Some(h_id) => handshake_ids.push(h_id),
None => {
::tracing::warn!(
@ -392,9 +391,9 @@ pub struct Record {
/// List of supported key exchanges
pub key_exchanges: Vec<KeyExchangeKind>,
/// List of supported key exchanges
pub hkdfs: Vec<HkdfKind>,
pub hkdfs: Vec<hkdf::Kind>,
/// List of supported ciphers
pub ciphers: Vec<CipherKind>,
pub ciphers: Vec<sym::Kind>,
}
impl Record {
@ -597,7 +596,7 @@ impl Record {
num_key_exchanges = num_key_exchanges - 1;
}
while num_hkdfs > 0 {
let hkdf = match HkdfKind::from_u8(raw[bytes_parsed]) {
let hkdf = match hkdf::Kind::from_u8(raw[bytes_parsed]) {
Some(hkdf) => hkdf,
None => {
// continue parsing. This could be a new hkdf type
@ -615,7 +614,7 @@ impl Record {
num_hkdfs = num_hkdfs - 1;
}
while num_ciphers > 0 {
let cipher = match CipherKind::from_u8(raw[bytes_parsed]) {
let cipher = match sym::Kind::from_u8(raw[bytes_parsed]) {
Some(cipher) => cipher,
None => {
// continue parsing. This could be a new cipher type

View File

@ -12,7 +12,7 @@ fn test_dnssec_serialization() {
return;
}
};
use crate::{connection::handshake::HandshakeID, enc};
use crate::{connection::handshake, enc};
let record = Record {
public_keys: [(
@ -25,14 +25,14 @@ fn test_dnssec_serialization() {
port: Some(::core::num::NonZeroU16::new(31337).unwrap()),
priority: record::AddressPriority::P1,
weight: record::AddressWeight::W1,
handshake_ids: [HandshakeID::DirectorySynchronized].to_vec(),
handshake_ids: [handshake::ID::DirectorySynchronized].to_vec(),
public_key_idx: [record::PubKeyIdx(0)].to_vec(),
}]
.to_vec(),
key_exchanges: [enc::asym::KeyExchangeKind::X25519DiffieHellman]
.to_vec(),
hkdfs: [enc::hkdf::HkdfKind::Sha3].to_vec(),
ciphers: [enc::sym::CipherKind::XChaCha20Poly1305].to_vec(),
hkdfs: [enc::hkdf::Kind::Sha3].to_vec(),
ciphers: [enc::sym::Kind::XChaCha20Poly1305].to_vec(),
};
let encoded = match record.encode() {
Ok(encoded) => encoded,

View File

@ -45,7 +45,7 @@ impl ::std::fmt::Display for KeyID {
/// Capabilities of each key
#[derive(Debug, Clone, Copy)]
pub enum KeyCapabilities {
pub enum Capabilities {
/// signing *only*
Sign,
/// encrypt *only*
@ -61,13 +61,13 @@ pub enum KeyCapabilities {
/// All: sign, encrypt, Key Exchange
SignEncryptExchage,
}
impl KeyCapabilities {
impl Capabilities {
/// Check if this key supports eky exchage
pub fn has_exchange(&self) -> bool {
match self {
KeyCapabilities::Exchange
| KeyCapabilities::SignExchange
| KeyCapabilities::SignEncryptExchage => true,
Capabilities::Exchange
| Capabilities::SignExchange
| Capabilities::SignEncryptExchage => true,
_ => false,
}
}
@ -85,7 +85,7 @@ impl KeyCapabilities {
)]
#[non_exhaustive]
#[repr(u8)]
pub enum KeyKind {
pub enum Kind {
/// Ed25519 Public key (sign only)
#[strum(serialize = "ed25519")]
Ed25519 = 0,
@ -93,25 +93,25 @@ pub enum KeyKind {
#[strum(serialize = "x25519")]
X25519,
}
impl KeyKind {
impl Kind {
/// Length of the serialized field
pub const fn len() -> usize {
1
}
/// return the expected length of the public key
pub fn pub_len(&self) -> usize {
KeyKind::len()
Kind::len()
+ match self {
// FIXME: 99% wrong size
KeyKind::Ed25519 => ::ring::signature::ED25519_PUBLIC_KEY_LEN,
KeyKind::X25519 => 32,
Kind::Ed25519 => ::ring::signature::ED25519_PUBLIC_KEY_LEN,
Kind::X25519 => 32,
}
}
/// Get the capabilities of this key type
pub fn capabilities(&self) -> KeyCapabilities {
pub fn capabilities(&self) -> Capabilities {
match self {
KeyKind::Ed25519 => KeyCapabilities::Sign,
KeyKind::X25519 => KeyCapabilities::Exchange,
Kind::Ed25519 => Capabilities::Sign,
Kind::X25519 => Capabilities::Exchange,
}
}
/// Returns the key exchanges supported by this key
@ -120,8 +120,8 @@ impl KeyKind {
const X25519_KEY_EXCHANGES: [KeyExchangeKind; 1] =
[KeyExchangeKind::X25519DiffieHellman];
match self {
KeyKind::Ed25519 => &EMPTY,
KeyKind::X25519 => &X25519_KEY_EXCHANGES,
Kind::Ed25519 => &EMPTY,
Kind::X25519 => &X25519_KEY_EXCHANGES,
}
}
/// generate new keypair
@ -193,21 +193,21 @@ impl PubKey {
}
}
/// return the kind of public key
pub fn kind(&self) -> KeyKind {
pub fn kind(&self) -> Kind {
match self {
// FIXME: lie, we don't fully support this
PubKey::Signing => KeyKind::Ed25519,
PubKey::Signing => Kind::Ed25519,
PubKey::Exchange(ex) => ex.kind(),
}
}
/// generate new keypair
fn new_keypair(
kind: KeyKind,
kind: Kind,
rnd: &Random,
) -> Result<(PrivKey, PubKey), Error> {
match kind {
KeyKind::Ed25519 => todo!(),
KeyKind::X25519 => {
Kind::Ed25519 => todo!(),
Kind::X25519 => {
let (priv_key, pub_key) =
KeyExchangeKind::X25519DiffieHellman.new_keypair(rnd)?;
Ok((PrivKey::Exchange(priv_key), PubKey::Exchange(pub_key)))
@ -231,7 +231,7 @@ impl PubKey {
if raw.len() < 1 {
return Err(Error::NotEnoughData(0));
}
let kind: KeyKind = match KeyKind::from_u8(raw[0]) {
let kind: Kind = match Kind::from_u8(raw[0]) {
Some(kind) => kind,
None => return Err(Error::UnsupportedKey(1)),
};
@ -239,11 +239,11 @@ impl PubKey {
return Err(Error::NotEnoughData(1));
}
match kind {
KeyKind::Ed25519 => {
Kind::Ed25519 => {
::tracing::error!("ed25519 keys are not yet supported");
return Err(Error::Parsing);
}
KeyKind::X25519 => {
Kind::X25519 => {
let pub_key: ::x25519_dalek::PublicKey =
//match ::bincode::deserialize(&raw[1..(1 + kind.pub_len())])
match ::bincode::deserialize(&raw[1..])
@ -284,7 +284,7 @@ impl PrivKey {
}
}
/// return the kind of public key
pub fn kind(&self) -> KeyKind {
pub fn kind(&self) -> Kind {
match self {
PrivKey::Signing => todo!(),
PrivKey::Exchange(ex) => ex.kind(),
@ -322,13 +322,13 @@ impl ExchangePrivKey {
/// Get the serialized key length
pub fn len(&self) -> usize {
match self {
ExchangePrivKey::X25519(_) => KeyKind::X25519.pub_len(),
ExchangePrivKey::X25519(_) => Kind::X25519.pub_len(),
}
}
/// Get the kind of key
pub fn kind(&self) -> KeyKind {
pub fn kind(&self) -> Kind {
match self {
ExchangePrivKey::X25519(_) => KeyKind::X25519,
ExchangePrivKey::X25519(_) => Kind::X25519,
}
}
/// Run the key exchange between two keys of the same kind
@ -372,13 +372,13 @@ impl ExchangePubKey {
/// Get the serialized key length
pub fn len(&self) -> usize {
match self {
ExchangePubKey::X25519(_) => KeyKind::X25519.pub_len(),
ExchangePubKey::X25519(_) => Kind::X25519.pub_len(),
}
}
/// Get the kind of key
pub fn kind(&self) -> KeyKind {
pub fn kind(&self) -> Kind {
match self {
ExchangePubKey::X25519(_) => KeyKind::X25519,
ExchangePubKey::X25519(_) => Kind::X25519,
}
}
/// serialize the key into the buffer
@ -396,13 +396,13 @@ impl ExchangePubKey {
/// The riesult is "unparsed" since we don't verify
/// the actual key
pub fn deserialize(raw: &[u8]) -> Result<(Self, usize), Error> {
match KeyKind::from_u8(raw[0]) {
match Kind::from_u8(raw[0]) {
Some(kind) => match kind {
KeyKind::Ed25519 => {
Kind::Ed25519 => {
::tracing::error!("ed25519 keys are not yet supported");
return Err(Error::Parsing);
}
KeyKind::X25519 => {
Kind::X25519 => {
let pub_key: ::x25519_dalek::PublicKey =
match ::bincode::deserialize(
&raw[1..(1 + kind.pub_len())],

View File

@ -18,12 +18,12 @@ use crate::{config::Config, enc::Secret};
)]
#[non_exhaustive]
#[repr(u8)]
pub enum HkdfKind {
pub enum Kind {
/// Sha3
#[strum(serialize = "sha3")]
Sha3 = 0,
}
impl HkdfKind {
impl Kind {
/// Length of the serialized type
pub const fn len() -> usize {
1
@ -34,7 +34,7 @@ impl HkdfKind {
#[derive(Clone)]
pub enum Hkdf {
/// Sha3 based
Sha3(HkdfSha3),
Sha3(Sha3),
}
// Fake debug implementation to avoid leaking secrets
@ -49,9 +49,9 @@ impl ::core::fmt::Debug for Hkdf {
impl Hkdf {
/// New Hkdf
pub fn new(kind: HkdfKind, salt: &[u8], key: Secret) -> Self {
pub fn new(kind: Kind, salt: &[u8], key: Secret) -> Self {
match kind {
HkdfKind::Sha3 => Self::Sha3(HkdfSha3::new(salt, key)),
Kind::Sha3 => Self::Sha3(Sha3::new(salt, key)),
}
}
/// Get a secret generated from the key and a given context
@ -61,9 +61,9 @@ impl Hkdf {
}
}
/// get the kind of this Hkdf
pub fn kind(&self) -> HkdfKind {
pub fn kind(&self) -> Kind {
match self {
Hkdf::Sha3(_) => HkdfKind::Sha3,
Hkdf::Sha3(_) => Kind::Sha3,
}
}
}
@ -106,11 +106,11 @@ impl Clone for HkdfInner {
/// Sha3 based HKDF
#[derive(Clone)]
pub struct HkdfSha3 {
pub struct Sha3 {
inner: HkdfInner,
}
impl HkdfSha3 {
impl Sha3 {
/// Instantiate a new HKDF with Sha3-256
pub(crate) fn new(salt: &[u8], key: Secret) -> Self {
let hkdf = ::hkdf::Hkdf::<Sha3_256>::new(Some(salt), key.as_ref());
@ -132,7 +132,7 @@ impl HkdfSha3 {
}
// Fake debug implementation to avoid leaking secrets
impl ::core::fmt::Debug for HkdfSha3 {
impl ::core::fmt::Debug for Sha3 {
fn fmt(
&self,
f: &mut core::fmt::Formatter<'_>,
@ -146,8 +146,8 @@ impl ::core::fmt::Debug for HkdfSha3 {
/// Give priority to our list
pub fn server_select_hkdf(
cfg: &Config,
client_supported: &Vec<HkdfKind>,
) -> Option<HkdfKind> {
client_supported: &Vec<Kind>,
) -> Option<Kind> {
cfg.hkdfs
.iter()
.find(|h| client_supported.contains(h))
@ -159,8 +159,8 @@ pub fn server_select_hkdf(
/// this is used only in the directory synchronized handshake
pub fn client_select_hkdf(
cfg: &Config,
server_supported: &Vec<HkdfKind>,
) -> Option<HkdfKind> {
server_supported: &Vec<Kind>,
) -> Option<Kind> {
server_supported
.iter()
.find(|h| cfg.hkdfs.contains(h))

View File

@ -17,20 +17,20 @@ use crate::{
::strum_macros::IntoStaticStr,
)]
#[repr(u8)]
pub enum CipherKind {
pub enum Kind {
/// XChaCha20_Poly1305
#[strum(serialize = "xchacha20poly1305")]
XChaCha20Poly1305 = 0,
}
impl CipherKind {
impl Kind {
/// length of the serialized id for the cipher kind field
pub const fn len() -> usize {
1
}
/// required length of the nonce
pub fn nonce_len(&self) -> HeadLen {
HeadLen(Nonce::len())
pub fn nonce_len(&self) -> NonceLen {
Nonce::len()
}
/// required length of the key
pub fn key_len(&self) -> usize {
@ -48,21 +48,10 @@ impl CipherKind {
#[derive(Debug)]
pub struct AAD<'a>(pub &'a [u8]);
/// Cipher direction, to make sure we don't reuse the same cipher
/// for both decrypting and encrypting
#[derive(Debug, Copy, Clone)]
#[repr(u8)]
pub enum CipherDirection {
/// Receive, to decrypt only
Recv = 0,
/// Send, to encrypt only
Send,
}
/// strong typedef for header length
/// aka: nonce length in the encrypted data)
#[derive(Debug, Copy, Clone)]
pub struct HeadLen(pub usize);
pub struct NonceLen(pub usize);
/// strong typedef for the Tag length
/// aka: cryptographic authentication tag length at the end
/// of the encrypted data
@ -77,21 +66,21 @@ enum Cipher {
impl Cipher {
/// Build a new Cipher
fn new(kind: CipherKind, secret: Secret) -> Self {
fn new(kind: Kind, secret: Secret) -> Self {
match kind {
CipherKind::XChaCha20Poly1305 => {
Kind::XChaCha20Poly1305 => {
Self::XChaCha20Poly1305(XChaCha20Poly1305::new(secret))
}
}
}
pub fn kind(&self) -> CipherKind {
pub fn kind(&self) -> Kind {
match self {
Cipher::XChaCha20Poly1305(_) => CipherKind::XChaCha20Poly1305,
Cipher::XChaCha20Poly1305(_) => Kind::XChaCha20Poly1305,
}
}
fn nonce_len(&self) -> HeadLen {
fn nonce_len(&self) -> NonceLen {
match self {
Cipher::XChaCha20Poly1305(_) => HeadLen(Nonce::len()),
Cipher::XChaCha20Poly1305(_) => Nonce::len(),
}
}
fn tag_len(&self) -> TagLen {
@ -117,7 +106,7 @@ impl Cipher {
return Err(Error::NotEnoughData(raw_data.len()));
}
let (nonce_bytes, data_and_tag) =
raw_data.split_at_mut(Nonce::len());
raw_data.split_at_mut(Nonce::len().0);
let (data_notag, tag_bytes) = data_and_tag.split_at_mut(
data_and_tag.len()
- ::ring::aead::CHACHA20_POLY1305.tag_len(),
@ -137,20 +126,20 @@ impl Cipher {
};
//data.drain(..Nonce::len());
//data.truncate(final_len);
Ok(&raw_data[Nonce::len()..Nonce::len() + final_len])
Ok(&raw_data[Nonce::len().0..Nonce::len().0 + final_len])
}
}
}
fn overhead(&self) -> usize {
match self {
Cipher::XChaCha20Poly1305(_) => {
let cipher = CipherKind::XChaCha20Poly1305;
let cipher = Kind::XChaCha20Poly1305;
cipher.nonce_len().0 + cipher.tag_len().0
}
}
}
fn encrypt(
&self,
&mut self,
nonce: &Nonce,
aad: AAD,
data: &mut [u8],
@ -162,13 +151,13 @@ impl Cipher {
let tag_len: usize = ::ring::aead::CHACHA20_POLY1305.tag_len();
let data_len_notag = data.len() - tag_len;
// write nonce
data[..Nonce::len()].copy_from_slice(nonce.as_bytes());
data[..Nonce::len().0].copy_from_slice(nonce.as_bytes());
// encrypt data
match cipher.cipher.encrypt_in_place_detached(
nonce.as_bytes().into(),
aad.0,
&mut data[Nonce::len()..data_len_notag],
&mut data[Nonce::len().0..data_len_notag],
) {
Ok(tag) => {
data[data_len_notag..].copy_from_slice(tag.as_slice());
@ -194,11 +183,11 @@ impl ::core::fmt::Debug for CipherRecv {
impl CipherRecv {
/// Build a new Cipher
pub fn new(kind: CipherKind, secret: Secret) -> Self {
pub fn new(kind: Kind, secret: Secret) -> Self {
Self(Cipher::new(kind, secret))
}
/// Get the length of the nonce for this cipher
pub fn nonce_len(&self) -> HeadLen {
pub fn nonce_len(&self) -> NonceLen {
self.0.nonce_len()
}
/// Get the length of the nonce for this cipher
@ -215,14 +204,14 @@ impl CipherRecv {
self.0.decrypt(aad, data)
}
/// return the underlying cipher id
pub fn kind(&self) -> CipherKind {
pub fn kind(&self) -> Kind {
self.0.kind()
}
}
/// Send only cipher
pub struct CipherSend {
nonce: NonceSync,
nonce: Nonce,
cipher: Cipher,
}
impl ::core::fmt::Debug for CipherSend {
@ -236,22 +225,30 @@ impl ::core::fmt::Debug for CipherSend {
impl CipherSend {
/// Build a new Cipher
pub fn new(kind: CipherKind, secret: Secret, rand: &Random) -> Self {
pub fn new(kind: Kind, secret: Secret, rand: &Random) -> Self {
Self {
nonce: NonceSync::new(rand),
nonce: Nonce::new(rand),
cipher: Cipher::new(kind, secret),
}
}
/// Encrypt the given data
pub fn encrypt(&self, aad: AAD, data: &mut [u8]) -> Result<(), Error> {
pub fn encrypt(&mut self, aad: AAD, data: &mut [u8]) -> Result<(), Error> {
let old_nonce = self.nonce.advance();
self.cipher.encrypt(&old_nonce, aad, data)?;
Ok(())
}
/// return the underlying cipher id
pub fn kind(&self) -> CipherKind {
pub fn kind(&self) -> Kind {
self.cipher.kind()
}
/// Get the length of the nonce for this cipher
pub fn nonce_len(&self) -> NonceLen {
self.cipher.nonce_len()
}
/// Get the length of the nonce for this cipher
pub fn tag_len(&self) -> TagLen {
self.cipher.tag_len()
}
}
/// XChaCha20Poly1305 cipher
@ -285,7 +282,7 @@ struct NonceNum {
#[repr(C)]
pub union Nonce {
num: NonceNum,
raw: [u8; Self::len()],
raw: [u8; Self::len().0],
}
impl ::core::fmt::Debug for Nonce {
@ -303,17 +300,17 @@ impl ::core::fmt::Debug for Nonce {
impl Nonce {
/// Generate a new random Nonce
pub fn new(rand: &Random) -> Self {
let mut raw = [0; Self::len()];
let mut raw = [0; Self::len().0];
rand.fill(&mut raw);
Self { raw }
}
/// Length of this nonce in bytes
pub const fn len() -> usize {
pub const fn len() -> NonceLen {
// FIXME: was:12. xchacha20poly1305 requires 24.
// but we should change keys much earlier than that, and our
// nonces are not random, but sequential.
// we should change keys every 2^30 bytes to be sure (stream max window)
return 24;
return NonceLen(24);
}
/// Get reference to the nonce bytes
pub fn as_bytes(&self) -> &[u8] {
@ -323,11 +320,12 @@ impl Nonce {
}
}
/// Create Nonce from array
pub fn from_slice(raw: [u8; Self::len()]) -> Self {
pub fn from_slice(raw: [u8; Self::len().0]) -> Self {
Self { raw }
}
/// Go to the next nonce
pub fn advance(&mut self) {
pub fn advance(&mut self) -> Self {
let old_nonce = self.clone();
#[allow(unsafe_code)]
unsafe {
let old_low = self.num.low;
@ -336,40 +334,17 @@ impl Nonce {
self.num.high = self.num.high;
}
}
}
}
/// Synchronize the mutex acess with a nonce for multithread safety
// TODO: remove mutex, not needed anymore
#[derive(Debug)]
pub struct NonceSync {
nonce: ::std::sync::Mutex<Nonce>,
}
impl NonceSync {
/// Create a new thread safe nonce
pub fn new(rand: &Random) -> Self {
Self {
nonce: ::std::sync::Mutex::new(Nonce::new(rand)),
}
}
/// Advance the nonce and return the *old* value
pub fn advance(&self) -> Nonce {
let old_nonce: Nonce;
{
let mut nonce = self.nonce.lock().unwrap();
old_nonce = *nonce;
nonce.advance();
}
old_nonce
}
}
/// Select the best cipher from our supported list
/// and the other endpoint supported list.
/// Give priority to our list
pub fn server_select_cipher(
cfg: &Config,
client_supported: &Vec<CipherKind>,
) -> Option<CipherKind> {
client_supported: &Vec<Kind>,
) -> Option<Kind> {
cfg.ciphers
.iter()
.find(|c| client_supported.contains(c))
@ -381,8 +356,8 @@ pub fn server_select_cipher(
/// This is used only in the Directory synchronized handshake
pub fn client_select_cipher(
cfg: &Config,
server_supported: &Vec<CipherKind>,
) -> Option<CipherKind> {
server_supported: &Vec<Kind>,
) -> Option<Kind> {
server_supported
.iter()
.find(|c| cfg.ciphers.contains(c))

View File

@ -1,24 +1,26 @@
use crate::{
auth,
connection::{handshake::*, ID},
connection::{
handshake::{self, *},
ID,
},
enc::{self, asym::KeyID},
};
#[test]
fn test_simple_encrypt_decrypt() {
let rand = enc::Random::new();
let cipher = enc::sym::CipherKind::XChaCha20Poly1305;
let cipher = enc::sym::Kind::XChaCha20Poly1305;
let secret = enc::Secret::new_rand(&rand);
let secret2 = secret.clone();
let cipher_send = enc::sym::CipherSend::new(cipher, secret, &rand);
let mut cipher_send = enc::sym::CipherSend::new(cipher, secret, &rand);
let cipher_recv = enc::sym::CipherRecv::new(cipher, secret2);
let mut data = Vec::new();
let tot_len = cipher_recv.nonce_len().0 + 1234 + cipher_recv.tag_len().0;
data.resize(tot_len, 0);
rand.fill(&mut data);
data[..enc::sym::Nonce::len()].copy_from_slice(&[0; 24]);
data[..enc::sym::Nonce::len().0].copy_from_slice(&[0; 24]);
let last = data.len() - cipher_recv.tag_len().0;
data[last..].copy_from_slice(&[0; 16]);
let orig = data.clone();
@ -31,7 +33,7 @@ fn test_simple_encrypt_decrypt() {
if cipher_recv.decrypt(aad2, &mut data).is_err() {
assert!(false, "Decrypt failed");
}
data[..enc::sym::Nonce::len()].copy_from_slice(&[0; 24]);
data[..enc::sym::Nonce::len().0].copy_from_slice(&[0; 24]);
let last = data.len() - cipher_recv.tag_len().0;
data[last..].copy_from_slice(&[0; 16]);
assert!(orig == data, "DIFFERENT!\n{:?}\n{:?}\n", orig, data);
@ -40,18 +42,18 @@ fn test_simple_encrypt_decrypt() {
#[test]
fn test_encrypt_decrypt() {
let rand = enc::Random::new();
let cipher = enc::sym::CipherKind::XChaCha20Poly1305;
let cipher = enc::sym::Kind::XChaCha20Poly1305;
let secret = enc::Secret::new_rand(&rand);
let secret2 = secret.clone();
let cipher_send = enc::sym::CipherSend::new(cipher, secret, &rand);
let mut cipher_send = enc::sym::CipherSend::new(cipher, secret, &rand);
let cipher_recv = enc::sym::CipherRecv::new(cipher, secret2);
let nonce_len = cipher_recv.nonce_len();
let tag_len = cipher_recv.tag_len();
let service_key = enc::Secret::new_rand(&rand);
let data = dirsync::RespInner::ClearText(dirsync::RespData {
let data = dirsync::resp::State::ClearText(dirsync::resp::Data {
client_nonce: dirsync::Nonce::new(&rand),
id: ID::ID(::core::num::NonZeroU64::new(424242).unwrap()),
service_connection_id: ID::ID(
@ -60,7 +62,7 @@ fn test_encrypt_decrypt() {
service_key,
});
let resp = dirsync::Resp {
let resp = dirsync::resp::Resp {
client_key_id: KeyID(4444),
data,
};
@ -68,7 +70,7 @@ fn test_encrypt_decrypt() {
let encrypt_to = encrypt_from + resp.encrypted_length(nonce_len, tag_len);
let h_resp =
Handshake::new(HandshakeData::DirSync(dirsync::DirSync::Resp(resp)));
Handshake::new(handshake::Data::DirSync(dirsync::DirSync::Resp(resp)));
let mut bytes = Vec::<u8>::with_capacity(
h_resp.len(cipher.nonce_len(), cipher.tag_len()),
@ -117,7 +119,7 @@ fn test_encrypt_decrypt() {
}
};
// reparse
if let HandshakeData::DirSync(dirsync::DirSync::Resp(r_a)) =
if let handshake::Data::DirSync(dirsync::DirSync::Resp(r_a)) =
&mut deserialized.data
{
let enc_start = r_a.encrypted_offset() + cipher.nonce_len().0;

View File

@ -4,6 +4,10 @@
pub(crate) mod worker;
use crate::inner::worker::Work;
use ::std::{collections::BTreeMap, vec::Vec};
use ::tokio::time::Instant;
/// Track the total number of threads and our index
/// 65K cpus should be enough for anybody
#[derive(Debug, Clone, Copy)]
@ -12,3 +16,109 @@ pub(crate) struct ThreadTracker {
/// Note: starts from 1
pub id: u16,
}
pub(crate) static mut SLEEP_RESOLUTION: ::std::time::Duration =
if cfg!(linux) || cfg!(macos) {
::std::time::Duration::from_millis(1)
} else {
// windows
::std::time::Duration::from_millis(16)
};
pub(crate) async fn set_minimum_sleep_resolution() {
let nanosleep = ::std::time::Duration::from_nanos(1);
let mut tests: usize = 3;
while tests > 0 {
let pre_sleep = ::std::time::Instant::now();
::tokio::time::sleep(nanosleep).await;
let post_sleep = ::std::time::Instant::now();
let slept_for = post_sleep - pre_sleep;
#[allow(unsafe_code)]
unsafe {
if slept_for < SLEEP_RESOLUTION {
SLEEP_RESOLUTION = slept_for;
}
}
tests = tests - 1;
}
}
/// Sleeping has a higher resolution that we would like for packet pacing.
/// So we sleep for however log we need, then chunk up all the work here
/// we will end up chunking the work in SLEEP_RESOLUTION, then we will busy wait
/// for more precise timing
pub(crate) struct Timers {
times: BTreeMap<Instant, Work>,
}
impl Timers {
pub(crate) fn new() -> Self {
Self {
times: BTreeMap::new(),
}
}
pub(crate) fn get_next(&self) -> ::tokio::time::Sleep {
match self.times.keys().next() {
Some(entry) => ::tokio::time::sleep_until((*entry).into()),
None => {
::tokio::time::sleep(::std::time::Duration::from_secs(3600))
}
}
}
pub(crate) fn add(
&mut self,
duration: ::tokio::time::Duration,
work: Work,
) -> ::tokio::time::Instant {
// the returned time is the key in the map.
// Make sure it is unique.
//
// We can be pretty sure we won't do a lot of stuff
// in a single nanosecond, so if we hit a time that is already present
// just add a nanosecond and retry
let mut time = ::tokio::time::Instant::now() + duration;
let mut work = work;
loop {
if let Some(old_val) = self.times.insert(time, work) {
work = self.times.insert(time, old_val).unwrap();
time = time + ::std::time::Duration::from_nanos(1);
} else {
break;
}
}
time
}
pub(crate) fn remove(&mut self, time: ::tokio::time::Instant) {
let _ = self.times.remove(&time);
}
/// Get all the work from now up until now + SLEEP_RESOLUTION
pub(crate) fn get_work(&mut self) -> Vec<Work> {
let now: ::tokio::time::Instant = ::std::time::Instant::now().into();
let mut ret = Vec::with_capacity(4);
let mut count_rm = 0;
#[allow(unsafe_code)]
let next_instant = unsafe { now + SLEEP_RESOLUTION };
let mut iter = self.times.iter_mut().peekable();
loop {
match iter.peek() {
None => break,
Some(next) => {
if *next.0 > next_instant {
break;
}
}
}
let mut work = Work::DropHandshake(crate::enc::asym::KeyID(0));
let mut entry = iter.next().unwrap();
::core::mem::swap(&mut work, &mut entry.1);
ret.push(work);
count_rm = count_rm + 1;
}
while count_rm > 0 {
self.times.pop_first();
count_rm = count_rm - 1;
}
ret
}
}

View File

@ -7,16 +7,18 @@ use crate::{
handshake::{
self,
dirsync::{self, DirSync},
tracker::{HandshakeAction, HandshakeTracker},
Handshake, HandshakeData,
Handshake,
},
packet::{self, Packet},
socket::{UdpClient, UdpServer},
ConnList, Connection, IDSend, Packet,
stream, AuthSrvConn, ConnList, ConnTracker, Connection, IDSend,
ServiceConn,
},
dnssec,
enc::{
self,
asym::{self, KeyID, PrivKey, PubKey},
hkdf::{self, Hkdf, HkdfKind},
hkdf::{self, Hkdf},
sym, Random, Secret,
},
inner::ThreadTracker,
@ -44,6 +46,16 @@ pub(crate) struct ConnectInfo {
// TODO: UserID, Token information
}
/// Connection event. Mostly used to give the data to the user
#[derive(Debug, Eq, PartialEq, Clone)]
#[non_exhaustive]
pub enum Event {
/// Work loop has exited. nothing more to do
End,
/// Data from a connection
Data(Vec<u8>),
}
pub(crate) enum Work {
/// ask the thread to report to the main thread the total number of
/// connections present
@ -51,6 +63,8 @@ pub(crate) enum Work {
Connect(ConnectInfo),
DropHandshake(KeyID),
Recv(RawUdp),
UserSend((ConnTracker, stream::ID, Vec<u8>)),
SendData((ConnTracker, ::tokio::time::Instant)),
}
/// Actual worker implementation.
@ -64,11 +78,13 @@ pub struct Worker {
token_check: Option<Arc<Mutex<TokenChecker>>>,
sockets: Vec<Arc<UdpSocket>>,
queue: ::async_channel::Receiver<Work>,
queue_sender: ::async_channel::Sender<Work>,
queue_timeouts_recv: mpsc::UnboundedReceiver<Work>,
queue_timeouts_send: mpsc::UnboundedSender<Work>,
thread_channels: Vec<::async_channel::Sender<Work>>,
connections: ConnList,
handshakes: HandshakeTracker,
handshakes: handshake::Tracker,
work_timers: super::Timers,
}
#[allow(unsafe_code)]
@ -82,10 +98,11 @@ impl Worker {
token_check: Option<Arc<Mutex<TokenChecker>>>,
sockets: Vec<Arc<UdpSocket>>,
queue: ::async_channel::Receiver<Work>,
queue_sender: ::async_channel::Sender<Work>,
) -> ::std::io::Result<Self> {
let (queue_timeouts_send, queue_timeouts_recv) =
mpsc::unbounded_channel();
let mut handshakes = HandshakeTracker::new(
let mut handshakes = handshake::Tracker::new(
thread_id,
cfg.ciphers.clone(),
cfg.key_exchanges.clone(),
@ -118,16 +135,21 @@ impl Worker {
token_check,
sockets,
queue,
queue_sender,
queue_timeouts_recv,
queue_timeouts_send,
thread_channels: Vec::new(),
connections: ConnList::new(thread_id),
handshakes,
work_timers: super::Timers::new(),
})
}
/// Continuously loop and process work as needed
pub async fn work_loop(&mut self) {
pub async fn work_loop(&mut self) -> Result<Event, crate::Error> {
'mainloop: loop {
let next_timer = self.work_timers.get_next();
::tokio::pin!(next_timer);
let work = ::tokio::select! {
tell_stopped = self.stop_working.recv() => {
if let Ok(stop_ch) = tell_stopped {
@ -136,6 +158,13 @@ impl Worker {
}
break;
}
() = &mut next_timer => {
let work_list = self.work_timers.get_work();
for w in work_list.into_iter() {
let _ = self.queue_sender.send(w).await;
}
continue 'mainloop;
}
maybe_timeout = self.queue.recv() => {
match maybe_timeout {
Ok(work) => work,
@ -298,6 +327,8 @@ impl Worker {
connection::Role::Client,
&self.rand,
);
let dest = UdpClient(addr.as_sockaddr().unwrap());
conn.send_addr = dest;
let auth_recv_id = self.connections.reserve_first();
let service_conn_id = self.connections.reserve_first();
@ -325,39 +356,39 @@ impl Worker {
};
// build request
let auth_info = dirsync::AuthInfo {
let auth_info = dirsync::req::AuthInfo {
user: UserID::new_anonymous(),
token: Token::new_anonymous(&self.rand),
service_id: conn_info.service_id,
domain: conn_info.domain,
};
let req_data = dirsync::ReqData {
let req_data = dirsync::req::Data {
nonce: dirsync::Nonce::new(&self.rand),
client_key_id,
id: auth_recv_id.0, //FIXME: is zero
auth: auth_info,
};
let req = dirsync::Req {
let req = dirsync::req::Req {
key_id: key.0,
exchange,
hkdf: hkdf_selected,
cipher: cipher_selected,
exchange_key: pub_key,
data: dirsync::ReqInner::ClearText(req_data),
data: dirsync::req::State::ClearText(req_data),
};
let encrypt_start = ID::len() + req.encrypted_offset();
let encrypt_start =
connection::ID::len() + req.encrypted_offset();
let encrypt_end = encrypt_start
+ req.encrypted_length(
cipher_selected.nonce_len(),
cipher_selected.tag_len(),
);
let h_req = Handshake::new(HandshakeData::DirSync(
let h_req = Handshake::new(handshake::Data::DirSync(
DirSync::Req(req),
));
use connection::{PacketData, ID};
let packet = Packet {
id: ID::Handshake,
data: PacketData::Handshake(h_req),
id: connection::ID::Handshake,
data: packet::Data::Handshake(h_req),
};
let tot_len = packet.len(
@ -388,15 +419,13 @@ impl Worker {
// send always from the first socket
// FIXME: select based on routing table
let sender = self.sockets[0].local_addr().unwrap();
let dest = UdpClient(addr.as_sockaddr().unwrap());
// start the timeout right before sending the packet
hshake.timeout = Some(::tokio::task::spawn_local(
Self::handshake_timeout(
self.queue_timeouts_send.clone(),
client_key_id,
),
));
let time_drop = self.work_timers.add(
::tokio::time::Duration::from_secs(10),
Work::DropHandshake(client_key_id),
);
hshake.timeout = Some(time_drop);
// send packet
self.send_packet(raw, dest, UdpServer(sender)).await;
@ -415,15 +444,66 @@ impl Worker {
Work::Recv(pkt) => {
self.recv(pkt).await;
}
Work::UserSend((tracker, stream, data)) => {
let conn = match self.connections.get_mut(tracker) {
None => continue,
Some(conn) => conn,
};
use connection::Enqueue;
match conn.enqueue(stream, data) {
Enqueue::Immediate(instant) => {
let _ = self
.queue_sender
.send(Work::SendData((tracker, instant)))
.await;
}
Enqueue::TimerWait => {}
Enqueue::NoSuchStream => {
::tracing::error!(
"Trying to send on unknown stream"
);
}
}
}
Work::SendData((tracker, instant)) => {
// make sure we don't process events before they are
// actually needed.
// This is basically busy waiting with extra steps,
// but we don't want to spawn lots of timers and
// we don't really have a fine-grained sleep that is
// multiplatform
let now = ::tokio::time::Instant::now();
if instant <= now {
let _ = self
.queue_sender
.send(Work::SendData((tracker, instant)))
.await;
continue;
}
let mut raw: Vec<u8> = Vec::with_capacity(1200);
raw.resize(raw.capacity(), 0);
let conn = match self.connections.get_mut(tracker) {
None => continue,
Some(conn) => conn,
};
let pkt = match conn.write_pkt(&mut raw) {
Ok(pkt) => pkt,
Err(enc::Error::NotEnoughData(0)) => continue,
Err(e) => {
::tracing::error!("Packet generation: {:?}", e);
continue;
}
};
let dest = conn.send_addr;
let src = UdpServer(self.sockets[0].local_addr().unwrap());
let len = pkt.len();
raw.truncate(len);
let _ = self.send_packet(raw, dest, src);
}
}
}
}
async fn handshake_timeout(
timeout_queue: mpsc::UnboundedSender<Work>,
key_id: KeyID,
) {
::tokio::time::sleep(::std::time::Duration::from_secs(10)).await;
let _ = timeout_queue.send(Work::DropHandshake(key_id));
Ok(Event::End)
}
/// Read and do stuff with the raw udp packet
async fn recv(&mut self, mut udp: RawUdp) {
@ -447,188 +527,241 @@ impl Worker {
return;
}
};
match action {
HandshakeAction::AuthNeeded(authinfo) => {
let req;
if let HandshakeData::DirSync(DirSync::Req(r)) =
authinfo.handshake.data
{
req = r;
} else {
::tracing::error!("AuthInfo on non DS::Req");
return;
}
use dirsync::ReqInner;
let req_data = match req.data {
ReqInner::ClearText(req_data) => req_data,
_ => {
::tracing::error!("AuthNeeded: expected ClearText");
assert!(false, "AuthNeeded: unreachable");
return;
}
};
// FIXME: This part can take a while,
// we should just spawn it probably
let maybe_auth_check = {
match &self.token_check {
None => {
if req_data.auth.user == auth::USERID_ANONYMOUS
{
Ok(true)
} else {
Ok(false)
}
}
Some(token_check) => {
let tk_check = token_check.lock().await;
tk_check(
req_data.auth.user,
req_data.auth.token,
req_data.auth.service_id,
req_data.auth.domain,
)
.await
}
}
};
let is_authenticated = match maybe_auth_check {
Ok(is_authenticated) => is_authenticated,
Err(_) => {
::tracing::error!("error in token auth");
// TODO: retry?
return;
}
};
if !is_authenticated {
::tracing::warn!(
"Wrong authentication for user {:?}",
req_data.auth.user
);
// TODO: error response
return;
}
// Client has correctly authenticated
// TODO: contact the service, get the key and
// connection ID
let srv_conn_id = ID::new_rand(&self.rand);
let srv_secret = Secret::new_rand(&self.rand);
let head_len = req.cipher.nonce_len();
let tag_len = req.cipher.tag_len();
let mut auth_conn = Connection::new(
authinfo.hkdf,
req.cipher,
connection::Role::Server,
&self.rand,
);
auth_conn.id_send = IDSend(req_data.id);
// track connection
let auth_id_recv = self.connections.reserve_first();
auth_conn.id_recv = auth_id_recv;
let resp_data = dirsync::RespData {
client_nonce: req_data.nonce,
id: auth_conn.id_recv.0,
service_connection_id: srv_conn_id,
service_key: srv_secret,
};
use crate::enc::sym::AAD;
// no aad for now
let aad = AAD(&mut []);
use dirsync::RespInner;
let resp = dirsync::Resp {
client_key_id: req_data.client_key_id,
data: RespInner::ClearText(resp_data),
};
let encrypt_from = ID::len() + resp.encrypted_offset();
let encrypt_until =
encrypt_from + resp.encrypted_length(head_len, tag_len);
let resp_handshake = Handshake::new(
HandshakeData::DirSync(DirSync::Resp(resp)),
);
use connection::{PacketData, ID};
let packet = Packet {
id: ID::new_handshake(),
data: PacketData::Handshake(resp_handshake),
};
let tot_len = packet.len(head_len, tag_len);
let mut raw_out = Vec::<u8>::with_capacity(tot_len);
raw_out.resize(tot_len, 0);
packet.serialize(head_len, tag_len, &mut raw_out);
if let Err(e) = auth_conn
.cipher_send
.encrypt(aad, &mut raw_out[encrypt_from..encrypt_until])
{
::tracing::error!("can't encrypt: {:?}", e);
return;
}
self.send_packet(raw_out, udp.src, udp.dst).await;
self.recv_handshake(udp, action).await;
} else {
self.recv_packet(udp);
}
}
/// Receive a non-handshake packet
fn recv_packet(&mut self, udp: RawUdp) {
let conn = match self.connections.get_id_mut(udp.packet.id) {
None => return,
Some(conn) => conn,
};
match conn.recv(udp) {
Ok(stream::StreamData::NotReady) => {}
Ok(stream::StreamData::Ready) => {
//
}
Err(e) => ::tracing::trace!("Conn Recv: {:?}", e.to_string()),
}
}
/// Receive an handshake packet
async fn recv_handshake(&mut self, udp: RawUdp, action: handshake::Action) {
match action {
handshake::Action::AuthNeeded(authinfo) => {
let req;
if let handshake::Data::DirSync(DirSync::Req(r)) =
authinfo.handshake.data
{
req = r;
} else {
::tracing::error!("AuthInfo on non DS::Req");
return;
}
HandshakeAction::ClientConnect(cci) => {
let ds_resp;
if let HandshakeData::DirSync(DirSync::Resp(resp)) =
cci.handshake.data
{
ds_resp = resp;
} else {
::tracing::error!("ClientConnect on non DS::Resp");
let req_data = match req.data {
dirsync::req::State::ClearText(req_data) => req_data,
_ => {
::tracing::error!("AuthNeeded: expected ClearText");
assert!(false, "AuthNeeded: unreachable");
return;
}
// track connection
let resp_data;
if let dirsync::RespInner::ClearText(r_data) = ds_resp.data
{
resp_data = r_data;
} else {
::tracing::error!(
"ClientConnect on non DS::Resp::ClearText"
);
unreachable!();
};
// FIXME: This part can take a while,
// we should just spawn it probably
let maybe_auth_check = {
match &self.token_check {
None => {
if req_data.auth.user == auth::USERID_ANONYMOUS {
Ok(true)
} else {
Ok(false)
}
}
Some(token_check) => {
let tk_check = token_check.lock().await;
tk_check(
req_data.auth.user,
req_data.auth.token,
req_data.auth.service_id,
req_data.auth.domain,
)
.await
}
}
let auth_srv_conn = IDSend(resp_data.id);
let mut conn = cci.connection;
conn.id_send = auth_srv_conn;
let id_recv = conn.id_recv;
let cipher = conn.cipher_recv.kind();
// track the connection to the authentication server
if self.connections.track(conn.into()).is_err() {
::tracing::error!("Could not track new connection");
};
let is_authenticated = match maybe_auth_check {
Ok(is_authenticated) => is_authenticated,
Err(_) => {
::tracing::error!("error in token auth");
// TODO: retry?
return;
}
};
if !is_authenticated {
::tracing::warn!(
"Wrong authentication for user {:?}",
req_data.auth.user
);
// TODO: error response
return;
}
// Client has correctly authenticated
// TODO: contact the service, get the key and
// connection ID
let srv_conn_id = connection::ID::new_rand(&self.rand);
let srv_secret = Secret::new_rand(&self.rand);
let head_len = req.cipher.nonce_len();
let tag_len = req.cipher.tag_len();
let mut auth_conn = Connection::new(
authinfo.hkdf,
req.cipher,
connection::Role::Server,
&self.rand,
);
auth_conn.id_send = IDSend(req_data.id);
auth_conn.send_addr = udp.src;
// track connection
let auth_id_recv = self.connections.reserve_first();
auth_conn.id_recv = auth_id_recv;
let resp_data = dirsync::resp::Data {
client_nonce: req_data.nonce,
id: auth_conn.id_recv.0,
service_connection_id: srv_conn_id,
service_key: srv_secret,
};
use crate::enc::sym::AAD;
// no aad for now
let aad = AAD(&mut []);
let resp = dirsync::resp::Resp {
client_key_id: req_data.client_key_id,
data: dirsync::resp::State::ClearText(resp_data),
};
let encrypt_from =
connection::ID::len() + resp.encrypted_offset();
let encrypt_until =
encrypt_from + resp.encrypted_length(head_len, tag_len);
let resp_handshake = Handshake::new(handshake::Data::DirSync(
DirSync::Resp(resp),
));
let packet = Packet {
id: connection::ID::new_handshake(),
data: packet::Data::Handshake(resp_handshake),
};
let tot_len = packet.len(head_len, tag_len);
let mut raw_out = Vec::<u8>::with_capacity(tot_len);
raw_out.resize(tot_len, 0);
packet.serialize(head_len, tag_len, &mut raw_out);
if let Err(e) = auth_conn
.cipher_send
.encrypt(aad, &mut raw_out[encrypt_from..encrypt_until])
{
::tracing::error!("can't encrypt: {:?}", e);
return;
}
self.send_packet(raw_out, udp.src, udp.dst).await;
}
handshake::Action::ClientConnect(cci) => {
self.work_timers.remove(cci.old_timeout);
let ds_resp;
if let handshake::Data::DirSync(DirSync::Resp(resp)) =
cci.handshake.data
{
ds_resp = resp;
} else {
::tracing::error!("ClientConnect on non DS::Resp");
return;
}
// track connection
let resp_data;
if let dirsync::resp::State::ClearText(r_data) = ds_resp.data {
resp_data = r_data;
} else {
::tracing::error!(
"ClientConnect on non DS::Resp::ClearText"
);
unreachable!();
}
let auth_id_send = IDSend(resp_data.id);
let mut conn = cci.connection;
conn.id_send = auth_id_send;
let id_recv = conn.id_recv;
let cipher = conn.cipher_recv.kind();
// track the connection to the authentication server
let track_auth_conn = match self.connections.track(conn) {
Ok(track_auth_conn) => track_auth_conn,
Err(_) => {
::tracing::error!(
"Could not track new auth srv connection"
);
self.connections.remove(id_recv);
// FIXME: proper connection closing
let _ = cci.answer.send(Err(
handshake::Error::InternalTracking.into(),
));
return;
}
if cci.service_id != auth::SERVICEID_AUTH {
// create and track the connection to the service
// SECURITY: xor with secrets
//FIXME: the Secret should be XORed with the client
// stored secret (if any)
let hkdf = Hkdf::new(
HkdfKind::Sha3,
cci.service_id.as_bytes(),
resp_data.service_key,
);
let mut service_connection = Connection::new(
hkdf,
cipher,
connection::Role::Client,
&self.rand,
);
service_connection.id_recv = cci.service_connection_id;
service_connection.id_send =
IDSend(resp_data.service_connection_id);
let _ =
self.connections.track(service_connection.into());
}
let _ =
cci.answer.send(Ok((cci.srv_key_id, auth_srv_conn)));
};
let authsrv_conn = AuthSrvConn(connection::Conn {
queue: self.queue_sender.clone(),
fast: track_auth_conn,
});
let mut service_conn = None;
if cci.service_id != auth::SERVICEID_AUTH {
// create and track the connection to the service
// SECURITY: xor with secrets
//FIXME: the Secret should be XORed with the client
// stored secret (if any)
let hkdf = Hkdf::new(
hkdf::Kind::Sha3,
cci.service_id.as_bytes(),
resp_data.service_key,
);
let mut service_connection = Connection::new(
hkdf,
cipher,
connection::Role::Client,
&self.rand,
);
service_connection.id_recv = cci.service_connection_id;
service_connection.id_send =
IDSend(resp_data.service_connection_id);
let track_serv_conn =
match self.connections.track(service_connection) {
Ok(track_serv_conn) => track_serv_conn,
Err(_) => {
::tracing::error!(
"Could not track new service connection"
);
self.connections
.remove(cci.service_connection_id);
// FIXME: proper connection closing
// FIXME: drop auth srv connection if we just
// established it
let _ = cci.answer.send(Err(
handshake::Error::InternalTracking.into(),
));
return;
}
};
service_conn = Some(ServiceConn(connection::Conn {
queue: self.queue_sender.clone(),
fast: track_serv_conn,
}));
}
HandshakeAction::Nothing => {}
};
}
let _ = cci.answer.send(Ok(handshake::tracker::ConnectOk {
auth_key_id: cci.srv_key_id,
auth_id_send,
authsrv_conn,
service_conn,
}));
}
handshake::Action::Nothing => {}
};
}
async fn send_packet(
&self,

View File

@ -34,11 +34,12 @@ use crate::{
AuthServerConnections, Packet,
},
inner::{
worker::{ConnectInfo, RawUdp, Work, Worker},
worker::{ConnectInfo, Event, RawUdp, Work, Worker},
ThreadTracker,
},
};
pub use config::Config;
pub use connection::{AuthSrvConn, ServiceConn};
/// Main fenrir library errors
#[derive(::thiserror::Error, Debug)]
@ -175,6 +176,7 @@ impl Fenrir {
config: &Config,
tokio_rt: Arc<::tokio::runtime::Runtime>,
) -> Result<Self, Error> {
inner::set_minimum_sleep_resolution().await;
let (sender, _) = ::tokio::sync::broadcast::channel(1);
let dnssec = dnssec::Dnssec::new(&config.resolvers)?;
// bind sockets early so we can change "port 0" (aka: random)
@ -213,6 +215,7 @@ impl Fenrir {
pub async fn with_workers(
config: &Config,
) -> Result<(Self, Vec<Worker>), Error> {
inner::set_minimum_sleep_resolution().await;
let (stop_working, _) = ::tokio::sync::broadcast::channel(1);
let dnssec = dnssec::Dnssec::new(&config.resolvers)?;
// bind sockets early so we can change "port 0" (aka: random)
@ -332,7 +335,7 @@ impl Fenrir {
let data: Vec<u8> = buffer[..bytes].to_vec();
// we very likely have multiple threads, pinned to different cpus.
// use the ConnectionID to send the same connection
// use the connection::ID to send the same connection
// to the same thread.
// Handshakes have connection ID 0, so we use the sender's UDP port
@ -341,13 +344,12 @@ impl Fenrir {
Err(_) => continue, // packet way too short, ignore.
};
let thread_idx: usize = {
use connection::packet::ConnectionID;
match packet.id {
ConnectionID::Handshake => {
connection::ID::Handshake => {
let send_port = sock_sender.0.port() as u64;
(send_port % queues_num) as usize
}
ConnectionID::ID(id) => (id.get() % queues_num) as usize,
connection::ID::ID(id) => (id.get() % queues_num) as usize,
}
};
let _ = work_queues[thread_idx]
@ -382,7 +384,7 @@ impl Fenrir {
&self,
domain: &Domain,
service: ServiceID,
) -> Result<(), Error> {
) -> Result<(AuthSrvConn, Option<ServiceConn>), Error> {
let resolved = self.resolv(domain).await?;
self.connect_resolved(resolved, domain, service).await
}
@ -392,7 +394,7 @@ impl Fenrir {
resolved: dnssec::Record,
domain: &Domain,
service: ServiceID,
) -> Result<(), Error> {
) -> Result<(AuthSrvConn, Option<ServiceConn>), Error> {
loop {
// check if we already have a connection to that auth. srv
let is_reserved = {
@ -460,29 +462,28 @@ impl Fenrir {
.await;
match recv.await {
Ok(res) => {
match res {
Err(e) => {
let mut conn_auth_lock =
self.conn_auth_srv.lock().await;
conn_auth_lock.remove_reserved(&resolved);
Err(e)
}
Ok((key_id, id_send)) => {
let key = resolved
.public_keys
.iter()
.find(|k| k.0 == key_id)
.unwrap();
let mut conn_auth_lock =
self.conn_auth_srv.lock().await;
conn_auth_lock.add(&key.1, id_send, &resolved);
//FIXME: user needs to somehow track the connection
Ok(())
}
Ok(res) => match res {
Err(e) => {
let mut conn_auth_lock = self.conn_auth_srv.lock().await;
conn_auth_lock.remove_reserved(&resolved);
Err(e)
}
}
Ok(connections) => {
let key = resolved
.public_keys
.iter()
.find(|k| k.0 == connections.auth_key_id)
.unwrap();
let mut conn_auth_lock = self.conn_auth_srv.lock().await;
conn_auth_lock.add(
&key.1,
connections.auth_id_send,
&resolved,
);
Ok((connections.authsrv_conn, connections.service_conn))
}
},
Err(e) => {
// Thread dropped the sender. no more thread?
let mut conn_auth_lock = self.conn_auth_srv.lock().await;
@ -524,6 +525,7 @@ impl Fenrir {
self.token_check.clone(),
socks,
work_recv,
work_send.clone(),
)
.await?;
// don't keep around private keys too much
@ -547,7 +549,6 @@ impl Fenrir {
}
Ok(worker)
}
// needs to be called before add_sockets
/// Start one working thread for each physical cpu
/// threads are pinned to each cpu core.
@ -589,6 +590,7 @@ impl Fenrir {
let th_tokio_rt = tokio_rt.clone();
let th_config = self.cfg.clone();
let (work_send, work_recv) = ::async_channel::unbounded::<Work>();
let th_work_send = work_send.clone();
let th_stop_working = self.stop_working.subscribe();
let th_token_check = self.token_check.clone();
let th_sockets = sockets.clone();
@ -629,13 +631,23 @@ impl Fenrir {
th_token_check,
th_sockets,
work_recv,
th_work_send,
)
.await
{
Ok(worker) => worker,
Err(_) => return,
};
worker.work_loop().await
loop {
match worker.work_loop().await {
Ok(_) => continue,
Ok(Event::End) => break,
Err(e) => {
::tracing::error!("Worker: {:?}", e);
break;
}
}
}
});
});
loop {

View File

@ -62,10 +62,7 @@ async fn test_connection_dirsync() {
rt.block_on(local_thread);
});
use crate::{
connection::handshake::HandshakeID,
dnssec::{record, Record},
};
use crate::dnssec::{record, Record};
let port: u16 = server.addresses()[0].port();
@ -76,14 +73,14 @@ async fn test_connection_dirsync() {
port: Some(::core::num::NonZeroU16::new(port).unwrap()),
priority: record::AddressPriority::P1,
weight: record::AddressWeight::W1,
handshake_ids: [HandshakeID::DirectorySynchronized].to_vec(),
handshake_ids: [handshake::ID::DirectorySynchronized].to_vec(),
public_key_idx: [record::PubKeyIdx(0)].to_vec(),
}]
.to_vec(),
key_exchanges: [enc::asym::KeyExchangeKind::X25519DiffieHellman]
.to_vec(),
hkdfs: [enc::hkdf::HkdfKind::Sha3].to_vec(),
ciphers: [enc::sym::CipherKind::XChaCha20Poly1305].to_vec(),
hkdfs: [enc::hkdf::Kind::Sha3].to_vec(),
ciphers: [enc::sym::Kind::XChaCha20Poly1305].to_vec(),
};
::tokio::time::sleep(::std::time::Duration::from_millis(500)).await;
@ -91,7 +88,7 @@ async fn test_connection_dirsync() {
.connect_resolved(dnssec_record, &test_domain, auth::SERVICEID_AUTH)
.await
{
Ok(()) => {}
Ok((_, _)) => {}
Err(e) => {
assert!(false, "Err on client connection: {:?}", e);
}