diff --git a/Cargo.toml b/Cargo.toml index 906053a..0544e30 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,9 +35,11 @@ bitmaps = { version = "3.2" } chacha20poly1305 = { version = "0.10" } futures = { version = "0.3" } hkdf = { version = "0.12" } +hwloc2 = {version = "2.2" } libc = { version = "0.2" } num-traits = { version = "0.2" } num-derive = { version = "0.3" } +rand_core = {version = "0.6" } ring = { version = "0.16" } bincode = { version = "1.3" } sha3 = { version = "0.10" } @@ -48,6 +50,7 @@ tokio = { version = "1", features = ["full"] } # PERF: todo linux-only, behind "iouring" feature #tokio-uring = { version = "0.4" } tracing = { version = "0.1" } +tracing-test = { version = "0.2" } trust-dns-resolver = { version = "0.22", features = [ "dnssec-ring" ] } trust-dns-client = { version = "0.22", features = [ "dnssec" ] } trust-dns-proto = { version = "0.22" } @@ -70,3 +73,14 @@ incremental = true codegen-units = 256 rpath = false +[profile.test] +opt-level = 0 +debug = true +debug-assertions = true +overflow-checks = true +lto = false +panic = 'unwind' +incremental = true +codegen-units = 256 +rpath = false + diff --git a/Readme.md b/Readme.md index 7a2f74b..d67fe77 100644 --- a/Readme.md +++ b/Readme.md @@ -15,3 +15,27 @@ you will find the result in `./target/release` If you want to build the `Hati` server, you don't need to build this library separately. Just build the server and it will automatically include this lib +# Developing + +we recommend to use the nix environment, so that you will have +exactly the same environment as the developers. + +just enter the repository directory and run + +``` +nix develop +``` + +and everything should be done for you. + +## Git + +Please configure a pre-commit hook like the one in `var/git-pre-commit` + +``` +cp var/git-pre-commit .git/hooks/pre-commit +``` + +This will run `cargo test --offline` right before your commit, +to make sure that everything compiles and that the test pass + diff --git a/TODO b/TODO new file mode 100644 index 0000000..9531367 --- /dev/null +++ b/TODO @@ -0,0 +1 @@ +* Wrapping for everything that wraps (sigh) diff --git a/flake.lock b/flake.lock index 2920402..7a5770d 100644 --- a/flake.lock +++ b/flake.lock @@ -1,12 +1,15 @@ { "nodes": { "flake-utils": { + "inputs": { + "systems": "systems" + }, "locked": { - "lastModified": 1676283394, - "narHash": "sha256-XX2f9c3iySLCw54rJ/CZs+ZK6IQy7GXNY4nSOyu2QG4=", + "lastModified": 1681202837, + "narHash": "sha256-H+Rh19JDwRtpVPAWp64F+rlEtxUWBAQW28eAi3SRSzg=", "owner": "numtide", "repo": "flake-utils", - "rev": "3db36a8b464d0c4532ba1c7dda728f4576d6d073", + "rev": "cfacdce06f30d2b68473a46042957675eebb3401", "type": "github" }, "original": { @@ -16,12 +19,15 @@ } }, "flake-utils_2": { + "inputs": { + "systems": "systems_2" + }, "locked": { - "lastModified": 1659877975, - "narHash": "sha256-zllb8aq3YO3h8B/U0/J1WBgAL8EX5yWf5pMj3G0NAmc=", + "lastModified": 1681202837, + "narHash": "sha256-H+Rh19JDwRtpVPAWp64F+rlEtxUWBAQW28eAi3SRSzg=", "owner": "numtide", "repo": "flake-utils", - "rev": "c0e246b9b83f637f4681389ecabcb2681b4f3af0", + "rev": "cfacdce06f30d2b68473a46042957675eebb3401", "type": "github" }, "original": { @@ -32,27 +38,27 @@ }, "nixpkgs": { "locked": { - "lastModified": 1677624842, - "narHash": "sha256-4DF9DbDuK4/+KYx0L6XcPBeDHUFVCtzok2fWtwXtb5w=", + "lastModified": 1684922889, + "narHash": "sha256-l0WZAmln8959O7RdYUJ3gnAIM9OPKFLKHKGX4q+Blrk=", "owner": "nixos", "repo": "nixpkgs", - "rev": "d70f5cd5c3bef45f7f52698f39e7cc7a89daa7f0", + "rev": "04aaf8511678a0d0f347fdf1e8072fe01e4a509e", "type": "github" }, "original": { "owner": "nixos", - "ref": "nixos-22.11", + "ref": "nixos-23.05", "repo": "nixpkgs", "type": "github" } }, "nixpkgs-unstable": { "locked": { - "lastModified": 1677407201, - "narHash": "sha256-3blwdI9o1BAprkvlByHvtEm5HAIRn/XPjtcfiunpY7s=", + "lastModified": 1684844536, + "narHash": "sha256-M7HhXYVqAuNb25r/d3FOO0z4GxPqDIZp5UjHFbBgw0Q=", "owner": "nixos", "repo": "nixpkgs", - "rev": "7f5639fa3b68054ca0b062866dc62b22c3f11505", + "rev": "d30264c2691128adc261d7c9388033645f0e742b", "type": "github" }, "original": { @@ -64,11 +70,11 @@ }, "nixpkgs_2": { "locked": { - "lastModified": 1665296151, - "narHash": "sha256-uOB0oxqxN9K7XGF1hcnY+PQnlQJ+3bP2vCn/+Ru/bbc=", + "lastModified": 1681358109, + "narHash": "sha256-eKyxW4OohHQx9Urxi7TQlFBTDWII+F+x2hklDOQPB50=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "14ccaaedd95a488dd7ae142757884d8e125b3363", + "rev": "96ba1c52e54e74c3197f4d43026b3f3d92e83ff9", "type": "github" }, "original": { @@ -92,11 +98,11 @@ "nixpkgs": "nixpkgs_2" }, "locked": { - "lastModified": 1677638104, - "narHash": "sha256-vbdOoDYnQ1QYSchMb3fYGCLYeta3XwmGvMrlXchST5s=", + "lastModified": 1684894917, + "narHash": "sha256-kwKCfmliHIxKuIjnM95TRcQxM/4AAEIZ+4A9nDJ6cJs=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "f388187efb41ce4195b2f4de0b6bb463d3cd0a76", + "rev": "9ea38d547100edcf0da19aaebbdffa2810585495", "type": "github" }, "original": { @@ -104,6 +110,36 @@ "repo": "rust-overlay", "type": "github" } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_2": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } } }, "root": "root", diff --git a/flake.nix b/flake.nix index 89bfdb9..21e2442 100644 --- a/flake.nix +++ b/flake.nix @@ -2,7 +2,7 @@ description = "libFenrir"; inputs = { - nixpkgs.url = "github:nixos/nixpkgs/nixos-22.11"; + nixpkgs.url = "github:nixos/nixpkgs/nixos-23.05"; nixpkgs-unstable.url = "github:nixos/nixpkgs/nixos-unstable"; rust-overlay.url = "github:oxalica/rust-overlay"; flake-utils.url = "github:numtide/flake-utils"; @@ -15,6 +15,10 @@ pkgs = import nixpkgs { inherit system overlays; }; + pkgs-unstable = import nixpkgs-unstable { + inherit system overlays; + }; + RUST_VERSION="1.69.0"; in { devShells.default = pkgs.mkShell { @@ -34,16 +38,19 @@ #}) clippy cargo-watch + cargo-flamegraph cargo-license lld - rust-bin.stable.latest.default + rust-bin.stable.${RUST_VERSION}.default rustfmt rust-analyzer + # fenrir deps + hwloc ]; shellHook = '' # use zsh or other custom shell USER_SHELL="$(grep $USER /etc/passwd | cut -d ':' -f 7)" - if [ -n "$USER_SHELL" ] && [ "$USER_SHELL" != "$SHELL" ]; then + if [ -n "$USER_SHELL" ]; then exec $USER_SHELL fi ''; diff --git a/src/Architecture.md b/src/Architecture.md index b9f0c66..9e547d0 100644 --- a/src/Architecture.md +++ b/src/Architecture.md @@ -1,8 +1,17 @@ # Architecture -For now we will keep things as easy as possible, not caring about the performance. +The current architecture is based on tokio. +In the future we might want something more general purpose (and no-std), +but good enough for now. -This means we will use only ::tokio and spawn one job per connection +Tokio has its own thread pool, and we spawn one async loop per listening socket. + +Then we spawn our own thread pool. +These threads are pinned to the cpu cores. +We spawn one async worker per thread, making sure it remains pinned to that core. +This is done to avoid *a lot* of locking and multithread syncronizations. + +We do connection sharding on the connection id. # Future @@ -12,8 +21,6 @@ scaling and work queues: https://tokio.rs/blog/2019-10-scheduler What we want to do is: -* one thread per core -* thread pinning * ebpf BBR packet pacing * need to support non-ebpf BBR for mobile ios, too * connection sharding diff --git a/src/auth/mod.rs b/src/auth/mod.rs index 17f43b9..085816b 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -1,11 +1,13 @@ -//! Authentication reslated struct definitions +//! Authentication related struct definitions -use ::ring::rand::SecureRandom; +use crate::enc::Random; use ::zeroize::Zeroize; +/// Anonymous user id +pub const USERID_ANONYMOUS: UserID = UserID([0; UserID::len()]); /// User identifier. 16 bytes for easy uuid conversion -#[derive(Debug, Copy, Clone)] -pub struct UserID([u8; 16]); +#[derive(Debug, Copy, Clone, PartialEq)] +pub struct UserID(pub [u8; 16]); impl From<[u8; 16]> for UserID { fn from(raw: [u8; 16]) -> Self { @@ -15,10 +17,18 @@ impl From<[u8; 16]> for UserID { impl UserID { /// New random user id - pub fn new(rand: &::ring::rand::SystemRandom) -> Self { - let mut ret = Self([0; 16]); - rand.fill(&mut ret.0); - ret + pub fn new(rand: &Random) -> Self { + use ::core::mem::MaybeUninit; + let mut out: MaybeUninit<[u8; 16]> = MaybeUninit::uninit(); + #[allow(unsafe_code)] + unsafe { + let _ = rand.fill(out.assume_init_mut()); + Self(out.assume_init()) + } + } + /// Anonymous user id + pub const fn new_anonymous() -> Self { + USERID_ANONYMOUS } /// length of the User ID in bytes pub const fn len() -> usize { @@ -26,11 +36,21 @@ impl UserID { } } /// Authentication Token, basically just 32 random bytes -#[derive(Clone, Zeroize)] +#[derive(Clone, Zeroize, PartialEq)] #[zeroize(drop)] -pub struct Token([u8; 32]); +pub struct Token(pub [u8; 32]); impl Token { + /// New random token, anonymous should not check this anyway + pub fn new_anonymous(rand: &Random) -> Self { + use ::core::mem::MaybeUninit; + let mut out: MaybeUninit<[u8; 32]> = MaybeUninit::uninit(); + #[allow(unsafe_code)] + unsafe { + let _ = rand.fill(out.assume_init_mut()); + Self(out.assume_init()) + } + } /// length of the token in bytes pub const fn len() -> usize { 32 @@ -53,12 +73,22 @@ impl ::core::fmt::Debug for Token { } } +/// Type of the function used to check the validity of the tokens +/// Reimplement this to use whatever database you want +pub type TokenChecker = + fn( + user: UserID, + token: Token, + service_id: ServiceID, + domain: Domain, + ) -> ::futures::future::BoxFuture<'static, Result>; + /// domain representation /// Security notice: internal representation is utf8, but we will /// further limit to a "safe" subset of utf8 // SECURITY: TODO: limit to a subset of utf8 #[derive(Debug, Clone, PartialEq)] -pub struct Domain(String); +pub struct Domain(pub String); impl TryFrom<&[u8]> for Domain { type Error = (); @@ -70,6 +100,16 @@ impl TryFrom<&[u8]> for Domain { Ok(Domain(domain_string)) } } +impl From for Domain { + fn from(raw: String) -> Self { + Self(raw) + } +} +impl From<&str> for Domain { + fn from(raw: &str) -> Self { + Self(raw.to_owned()) + } +} impl Domain { /// length of the User ID in bytes @@ -78,9 +118,11 @@ impl Domain { } } +/// Reserve the first service ID for the authentication service +pub const SERVICEID_AUTH: ServiceID = ServiceID([0; 16]); /// The Service ID is a UUID associated with the service. -#[derive(Debug, Copy, Clone)] -pub struct ServiceID([u8; 16]); +#[derive(Debug, Copy, Clone, PartialEq)] +pub struct ServiceID(pub [u8; 16]); impl From<[u8; 16]> for ServiceID { fn from(raw: [u8; 16]) -> Self { @@ -93,4 +135,8 @@ impl ServiceID { pub const fn len() -> usize { 16 } + /// read the service id as bytes + pub fn as_bytes(&self) -> &[u8; 16] { + &self.0 + } } diff --git a/src/config/mod.rs b/src/config/mod.rs index e3fd4c8..f196ac9 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,6 +1,14 @@ //! //! Configuration to initialize the Fenrir networking library +use crate::{ + connection::handshake::HandshakeID, + enc::{ + asym::{KeyExchangeKind, KeyID, PrivKey, PubKey}, + hkdf::HkdfKind, + sym::CipherKind, + }, +}; use ::std::{ net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, num::NonZeroUsize, @@ -8,6 +16,23 @@ use ::std::{ vec, }; +/// Key used by a server during the handshake +#[derive(Clone, Debug)] +pub struct ServerKey { + pub id: KeyID, + pub priv_key: PrivKey, + pub pub_key: PubKey, +} + +/// Authentication Server information and keys +#[derive(Clone, Debug)] +pub struct AuthServer { + /// fqdn of the authentication server + pub fqdn: crate::auth::Domain, + /// list of key ids enabled for this domain + pub keys: Vec, +} + /// Main config for libFenrir #[derive(Clone, Debug)] pub struct Config { @@ -18,6 +43,20 @@ pub struct Config { pub listen: Vec, /// List of DNS resolvers to use pub resolvers: Vec, + /// Supported handshakes + pub handshakes: Vec, + /// Supported key exchanges + pub key_exchanges: Vec, + /// Supported Hkdfs + pub hkdfs: Vec, + /// Supported Ciphers + pub ciphers: Vec, + /// list of authentication servers + /// clients will have this empty + pub servers: Vec, + /// list of public/private keys + /// clients should have this empty + pub server_keys: Vec, } impl Default for Config { @@ -34,6 +73,12 @@ impl Default for Config { ), ], resolvers: Vec::new(), + handshakes: [HandshakeID::DirectorySynchronized].to_vec(), + key_exchanges: [KeyExchangeKind::X25519DiffieHellman].to_vec(), + hkdfs: [HkdfKind::Sha3].to_vec(), + ciphers: [CipherKind::XChaCha20Poly1305].to_vec(), + servers: Vec::new(), + server_keys: Vec::new(), } } } diff --git a/src/connection/handshake/dirsync.rs b/src/connection/handshake/dirsync.rs index 9c0a299..0be2dd4 100644 --- a/src/connection/handshake/dirsync.rs +++ b/src/connection/handshake/dirsync.rs @@ -11,20 +11,45 @@ use super::{Error, HandshakeData}; use crate::{ auth, - connection::ID, + connection::{ProtocolVersion, ID}, enc::{ - asym::{ExchangePubKey, KeyExchange, KeyID}, - sym::{CipherKind, Secret}, + asym::{ExchangePubKey, KeyExchangeKind, KeyID}, + hkdf::HkdfKind, + sym::{CipherKind, HeadLen, TagLen}, + Random, Secret, }, }; -use ::arrayref::array_mut_ref; -use ::std::{collections::VecDeque, num::NonZeroU64, vec::Vec}; +// TODO: merge with crate::enc::sym::Nonce +/// random nonce +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct Nonce(pub(crate) [u8; 16]); -type Nonce = [u8; 16]; +impl Nonce { + /// Create a new random Nonce + pub fn new(rnd: &Random) -> Self { + use ::core::mem::MaybeUninit; + let mut out: MaybeUninit<[u8; 16]>; + #[allow(unsafe_code)] + unsafe { + out = MaybeUninit::uninit(); + let _ = rnd.fill(out.assume_init_mut()); + Self(out.assume_init()) + } + } + /// Length of the serialized Nonce + pub const fn len() -> usize { + 16 + } +} +impl From<&[u8; 16]> for Nonce { + fn from(raw: &[u8; 16]) -> Self { + Self(raw.clone()) + } +} /// Parsed handshake -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub enum DirSync { /// Directory synchronized handshake: client request Req(Req), @@ -34,55 +59,104 @@ pub enum DirSync { impl DirSync { /// actual length of the dirsync handshake data - pub fn len(&self) -> usize { + pub fn len(&self, head_len: HeadLen, tag_len: TagLen) -> usize { match self { DirSync::Req(req) => req.len(), - DirSync::Resp(resp) => resp.len(), + DirSync::Resp(resp) => resp.len(head_len, tag_len), } } /// Serialize into raw bytes /// NOTE: assumes that there is exactly asa much buffer as needed - pub fn serialize(&self, out: &mut [u8]) { + pub fn serialize( + &self, + head_len: HeadLen, + tag_len: TagLen, + out: &mut [u8], + ) { match self { - DirSync::Req(req) => req.serialize(out), - DirSync::Resp(resp) => resp.serialize(out), + DirSync::Req(req) => req.serialize(head_len, tag_len, out), + DirSync::Resp(resp) => resp.serialize(head_len, tag_len, out), } } } /// Client request of a directory synchronized handshake -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct Req { /// Id of the server key used for the key exchange pub key_id: KeyID, /// Selected key exchange - pub exchange: KeyExchange, + pub exchange: KeyExchangeKind, + /// Selected hkdf + pub hkdf: HkdfKind, /// Selected cipher pub cipher: CipherKind, /// Client ephemeral public key used for key exchanges pub exchange_key: ExchangePubKey, /// encrypted data pub data: ReqInner, + // SECURITY: TODO: Add padding to min: 1200 bytes + // to avoid amplification attaks + // also: 1200 < 1280 to allow better vpn compatibility } impl Req { - /// Set the cleartext data after it was parsed - pub fn set_data(&mut self, data: ReqData) { - self.data = ReqInner::Data(data); + /// return the offset of the encrypted data + /// NOTE: starts from the beginning of the fenrir packet + pub fn encrypted_offset(&self) -> usize { + ProtocolVersion::len() + + crate::handshake::HandshakeID::len() + + KeyID::len() + + KeyExchangeKind::len() + + HkdfKind::len() + + CipherKind::len() + + self.exchange_key.kind().pub_len() + } + /// return the total length of the cleartext data + pub fn encrypted_length( + &self, + head_len: HeadLen, + tag_len: TagLen, + ) -> usize { + match &self.data { + ReqInner::ClearText(data) => data.len() + head_len.0 + tag_len.0, + ReqInner::CipherText(length) => *length, + } } /// actual length of the directory synchronized request pub fn len(&self) -> usize { KeyID::len() - + KeyExchange::len() + + KeyExchangeKind::len() + + HkdfKind::len() + CipherKind::len() - + self.exchange_key.len() + + self.exchange_key.kind().pub_len() + + self.cipher.nonce_len().0 + self.data.len() + + self.cipher.tag_len().0 } /// Serialize into raw bytes - /// NOTE: assumes that there is exactly asa much buffer as needed - pub fn serialize(&self, out: &mut [u8]) { - //assert!(out.len() > , ": not enough buffer to serialize"); - todo!() + /// NOTE: assumes that there is exactly as much buffer as needed + pub fn serialize( + &self, + head_len: HeadLen, + tag_len: TagLen, + out: &mut [u8], + ) { + out[0..2].copy_from_slice(&self.key_id.0.to_le_bytes()); + out[2] = self.exchange as u8; + out[3] = self.hkdf as u8; + out[4] = self.cipher as u8; + let key_len = self.exchange_key.len(); + let written_next = 5 + key_len; + self.exchange_key.serialize_into(&mut out[5..written_next]); + let written = written_next; + if let ReqInner::ClearText(data) = &self.data { + let from = written + head_len.0; + let to = out.len() - tag_len.0; + data.serialize(&mut out[from..to]); + } else { + unreachable!(); + } } } @@ -93,27 +167,34 @@ impl super::HandshakeParsing for Req { return Err(Error::NotEnoughData); } let key_id: KeyID = - KeyID(u16::from_le_bytes(raw[0..1].try_into().unwrap())); + KeyID(u16::from_le_bytes(raw[0..2].try_into().unwrap())); use ::num_traits::FromPrimitive; - let exchange: KeyExchange = match KeyExchange::from_u8(raw[2]) { + let exchange: KeyExchangeKind = match KeyExchangeKind::from_u8(raw[2]) { Some(exchange) => exchange, None => return Err(Error::Parsing), }; - let cipher: CipherKind = match CipherKind::from_u8(raw[3]) { + let hkdf: HkdfKind = match HkdfKind::from_u8(raw[3]) { + Some(exchange) => exchange, + None => return Err(Error::Parsing), + }; + let cipher: CipherKind = match CipherKind::from_u8(raw[4]) { Some(cipher) => cipher, None => return Err(Error::Parsing), }; - let (exchange_key, len) = match ExchangePubKey::from_slice(&raw[4..]) { - Ok(exchange_key) => exchange_key, - Err(e) => return Err(e.into()), - }; - let mut vec = VecDeque::with_capacity(raw.len() - (4 + len)); - vec.extend(raw[(4 + len)..].iter().copied()); - let _ = vec.make_contiguous(); - let data = ReqInner::CipherText(vec); + const CURR_SIZE: usize = KeyID::len() + + KeyExchangeKind::len() + + HkdfKind::len() + + CipherKind::len(); + let (exchange_key, len) = + match ExchangePubKey::deserialize(&raw[CURR_SIZE..]) { + Ok(exchange_key) => exchange_key, + Err(e) => return Err(e.into()), + }; + let data = ReqInner::CipherText(raw.len() - (CURR_SIZE + len)); Ok(HandshakeData::DirSync(DirSync::Req(Self { key_id, exchange, + hkdf, cipher, exchange_key, data, @@ -122,47 +203,46 @@ impl super::HandshakeParsing for Req { } /// Quick way to avoid mixing cipher and clear text -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub enum ReqInner { - /// Client data, still in ciphertext - CipherText(VecDeque), - /// Client data, decrypted but unprocessed - ClearText(VecDeque), - /// Parsed client data - Data(ReqData), + /// Data is still encrytped, we only keep the length + CipherText(usize), + /// Client data, decrypted and parsed + ClearText(ReqData), } impl ReqInner { /// The length of the data pub fn len(&self) -> usize { match self { - ReqInner::CipherText(c) => c.len(), - ReqInner::ClearText(c) => c.len(), - ReqInner::Data(d) => d.len(), + ReqInner::CipherText(len) => *len, + ReqInner::ClearText(data) => data.len(), } } - /// Get the ciptertext, or panic - pub fn ciphertext<'a>(&'a mut self) -> &'a mut VecDeque { - match self { - ReqInner::CipherText(data) => data, - _ => panic!(), - } - } - /// switch from ciphertext to cleartext - pub fn mark_as_cleartext(&mut self) { - let mut newdata: VecDeque; - match self { - ReqInner::CipherText(data) => { - newdata = VecDeque::new(); - ::core::mem::swap(&mut newdata, data); + /// parse the cleartext + pub fn deserialize_as_cleartext( + &mut self, + raw: &[u8], + ) -> Result<(), Error> { + let clear = match self { + ReqInner::CipherText(len) => { + assert!( + *len > raw.len(), + "DirSync::ReqInner::CipherText length mismatch" + ); + match ReqData::deserialize(raw) { + Ok(clear) => clear, + Err(e) => return Err(e), + } } - _ => return, - } - *self = ReqInner::ClearText(newdata); + _ => return Err(Error::Parsing), + }; + *self = ReqInner::ClearText(clear); + Ok(()) } } /// Informations needed for authentication -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct AuthInfo { /// User of the domain pub user: auth::UserID, @@ -184,6 +264,21 @@ impl AuthInfo { pub fn len(&self) -> usize { Self::MIN_PKT_LEN + self.domain.len() } + /// serialize into a buffer + /// Note: assumes there is enough space + pub fn serialize(&self, out: &mut [u8]) { + out[..auth::UserID::len()].copy_from_slice(&self.user.0); + const WRITTEN_TOKEN: usize = auth::UserID::len() + auth::Token::len(); + out[auth::UserID::len()..WRITTEN_TOKEN].copy_from_slice(&self.token.0); + const WRITTEN_SERVICE_ID: usize = + WRITTEN_TOKEN + auth::ServiceID::len(); + out[WRITTEN_TOKEN..WRITTEN_SERVICE_ID] + .copy_from_slice(&self.service_id.0); + let domain_len = self.domain.0.as_bytes().len() as u8; + out[WRITTEN_SERVICE_ID] = domain_len; + const WRITTEN_DOMAIN_LEN: usize = WRITTEN_SERVICE_ID + 1; + out[WRITTEN_DOMAIN_LEN..].copy_from_slice(&self.domain.0.as_bytes()); + } /// deserialize from raw bytes pub fn deserialize(raw: &[u8]) -> Result { if raw.len() < Self::MIN_PKT_LEN { @@ -225,7 +320,7 @@ impl AuthInfo { } /// Decrypted request data -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct ReqData { /// Random nonce, the client can use this to track multiple key exchanges pub nonce: Nonce, @@ -239,24 +334,33 @@ pub struct ReqData { impl ReqData { /// actual length of the request data pub fn len(&self) -> usize { - self.nonce.len() + KeyID::len() + ID::len() + self.auth.len() + Nonce::len() + KeyID::len() + ID::len() + self.auth.len() } /// Minimum byte length of the request data pub const MIN_PKT_LEN: usize = 16 + KeyID::len() + ID::len() + AuthInfo::MIN_PKT_LEN; + /// serialize into a buffer + /// Note: assumes there is enough space + pub fn serialize(&self, out: &mut [u8]) { + out[..Nonce::len()].copy_from_slice(&self.nonce.0); + const WRITTEN_KEY: usize = Nonce::len() + KeyID::len(); + out[Nonce::len()..WRITTEN_KEY] + .copy_from_slice(&self.client_key_id.0.to_le_bytes()); + const WRITTEN: usize = WRITTEN_KEY; + const WRITTEN_ID: usize = WRITTEN + 8; + out[WRITTEN..WRITTEN_ID] + .copy_from_slice(&self.id.as_u64().to_le_bytes()); + self.auth.serialize(&mut out[WRITTEN_ID..]); + } /// Parse the cleartext raw data - pub fn deserialize(raw: &ReqInner) -> Result { - let raw = match raw { - // raw is VecDeque, assume everything is on the first slice - ReqInner::ClearText(raw) => raw.as_slices().0, - _ => return Err(Error::Parsing), - }; + pub fn deserialize(raw: &[u8]) -> Result { if raw.len() < Self::MIN_PKT_LEN { return Err(Error::NotEnoughData); } let mut start = 0; let mut end = 16; - let nonce: Nonce = raw[start..end].try_into().unwrap(); + let raw_sized: &[u8; 16] = raw[start..end].try_into().unwrap(); + let nonce: Nonce = raw_sized.into(); start = end; end = end + KeyID::len(); let client_key_id = @@ -280,72 +384,178 @@ impl ReqData { } } +/// Quick way to avoid mixing cipher and clear text +#[derive(Debug, Clone, PartialEq)] +pub enum RespInner { + /// Server data, still in ciphertext + CipherText(usize), + /// Parsed, cleartext server data + ClearText(RespData), +} +impl RespInner { + /// The length of the data + pub fn len(&self) -> usize { + match self { + RespInner::CipherText(len) => *len, + RespInner::ClearText(_) => RespData::len(), + } + } + /// parse the cleartext + pub fn deserialize_as_cleartext( + &mut self, + raw: &[u8], + ) -> Result<(), Error> { + let clear = match self { + RespInner::CipherText(len) => { + assert!( + *len > raw.len(), + "DirSync::RespInner::CipherText length mismatch" + ); + match RespData::deserialize(raw) { + Ok(clear) => clear, + Err(e) => return Err(e), + } + } + _ => return Err(Error::Parsing), + }; + *self = RespInner::ClearText(clear); + Ok(()) + } + /// Serialize the still cleartext data + pub fn serialize(&self, out: &mut [u8]) { + if let RespInner::ClearText(clear) = &self { + clear.serialize(out); + } + } +} + /// Server response in a directory synchronized handshake -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct Resp { /// Tells the client with which key the exchange was done pub client_key_id: KeyID, - /// encrypted data - pub enc: Vec, + /// actual response data, might be encrypted + pub data: RespInner, } impl super::HandshakeParsing for Resp { fn deserialize(raw: &[u8]) -> Result { - todo!() + const MIN_PKT_LEN: usize = 68; + if raw.len() < MIN_PKT_LEN { + return Err(Error::NotEnoughData); + } + let client_key_id: KeyID = + KeyID(u16::from_le_bytes(raw[0..KeyID::len()].try_into().unwrap())); + Ok(HandshakeData::DirSync(DirSync::Resp(Self { + client_key_id, + data: RespInner::CipherText(raw[KeyID::len()..].len()), + }))) } } impl Resp { + /// return the offset of the encrypted data + /// NOTE: starts from the beginning of the fenrir packet + pub fn encrypted_offset(&self) -> usize { + ProtocolVersion::len() + + crate::connection::handshake::HandshakeID::len() + + KeyID::len() + } + /// return the total length of the cleartext data + pub fn encrypted_length( + &self, + head_len: HeadLen, + tag_len: TagLen, + ) -> usize { + match &self.data { + RespInner::ClearText(_data) => { + RespData::len() + head_len.0 + tag_len.0 + } + RespInner::CipherText(len) => *len, + } + } /// Total length of the response handshake - pub fn len(&self) -> usize { - KeyID::len() + self.enc.len() + pub fn len(&self, head_len: HeadLen, tag_len: TagLen) -> usize { + KeyID::len() + head_len.0 + self.data.len() + tag_len.0 } /// Serialize into raw bytes - /// NOTE: assumes that there is exactly asa much buffer as needed - pub fn serialize(&self, out: &mut [u8]) { - assert!( - out.len() == KeyID::len() + self.enc.len(), - "DirSync Resp: not enough buffer to serialize" - ); - self.client_key_id.serialize(array_mut_ref![out, 0, 2]); - out[2..].copy_from_slice(&self.enc[..]); + /// NOTE: assumes that there is exactly as much buffer as needed + pub fn serialize( + &self, + head_len: HeadLen, + _tag_len: TagLen, + out: &mut [u8], + ) { + out[0..KeyID::len()] + .copy_from_slice(&self.client_key_id.0.to_le_bytes()); + let start_data = KeyID::len() + head_len.0; + let end_data = start_data + self.data.len(); + self.data.serialize(&mut out[start_data..end_data]); } } /// Decrypted response data -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct RespData { /// Client nonce, copied from the request pub client_nonce: Nonce, /// Server Connection ID pub id: ID, /// Service Connection ID - pub service_id: ID, + pub service_connection_id: ID, /// Service encryption key pub service_key: Secret, } impl RespData { - const NONCE_LEN: usize = ::core::mem::size_of::(); /// Return the expected length for buffer allocation pub fn len() -> usize { - Self::NONCE_LEN + ID::len() + ID::len() + 32 + Nonce::len() + ID::len() + ID::len() + Secret::len() } /// Serialize the data into a buffer /// NOTE: assumes that there is exactly asa much buffer as needed pub fn serialize(&self, out: &mut [u8]) { - assert!(out.len() == Self::len(), "wrong buffer size"); let mut start = 0; - let mut end = Self::NONCE_LEN; - out[start..end].copy_from_slice(&self.client_nonce); + let mut end = Nonce::len(); + out[start..end].copy_from_slice(&self.client_nonce.0); start = end; - end = end + Self::NONCE_LEN; + end = end + ID::len(); self.id.serialize(&mut out[start..end]); start = end; - end = end + Self::NONCE_LEN; - self.service_id.serialize(&mut out[start..end]); + end = end + ID::len(); + self.service_connection_id.serialize(&mut out[start..end]); start = end; - end = end + Self::NONCE_LEN; + end = end + Secret::len(); out[start..end].copy_from_slice(self.service_key.as_ref()); } + /// Parse the cleartext raw data + pub fn deserialize(raw: &[u8]) -> Result { + let raw_sized: &[u8; 16] = raw[..Nonce::len()].try_into().unwrap(); + let client_nonce: Nonce = raw_sized.into(); + let end = Nonce::len() + ID::len(); + let id: ID = + u64::from_le_bytes(raw[Nonce::len()..end].try_into().unwrap()) + .into(); + if id.is_handshake() { + return Err(Error::Parsing); + } + let parsed = end; + let end = parsed + ID::len(); + let service_connection_id: ID = + u64::from_le_bytes(raw[parsed..end].try_into().unwrap()).into(); + if service_connection_id.is_handshake() { + return Err(Error::Parsing); + } + let parsed = end; + let end = parsed + Secret::len(); + let raw_secret: &[u8; 32] = raw[parsed..end].try_into().unwrap(); + let service_key = raw_secret.into(); + + Ok(Self { + client_nonce, + id, + service_connection_id, + service_key, + }) + } } diff --git a/src/connection/handshake/mod.rs b/src/connection/handshake/mod.rs index 763191d..b5204a1 100644 --- a/src/connection/handshake/mod.rs +++ b/src/connection/handshake/mod.rs @@ -1,13 +1,19 @@ //! Handhsake handling pub mod dirsync; +#[cfg(test)] +mod tests; +pub(crate) mod tracker; +use crate::{ + connection::ProtocolVersion, + enc::sym::{HeadLen, TagLen}, +}; use ::num_traits::FromPrimitive; -use crate::connection::{self, ProtocolVersion}; - /// Handshake errors #[derive(::thiserror::Error, Debug, Copy, Clone)] +#[non_exhaustive] pub enum Error { /// Error while parsing the handshake packet /// TODO: more detailed parsing errors @@ -22,15 +28,53 @@ pub enum Error { /// Not enough data #[error("not enough data")] NotEnoughData, + /// Could not find common cryptography + #[error("Negotiation of keys/hkdfs/ciphers failed")] + Negotiation, + /// Could not generate Keys + #[error("Key generation failed")] + KeyGeneration, + /// Too many client handshakes currently running + #[error("Too many client handshakes")] + TooManyClientHandshakes, + /// generic internal error + #[error("Internal tracking error")] + InternalTracking, + /// Handshake Timeout + #[error("Handshake timeout")] + Timeout, } -pub(crate) struct HandshakeKey { - pub id: crate::enc::asym::KeyID, - pub key: crate::enc::asym::PrivKey, +/// List of possible handshakes +#[derive( + ::num_derive::FromPrimitive, + Debug, + Clone, + Copy, + PartialEq, + ::strum_macros::EnumString, + ::strum_macros::IntoStaticStr, +)] +#[repr(u8)] +pub enum HandshakeID { + /// 1-RTT Directory synchronized handshake. Fast, no forward secrecy + #[strum(serialize = "directory_synchronized")] + DirectorySynchronized = 0, + /// 2-RTT Stateful exchange. Little DDos protection + #[strum(serialize = "stateful")] + Stateful, + /// 3-RTT stateless exchange. Forward secrecy and ddos protection + #[strum(serialize = "stateless")] + Stateless, +} +impl HandshakeID { + /// The length of the serialized field + pub const fn len() -> usize { + 1 + } } - /// Parsed handshake -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub enum HandshakeData { /// Directory synchronized handhsake DirSync(dirsync::DirSync), @@ -38,16 +82,21 @@ pub enum HandshakeData { impl HandshakeData { /// actual length of the handshake data - pub fn len(&self) -> usize { + pub fn len(&self, head_len: HeadLen, tag_len: TagLen) -> usize { match self { - HandshakeData::DirSync(d) => d.len(), + HandshakeData::DirSync(d) => d.len(head_len, tag_len), } } /// Serialize into raw bytes /// NOTE: assumes that there is exactly asa much buffer as needed - pub fn serialize(&self, out: &mut [u8]) { + pub fn serialize( + &self, + head_len: HeadLen, + tag_len: TagLen, + out: &mut [u8], + ) { match self { - HandshakeData::DirSync(d) => d.serialize(out), + HandshakeData::DirSync(d) => d.serialize(head_len, tag_len, out), } } } @@ -55,7 +104,7 @@ impl HandshakeData { /// Kind of handshake #[derive(::num_derive::FromPrimitive, Debug, Clone, Copy)] #[repr(u8)] -pub enum Kind { +pub enum HandshakeKind { /// 1-RTT, Directory synchronized handshake /// Request DirSyncReq = 0, @@ -71,9 +120,15 @@ pub enum Kind { .... */ } +impl HandshakeKind { + /// Length of the serialized field + pub const fn len() -> usize { + 1 + } +} /// Parsed handshake -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct Handshake { /// Fenrir Protocol version pub fenrir_version: ProtocolVersion, @@ -90,8 +145,10 @@ impl Handshake { } } /// return the total length of the handshake - pub fn len(&self) -> usize { - ProtocolVersion::len() + self.data.len() + pub fn len(&self, head_len: HeadLen, tag_len: TagLen) -> usize { + ProtocolVersion::len() + + HandshakeKind::len() + + self.data.len(head_len, tag_len) } const MIN_PKT_LEN: usize = 8; /// Parse the packet and return the parsed handshake @@ -103,13 +160,15 @@ impl Handshake { Some(fenrir_version) => fenrir_version, None => return Err(Error::Parsing), }; - let handshake_kind = match Kind::from_u8(raw[1]) { + let handshake_kind = match HandshakeKind::from_u8(raw[1]) { Some(handshake_kind) => handshake_kind, None => return Err(Error::Parsing), }; let data = match handshake_kind { - Kind::DirSyncReq => dirsync::Req::deserialize(&raw[2..])?, - Kind::DirSyncResp => dirsync::Resp::deserialize(&raw[2..])?, + HandshakeKind::DirSyncReq => dirsync::Req::deserialize(&raw[2..])?, + HandshakeKind::DirSyncResp => { + dirsync::Resp::deserialize(&raw[2..])? + } }; Ok(Self { fenrir_version, @@ -117,15 +176,21 @@ impl Handshake { }) } /// serialize the handshake into bytes - /// NOTE: assumes that there is exactly asa much buffer as needed - pub fn serialize(&self, out: &mut [u8]) { - assert!(out.len() > 1, "Handshake: not enough buffer to serialize"); - self.fenrir_version.serialize(&mut out[0]); - self.data.serialize(&mut out[1..]); - } - - pub(crate) fn work(&self, keys: &[HandshakeKey]) -> Result<(), Error> { - todo!() + /// NOTE: assumes that there is exactly as much buffer as needed + pub fn serialize( + &self, + head_len: HeadLen, + tag_len: TagLen, + out: &mut [u8], + ) { + out[0] = self.fenrir_version as u8; + out[1] = match &self.data { + HandshakeData::DirSync(d) => match d { + dirsync::DirSync::Req(_) => HandshakeKind::DirSyncReq, + dirsync::DirSync::Resp(_) => HandshakeKind::DirSyncResp, + }, + } as u8; + self.data.serialize(head_len, tag_len, &mut out[2..]); } } diff --git a/src/connection/handshake/tests.rs b/src/connection/handshake/tests.rs new file mode 100644 index 0000000..83cacfd --- /dev/null +++ b/src/connection/handshake/tests.rs @@ -0,0 +1,125 @@ +use crate::{ + auth, + connection::{handshake::*, ID}, + enc::{self, asym::KeyID}, +}; + +#[test] +fn test_handshake_dirsync_req() { + let rand = enc::Random::new(); + let cipher = enc::sym::CipherKind::XChaCha20Poly1305; + + let (_, exchange_key) = + match enc::asym::KeyExchangeKind::X25519DiffieHellman.new_keypair(&rand) + { + Ok(pair) => pair, + Err(_) => { + assert!(false, "Can't generate random keypair"); + return; + } + }; + + let data = dirsync::ReqInner::ClearText(dirsync::ReqData { + nonce: dirsync::Nonce::new(&rand), + client_key_id: KeyID(2424), + id: ID::ID(::core::num::NonZeroU64::new(424242).unwrap()), + auth: dirsync::AuthInfo { + user: auth::UserID::new(&rand), + token: auth::Token::new_anonymous(&rand), + service_id: auth::SERVICEID_AUTH, + domain: auth::Domain("example.com".to_owned()), + }, + }); + + let h_req = Handshake::new(HandshakeData::DirSync(dirsync::DirSync::Req( + dirsync::Req { + key_id: KeyID(4224), + exchange: enc::asym::KeyExchangeKind::X25519DiffieHellman, + hkdf: enc::hkdf::HkdfKind::Sha3, + cipher: enc::sym::CipherKind::XChaCha20Poly1305, + exchange_key, + data, + }, + ))); + + let mut bytes = Vec::::with_capacity( + h_req.len(cipher.nonce_len(), cipher.tag_len()), + ); + bytes.resize(h_req.len(cipher.nonce_len(), cipher.tag_len()), 0); + h_req.serialize(cipher.nonce_len(), cipher.tag_len(), &mut bytes); + + let mut deserialized = match Handshake::deserialize(&bytes) { + Ok(deserialized) => deserialized, + Err(e) => { + assert!(false, "{}", e.to_string()); + return; + } + }; + if let HandshakeData::DirSync(dirsync::DirSync::Req(r_a)) = + &mut deserialized.data + { + let enc_start = r_a.encrypted_offset() + cipher.nonce_len().0; + if let Err(e) = r_a.data.deserialize_as_cleartext( + &bytes[enc_start..(bytes.len() - cipher.tag_len().0)], + ) { + assert!(false, "DirSync Req Inner serialize: {}", e.to_string()); + } + }; + + assert!( + deserialized == h_req, + "DirSync Req (de)serialization not working", + ); +} +#[test] +fn test_handshake_dirsync_reqsp() { + let rand = enc::Random::new(); + let cipher = enc::sym::CipherKind::XChaCha20Poly1305; + + let service_key = enc::Secret::new_rand(&rand); + + let data = dirsync::RespInner::ClearText(dirsync::RespData { + client_nonce: dirsync::Nonce::new(&rand), + id: ID::ID(::core::num::NonZeroU64::new(424242).unwrap()), + service_connection_id: ID::ID( + ::core::num::NonZeroU64::new(434343).unwrap(), + ), + service_key, + }); + + let h_resp = Handshake::new(HandshakeData::DirSync( + dirsync::DirSync::Resp(dirsync::Resp { + client_key_id: KeyID(4444), + data, + }), + )); + + let mut bytes = Vec::::with_capacity( + h_resp.len(cipher.nonce_len(), cipher.tag_len()), + ); + bytes.resize(h_resp.len(cipher.nonce_len(), cipher.tag_len()), 0); + h_resp.serialize(cipher.nonce_len(), cipher.tag_len(), &mut bytes); + + let mut deserialized = match Handshake::deserialize(&bytes) { + Ok(deserialized) => deserialized, + Err(e) => { + assert!(false, "{}", e.to_string()); + return; + } + }; + if let HandshakeData::DirSync(dirsync::DirSync::Resp(r_a)) = + &mut deserialized.data + { + let enc_start = r_a.encrypted_offset() + cipher.nonce_len().0; + if let Err(e) = r_a.data.deserialize_as_cleartext( + &bytes[enc_start..(bytes.len() - cipher.tag_len().0)], + ) { + assert!(false, "DirSync Resp Inner serialize: {}", e.to_string()); + } + }; + + assert!( + deserialized == h_resp, + "DirSync Resp (de)serialization not working", + ); +} diff --git a/src/connection/handshake/tracker.rs b/src/connection/handshake/tracker.rs new file mode 100644 index 0000000..aaa1163 --- /dev/null +++ b/src/connection/handshake/tracker.rs @@ -0,0 +1,391 @@ +//! Handhsake handling + +use crate::{ + auth::{Domain, ServiceID}, + connection::{ + self, + handshake::{self, Error, Handshake}, + Connection, IDRecv, IDSend, + }, + enc::{ + self, + asym::{self, KeyID, PrivKey, PubKey}, + hkdf::{Hkdf, HkdfKind}, + sym::{CipherKind, CipherRecv}, + }, + inner::ThreadTracker, +}; + +use ::tokio::sync::oneshot; + +pub(crate) struct HandshakeServer { + pub id: KeyID, + pub key: PrivKey, + pub domains: Vec, +} + +pub(crate) type ConnectAnswer = Result<(KeyID, IDSend), crate::Error>; + +pub(crate) struct HandshakeClient { + pub service_id: ServiceID, + pub service_conn_id: IDRecv, + pub connection: Connection, + pub timeout: Option<::tokio::task::JoinHandle<()>>, + pub answer: oneshot::Sender, + pub srv_key_id: KeyID, +} + +/// Tracks the keys used by the client and the handshake +/// they are associated with +pub(crate) struct HandshakeClientList { + used: Vec<::bitmaps::Bitmap<1024>>, // index = KeyID + keys: Vec>, + list: Vec>, +} + +impl HandshakeClientList { + pub(crate) fn new() -> Self { + Self { + used: [::bitmaps::Bitmap::<1024>::new()].to_vec(), + keys: Vec::with_capacity(16), + list: Vec::with_capacity(16), + } + } + pub(crate) fn get(&self, id: KeyID) -> Option<&HandshakeClient> { + if id.0 as usize >= self.list.len() { + return None; + } + self.list[id.0 as usize].as_ref() + } + pub(crate) fn remove(&mut self, id: KeyID) -> Option { + if id.0 as usize >= self.list.len() { + return None; + } + let used_vec_idx = id.0 as usize / 1024; + let used_bitmap_idx = id.0 as usize % 1024; + let used_iter = match self.used.get_mut(used_vec_idx) { + Some(used_iter) => used_iter, + None => return None, + }; + used_iter.set(used_bitmap_idx, false); + self.keys[id.0 as usize] = None; + let mut owned = None; + ::core::mem::swap(&mut self.list[id.0 as usize], &mut owned); + owned + } + pub(crate) fn add( + &mut self, + priv_key: PrivKey, + pub_key: PubKey, + service_id: ServiceID, + service_conn_id: IDRecv, + connection: Connection, + answer: oneshot::Sender, + srv_key_id: KeyID, + ) -> Result<(KeyID, &mut HandshakeClient), oneshot::Sender> + { + let maybe_free_key_idx = + self.used.iter().enumerate().find_map(|(idx, bmap)| { + match bmap.first_false_index() { + Some(false_idx) => Some(((idx * 1024), false_idx)), + None => None, + } + }); + let free_key_idx = match maybe_free_key_idx { + Some((idx, false_idx)) => { + let free_key_idx = idx * 1024 + false_idx; + if free_key_idx > KeyID::MAX as usize { + return Err(answer); + } + self.used[idx].set(false_idx, true); + free_key_idx + } + None => { + let mut bmap = ::bitmaps::Bitmap::<1024>::new(); + bmap.set(0, true); + self.used.push(bmap); + self.used.len() * 1024 + } + }; + if self.keys.len() >= free_key_idx { + self.keys.push(None); + self.list.push(None); + } + self.keys[free_key_idx] = Some((priv_key, pub_key)); + self.list[free_key_idx] = Some(HandshakeClient { + service_id, + service_conn_id, + connection, + timeout: None, + answer, + srv_key_id, + }); + Ok(( + KeyID(free_key_idx as u16), + self.list[free_key_idx].as_mut().unwrap(), + )) + } +} +/// Information needed to reply after the key exchange +#[derive(Debug, Clone)] +pub(crate) struct AuthNeededInfo { + /// Parsed handshake packet + pub handshake: Handshake, + /// hkdf generated from the handshake + pub hkdf: Hkdf, +} + +/// Client information needed to fully establish the conenction +#[derive(Debug)] +pub(crate) struct ClientConnectInfo { + /// The service ID that we are connecting to + pub service_id: ServiceID, + /// The service ID that we are connecting to + pub service_connection_id: IDRecv, + /// Parsed handshake packet + pub handshake: Handshake, + /// Connection + pub connection: Connection, + /// where to wake up the waiting client + pub answer: oneshot::Sender, + /// server public key id that we used on the handshake + pub srv_key_id: KeyID, +} +/// Intermediate actions to be taken while parsing the handshake +#[derive(Debug)] +pub(crate) enum HandshakeAction { + /// Parsing finished, all ok, nothing to do + Nothing, + /// Packet parsed, now go perform authentication + AuthNeeded(AuthNeededInfo), + /// the client can fully establish a connection with this info + ClientConnect(ClientConnectInfo), +} + +/// Tracking of handhsakes and conenctions +/// Note that we have multiple Handshake trackers, pinned to different cores +/// Each of them will handle a subset of all handshakes. +/// Each handshake is routed to a different tracker by checking +/// core = (udp_src_sender_port % total_threads) - 1 +pub(crate) struct HandshakeTracker { + thread_id: ThreadTracker, + key_exchanges: Vec, + ciphers: Vec, + /// ephemeral keys used server side in key exchange + keys_srv: Vec, + /// ephemeral keys used client side in key exchange + hshake_cli: HandshakeClientList, +} + +impl HandshakeTracker { + pub(crate) fn new( + thread_id: ThreadTracker, + ciphers: Vec, + key_exchanges: Vec, + ) -> Self { + Self { + thread_id, + ciphers, + key_exchanges, + keys_srv: Vec::new(), + hshake_cli: HandshakeClientList::new(), + } + } + pub(crate) fn add_server_key( + &mut self, + id: KeyID, + key: PrivKey, + ) -> Result<(), ()> { + if self.keys_srv.iter().find(|&k| k.id == id).is_some() { + return Err(()); + } + self.keys_srv.push(HandshakeServer { + id, + key, + domains: Vec::new(), + }); + self.keys_srv.sort_by(|h_a, h_b| h_a.id.0.cmp(&h_b.id.0)); + Ok(()) + } + pub(crate) fn add_server_domain( + &mut self, + domain: &Domain, + key_ids: &[KeyID], + ) -> Result<(), ()> { + // check that all the key ids are present + for id in key_ids.iter() { + if self.keys_srv.iter().find(|k| k.id == *id).is_none() { + return Err(()); + } + } + // add the domain to those keys + for id in key_ids.iter() { + if let Some(srv) = self.keys_srv.iter_mut().find(|k| k.id == *id) { + srv.domains.push(domain.clone()); + } + } + Ok(()) + } + + pub(crate) fn add_client( + &mut self, + priv_key: PrivKey, + pub_key: PubKey, + service_id: ServiceID, + service_conn_id: IDRecv, + connection: Connection, + answer: oneshot::Sender, + srv_key_id: KeyID, + ) -> Result<(KeyID, &mut HandshakeClient), oneshot::Sender> + { + self.hshake_cli.add( + priv_key, + pub_key, + service_id, + service_conn_id, + connection, + answer, + srv_key_id, + ) + } + pub(crate) fn remove_client( + &mut self, + key_id: KeyID, + ) -> Option { + self.hshake_cli.remove(key_id) + } + pub(crate) fn timeout_client( + &mut self, + key_id: KeyID, + ) -> Option<[IDRecv; 2]> { + if let Some(hshake) = self.hshake_cli.remove(key_id) { + let _ = hshake.answer.send(Err(Error::Timeout.into())); + Some([hshake.connection.id_recv, hshake.service_conn_id]) + } else { + None + } + } + pub(crate) fn recv_handshake( + &mut self, + mut handshake: Handshake, + handshake_raw: &mut [u8], + ) -> Result { + use connection::handshake::{dirsync::DirSync, HandshakeData}; + match handshake.data { + HandshakeData::DirSync(ref mut ds) => match ds { + DirSync::Req(ref mut req) => { + if !self.key_exchanges.contains(&req.exchange) { + return Err(enc::Error::UnsupportedKeyExchange.into()); + } + if !self.ciphers.contains(&req.cipher) { + return Err(enc::Error::UnsupportedCipher.into()); + } + let has_key = self.keys_srv.iter().find(|k| { + if k.id == req.key_id { + // Directory synchronized can only use keys + // for key exchange, not signing keys + if let PrivKey::Exchange(_) = k.key { + return true; + } + } + false + }); + + let ephemeral_key; + match has_key { + Some(s_k) => { + if let PrivKey::Exchange(ref k) = &s_k.key { + ephemeral_key = k; + } else { + unreachable!(); + } + } + None => { + return Err(handshake::Error::UnknownKeyID.into()) + } + } + let shared_key = match ephemeral_key + .key_exchange(req.exchange, req.exchange_key) + { + Ok(shared_key) => shared_key, + Err(e) => return Err(handshake::Error::Key(e).into()), + }; + let hkdf = Hkdf::new(HkdfKind::Sha3, b"fenrir", shared_key); + let secret_recv = hkdf.get_secret(b"to_server"); + let cipher_recv = CipherRecv::new(req.cipher, secret_recv); + use crate::enc::sym::AAD; + let aad = AAD(&mut []); // no aad for now + + let encrypt_from = req.encrypted_offset(); + let encrypt_to = encrypt_from + + req.encrypted_length( + cipher_recv.nonce_len(), + cipher_recv.tag_len(), + ); + match cipher_recv.decrypt( + aad, + &mut handshake_raw[encrypt_from..encrypt_to], + ) { + Ok(cleartext) => { + req.data.deserialize_as_cleartext(cleartext)?; + } + Err(e) => { + return Err(handshake::Error::Key(e).into()); + } + } + + return Ok(HandshakeAction::AuthNeeded(AuthNeededInfo { + handshake, + hkdf, + })); + } + DirSync::Resp(resp) => { + let hshake = match self.hshake_cli.get(resp.client_key_id) { + Some(hshake) => hshake, + None => { + ::tracing::debug!( + "No such client key id: {:?}", + resp.client_key_id + ); + return Err(handshake::Error::UnknownKeyID.into()); + } + }; + let cipher_recv = &hshake.connection.cipher_recv; + use crate::enc::sym::AAD; + // no aad for now + let aad = AAD(&mut []); + let data_from = resp.encrypted_offset(); + let data_to = data_from + + resp.encrypted_length( + cipher_recv.nonce_len(), + cipher_recv.tag_len(), + ); + let mut raw_data = &mut handshake_raw[data_from..data_to]; + match cipher_recv.decrypt(aad, &mut raw_data) { + Ok(cleartext) => { + resp.data.deserialize_as_cleartext(&cleartext)?; + } + Err(e) => { + return Err(handshake::Error::Key(e).into()); + } + } + let hshake = + self.hshake_cli.remove(resp.client_key_id).unwrap(); + if let Some(timeout) = hshake.timeout { + timeout.abort(); + } + return Ok(HandshakeAction::ClientConnect( + ClientConnectInfo { + service_id: hshake.service_id, + service_connection_id: hshake.service_conn_id, + handshake, + connection: hshake.connection, + answer: hshake.answer, + srv_key_id: hshake.srv_key_id, + }, + )); + } + }, + } + } +} diff --git a/src/connection/mod.rs b/src/connection/mod.rs index 8db39c4..45a50e9 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -1,21 +1,36 @@ //! Connection handling and send/receive queues pub mod handshake; -mod packet; +pub mod packet; +pub mod socket; -use ::std::vec::Vec; +use ::std::{rc::Rc, vec::Vec}; -pub use handshake::Handshake; -pub use packet::ConnectionID as ID; -pub use packet::{Packet, PacketData}; - -use crate::enc::{ - hkdf::HkdfSha3, - sym::{CipherKind, CipherRecv, CipherSend}, +pub use crate::connection::{ + handshake::Handshake, + packet::{ConnectionID as ID, Packet, PacketData}, }; +use crate::{ + dnssec, + enc::{ + asym::PubKey, + hkdf::Hkdf, + sym::{CipherKind, CipherRecv, CipherSend}, + Random, + }, + inner::ThreadTracker, +}; + +/// strong typedef for receiving connection id +#[derive(Debug, Copy, Clone, PartialEq)] +pub struct IDRecv(pub ID); +/// strong typedef for sending connection id +#[derive(Debug, Copy, Clone, PartialEq)] +pub struct IDSend(pub ID); + /// Version of the fenrir protocol in use -#[derive(::num_derive::FromPrimitive, Debug, Copy, Clone)] +#[derive(::num_derive::FromPrimitive, Debug, Copy, Clone, PartialEq)] #[repr(u8)] pub enum ProtocolVersion { /// First Fenrir Protocol Version @@ -35,10 +50,12 @@ impl ProtocolVersion { /// A single connection and its data #[derive(Debug)] pub struct Connection { - /// Connection ID - pub id: ID, + /// Receiving Connection ID + pub id_recv: IDRecv, + /// Sending Connection ID + pub id_send: IDSend, /// The main hkdf used for all secrets in this connection - pub hkdf: HkdfSha3, + pub hkdf: Hkdf, /// Cipher for decrypting data pub cipher_recv: CipherRecv, /// Cipher for encrypting data @@ -59,10 +76,10 @@ pub enum Role { impl Connection { pub(crate) fn new( - hkdf: HkdfSha3, + hkdf: Hkdf, cipher: CipherKind, role: Role, - rand: &::ring::rand::SystemRandom, + rand: &Random, ) -> Self { let (secret_recv, secret_send) = match role { Role::Server => { @@ -72,14 +89,207 @@ impl Connection { (hkdf.get_secret(b"to_client"), hkdf.get_secret(b"to_server")) } }; - let mut cipher_recv = CipherRecv::new(cipher, secret_recv); - let mut cipher_send = CipherSend::new(cipher, secret_send, rand); + let cipher_recv = CipherRecv::new(cipher, secret_recv); + let cipher_send = CipherSend::new(cipher, secret_send, rand); Self { - id: ID::Handshake, + id_recv: IDRecv(ID::Handshake), + id_send: IDSend(ID::Handshake), hkdf, cipher_recv, cipher_send, } } } + +// PERF: Arc> loks a bit too much, need to find +// faster ways to do this +pub(crate) struct ConnList { + thread_id: ThreadTracker, + connections: Vec>>, + /// Bitmap to track which connection ids are used or free + ids_used: Vec<::bitmaps::Bitmap<1024>>, +} + +impl ConnList { + pub(crate) fn new(thread_id: ThreadTracker) -> Self { + let mut bitmap_id = ::bitmaps::Bitmap::<1024>::new(); + if thread_id.id == 0 { + // make sure we don't count the Handshake ID + bitmap_id.set(0, true); + } + const INITIAL_CAP: usize = 128; + let mut ret = Self { + thread_id, + connections: Vec::with_capacity(INITIAL_CAP), + ids_used: vec![bitmap_id], + }; + ret.connections.resize_with(INITIAL_CAP, || None); + ret + } + pub fn len(&self) -> usize { + let mut total: usize = 0; + for bitmap in self.ids_used.iter() { + total = total + bitmap.len() + } + total + } + /// Only *Reserve* a connection, + /// without actually tracking it in self.connections + pub(crate) fn reserve_first(&mut self) -> IDRecv { + // uhm... bad things are going on here: + // * id must be initialized, but only because: + // * rust does not understand that after the `!found` id is always + // initialized + // * `ID::new_u64` is really safe only with >0, but here it always is + // ...we should probably rewrite it in better, safer rust + let mut id_in_thread: usize = 0; + let mut found = false; + for (i, b) in self.ids_used.iter_mut().enumerate() { + match b.first_false_index() { + Some(idx) => { + b.set(idx, true); + id_in_thread = (i * 1024) + idx; + found = true; + break; + } + None => {} + } + } + if !found { + let mut new_bitmap = ::bitmaps::Bitmap::<1024>::new(); + new_bitmap.set(0, true); + id_in_thread = self.ids_used.len() * 1024; + self.ids_used.push(new_bitmap); + } + // make sure we have enough space in self.connections + let curr_capacity = self.connections.capacity(); + if self.connections.capacity() <= id_in_thread { + // Fill with "None", assure 64 connections without reallocations + let multiple = 64 + curr_capacity - 1; + let new_capacity = multiple - (multiple % curr_capacity); + self.connections.resize_with(new_capacity, || None); + } + // calculate the actual connection ID + let actual_id = ((id_in_thread as u64) * (self.thread_id.total as u64)) + + (self.thread_id.id as u64); + let new_id = IDRecv(ID::new_u64(actual_id)); + new_id + } + /// NOTE: does NOT check if the connection has been previously reserved! + pub(crate) fn track(&mut self, conn: Rc) -> Result<(), ()> { + let conn_id = match conn.id_recv { + IDRecv(ID::Handshake) => { + return Err(()); + } + IDRecv(ID::ID(conn_id)) => conn_id, + }; + let id_in_thread: usize = + (conn_id.get() / (self.thread_id.total as u64)) as usize; + self.connections[id_in_thread] = Some(conn); + Ok(()) + } + pub(crate) fn remove(&mut self, id: IDRecv) { + if let IDRecv(ID::ID(raw_id)) = id { + let id_in_thread: usize = + (raw_id.get() / (self.thread_id.total as u64)) as usize; + let vec_index = id_in_thread / 1024; + let bitmask_index = id_in_thread % 1024; + if let Some(bitmask) = self.ids_used.get_mut(vec_index) { + bitmask.set(bitmask_index, false); + self.connections[id_in_thread] = None; + } + } + } +} + +/// return wether we already have a connection, we are waiting for one, or you +/// can start one +#[derive(Debug, Clone, Copy)] +pub(crate) enum Reservation { + /// we already have a connection. use this ID. + Present(IDSend), + /// we don't have a connection, but we are waiting for one to be established. + Waiting, + /// we have reserved a spot for your connection. + Reserved, +} + +enum MapEntry { + Present(IDSend), + Reserved, +} +use ::std::collections::HashMap; + +/// Link the public key of the authentication server to a connection id +/// so that we can reuse that connection to ask for more authentications +/// +/// Note that a server can have multiple public keys, +/// and the handshake will only ever verify one. +/// To avoid malicious publication fo keys that are not yours, +/// on connection we: +/// * reserve all public keys of the server +/// * wait for the connection to finish +/// * remove all those reservations, exept the one key that actually succeded +/// While searching, we return a connection ID if just one key is a match +// TODO: can we shard this per-core by hashing the pubkey? or domain? or...??? +// This needs a mutex and it will be our goeal to avoid any synchronization +pub(crate) struct AuthServerConnections { + conn_map: HashMap, +} + +impl AuthServerConnections { + pub(crate) fn new() -> Self { + Self { + conn_map: HashMap::with_capacity(32), + } + } + /// add an ID to the reserved spot, + /// and unlock the other pubkeys which have not been verified + pub(crate) fn add( + &mut self, + pubkey: &PubKey, + id: IDSend, + record: &dnssec::Record, + ) { + let _ = self.conn_map.insert(*pubkey, MapEntry::Present(id)); + for (_, pk) in record.public_keys.iter() { + if pk == pubkey { + continue; + } + let _ = self.conn_map.remove(pk); + } + } + /// remove a dropped connection + pub(crate) fn remove_reserved(&mut self, record: &dnssec::Record) { + for (_, pk) in record.public_keys.iter() { + let _ = self.conn_map.remove(pk); + } + } + /// remove a dropped connection + pub(crate) fn remove_conn(&mut self, pubkey: &PubKey) { + let _ = self.conn_map.remove(pubkey); + } + + /// each dnssec::Record has multiple Pubkeys. reserve and ID for them all. + /// later on, when `add` is called we will delete + /// those that have not actually benn used + pub(crate) fn get_or_reserve( + &mut self, + record: &dnssec::Record, + ) -> Reservation { + for (_, pk) in record.public_keys.iter() { + match self.conn_map.get(pk) { + None => {} + Some(MapEntry::Reserved) => return Reservation::Waiting, + Some(MapEntry::Present(id)) => { + return Reservation::Present(id.clone()) + } + } + } + for (_, pk) in record.public_keys.iter() { + let _ = self.conn_map.insert(*pk, MapEntry::Reserved); + } + Reservation::Reserved + } +} diff --git a/src/connection/packet.rs b/src/connection/packet.rs index b59e02a..925460c 100644 --- a/src/connection/packet.rs +++ b/src/connection/packet.rs @@ -1,6 +1,11 @@ // //! Raw packet handling, encryption, decryption, parsing +use crate::enc::{ + sym::{HeadLen, TagLen}, + Random, +}; + /// Fenrir Connection id /// 0 is special as it represents the handshake /// Connection IDs are to be considered u64 little endian @@ -31,8 +36,7 @@ impl ConnectionID { } } /// New random service ID - pub fn new_rand(rand: &::ring::rand::SystemRandom) -> Self { - use ::ring::rand::SecureRandom; + pub fn new_rand(rand: &Random) -> Self { let mut raw = [0; 8]; let mut num = 0; while num == 0 { @@ -54,11 +58,10 @@ impl ConnectionID { } /// write the ID to a buffer pub fn serialize(&self, out: &mut [u8]) { - assert!(out.len() == 8, "out buffer must be 8 bytes"); match self { - ConnectionID::Handshake => out[..].copy_from_slice(&[0; 8]), + ConnectionID::Handshake => out[..8].copy_from_slice(&[0; 8]), ConnectionID::ID(id) => { - out[..].copy_from_slice(&id.get().to_le_bytes()) + out[..8].copy_from_slice(&id.get().to_le_bytes()) } } } @@ -89,25 +92,41 @@ impl From<[u8; 8]> for ConnectionID { pub enum PacketData { /// A parsed handshake packet Handshake(super::Handshake), + /// Raw packet. we only have the connection ID and packet length + Raw(usize), } impl PacketData { /// total length of the data in bytes - pub fn len(&self) -> usize { + pub fn len(&self, head_len: HeadLen, tag_len: TagLen) -> usize { match self { - PacketData::Handshake(h) => h.len(), + PacketData::Handshake(h) => h.len(head_len, tag_len), + PacketData::Raw(len) => *len, } } /// serialize data into bytes /// NOTE: assumes that there is exactly asa much buffer as needed - pub fn serialize(&self, out: &mut [u8]) { - assert!(self.len() == out.len(), "PacketData: wrong buffer length"); + pub fn serialize( + &self, + head_len: HeadLen, + tag_len: TagLen, + out: &mut [u8], + ) { + assert!( + self.len(head_len, tag_len) == out.len(), + "PacketData: wrong buffer length" + ); match self { - PacketData::Handshake(h) => h.serialize(out), + PacketData::Handshake(h) => h.serialize(head_len, tag_len, out), + PacketData::Raw(_) => { + ::tracing::error!("Tried to serialize a raw PacketData!"); + } } } } +const MIN_PACKET_BYTES: usize = 16; + /// Fenrir packet structure #[derive(Debug, Clone)] pub struct Packet { @@ -118,18 +137,36 @@ pub struct Packet { } impl Packet { + /// New recevied packet, yet unparsed + pub fn deserialize_id(raw: &[u8]) -> Result { + // TODO: proper min_packet length. 16 is too conservative. + if raw.len() < MIN_PACKET_BYTES { + return Err(()); + } + let raw_id: [u8; 8] = (raw[..8]).try_into().expect("unreachable"); + Ok(Packet { + id: raw_id.into(), + data: PacketData::Raw(raw.len()), + }) + } /// get the total length of the packet - pub fn len(&self) -> usize { - ConnectionID::len() + self.data.len() + pub fn len(&self, head_len: HeadLen, tag_len: TagLen) -> usize { + ConnectionID::len() + self.data.len(head_len, tag_len) } /// serialize packet into buffer /// NOTE: assumes that there is exactly asa much buffer as needed - pub fn serialize(&self, out: &mut [u8]) { + pub fn serialize( + &self, + head_len: HeadLen, + tag_len: TagLen, + out: &mut [u8], + ) { assert!( out.len() > ConnectionID::len(), "Packet: not enough buffer to serialize" ); self.id.serialize(&mut out[0..ConnectionID::len()]); - self.data.serialize(&mut out[ConnectionID::len()..]); + self.data + .serialize(head_len, tag_len, &mut out[ConnectionID::len()..]); } } diff --git a/src/connection/socket.rs b/src/connection/socket.rs new file mode 100644 index 0000000..abfc106 --- /dev/null +++ b/src/connection/socket.rs @@ -0,0 +1,140 @@ +//! Socket related types and functions + +use ::std::net::SocketAddr; +use ::tokio::{net::UdpSocket, task::JoinHandle}; + +/// Pair to easily track the socket and its async listening handle +pub type SocketTracker = (SocketAddr, JoinHandle<::std::io::Result<()>>); + +/// Strong typedef for a client socket address +#[derive(Debug, Copy, Clone)] +pub(crate) struct UdpClient(pub SocketAddr); +/// Strong typedef for a server socket address +#[derive(Debug, Copy, Clone)] +pub(crate) struct UdpServer(pub SocketAddr); + +/// Enable some common socket options. This is just the unsafe part +fn enable_sock_opt( + fd: ::std::os::fd::RawFd, + option: ::libc::c_int, + value: ::libc::c_int, +) -> ::std::io::Result<()> { + #[allow(unsafe_code)] + unsafe { + #[allow(trivial_casts)] + let val = &value as *const _ as *const ::libc::c_void; + let size = ::core::mem::size_of_val(&value) as ::libc::socklen_t; + // always clear the error bit before doing a new syscall + let _ = ::std::io::Error::last_os_error(); + let ret = ::libc::setsockopt(fd, ::libc::SOL_SOCKET, option, val, size); + if ret != 0 { + return Err(::std::io::Error::last_os_error()); + } + } + Ok(()) +} +/// Add an async udp listener +pub async fn bind_udp(addr: SocketAddr) -> ::std::io::Result { + // I know, kind of a mess. but I really wanted SO_REUSE{ADDR,PORT} and + // no-fragmenting stuff. + // I also did not want to load another library for this. + // feel free to simplify, + // especially if we can avoid libc and other libraries + // we currently use libc because it's a dependency of many other deps + + let fd: ::std::os::fd::RawFd = { + let domain = if addr.is_ipv6() { + ::libc::AF_INET6 + } else { + ::libc::AF_INET + }; + #[allow(unsafe_code)] + let tmp = unsafe { ::libc::socket(domain, ::libc::SOCK_DGRAM, 0) }; + let lasterr = ::std::io::Error::last_os_error(); + if tmp == -1 { + return Err(lasterr); + } + tmp.into() + }; + + if let Err(e) = enable_sock_opt(fd, ::libc::SO_REUSEPORT, 1) { + #[allow(unsafe_code)] + unsafe { + ::libc::close(fd); + } + return Err(e); + } + if let Err(e) = enable_sock_opt(fd, ::libc::SO_REUSEADDR, 1) { + #[allow(unsafe_code)] + unsafe { + ::libc::close(fd); + } + return Err(e); + } + + // We will do path MTU discovery by ourselves, + // always set the "don't fragment" bit + let res = if addr.is_ipv6() { + enable_sock_opt(fd, ::libc::IPV6_DONTFRAG, 1) + } else { + // FIXME: linux only + enable_sock_opt(fd, ::libc::IP_MTU_DISCOVER, ::libc::IP_PMTUDISC_DO) + }; + if let Err(e) = res { + #[allow(unsafe_code)] + unsafe { + ::libc::close(fd); + } + return Err(e); + } + // manually convert rust SockAddr to C sockaddr + #[allow(unsafe_code, trivial_casts, trivial_numeric_casts)] + { + let bind_ret = match addr { + SocketAddr::V4(s4) => { + let ip4: u32 = (*s4.ip()).into(); + let bind_addr = ::libc::sockaddr_in { + sin_family: ::libc::AF_INET as u16, + sin_port: s4.port().to_be(), + sin_addr: ::libc::in_addr { s_addr: ip4 }, + sin_zero: [0; 8], + }; + unsafe { + let c_addr = + &bind_addr as *const _ as *const ::libc::sockaddr; + ::libc::bind(fd, c_addr, 16) + } + } + SocketAddr::V6(s6) => { + let ip6: [u8; 16] = (*s6.ip()).octets(); + let bind_addr = ::libc::sockaddr_in6 { + sin6_family: ::libc::AF_INET6 as u16, + sin6_port: s6.port().to_be(), + sin6_flowinfo: 0, + sin6_addr: ::libc::in6_addr { s6_addr: ip6 }, + sin6_scope_id: 0, + }; + unsafe { + let c_addr = + &bind_addr as *const _ as *const ::libc::sockaddr; + ::libc::bind(fd, c_addr, 24) + } + } + }; + let lasterr = ::std::io::Error::last_os_error(); + if bind_ret != 0 { + unsafe { + ::libc::close(fd); + } + return Err(lasterr); + } + } + + use ::std::os::fd::FromRawFd; + #[allow(unsafe_code)] + let std_sock = unsafe { ::std::net::UdpSocket::from_raw_fd(fd) }; + std_sock.set_nonblocking(true)?; + ::tracing::debug!("Listening udp sock: {}", std_sock.local_addr().unwrap()); + + Ok(UdpSocket::from_std(std_sock)?) +} diff --git a/src/dnssec/mod.rs b/src/dnssec/mod.rs index 912321d..d1128c1 100644 --- a/src/dnssec/mod.rs +++ b/src/dnssec/mod.rs @@ -7,6 +7,11 @@ use ::trust_dns_resolver::TokioAsyncResolver; pub mod record; pub use record::Record; +use crate::auth::Domain; + +#[cfg(test)] +mod tests; + /// Common errors for Dnssec setup and usage #[derive(::thiserror::Error, Debug)] pub enum Error { @@ -19,6 +24,9 @@ pub enum Error { /// Errors in establishing connection or connection drops #[error("nameserver connection: {0:?}")] NameserverConnection(String), + /// record is not valid base85 + #[error("invalid base85")] + InvalidBase85, } #[cfg(any( @@ -39,7 +47,7 @@ pub struct Dnssec { impl Dnssec { /// Spawn connections to DNS via TCP - pub async fn new(resolvers: &Vec) -> Result { + pub fn new(resolvers: &Vec) -> Result { // use a TCP connection to the DNS. // the records we need are big, will not fit in a UDP packet let resolv_conf_resolvers: Vec; @@ -88,10 +96,10 @@ impl Dnssec { } const TXT_RECORD_START: &str = "v=Fenrir1 "; /// Get the fenrir data for a domain - pub async fn resolv(&self, domain: &str) -> ::std::io::Result { + pub async fn resolv(&self, domain: &Domain) -> ::std::io::Result { use ::trust_dns_client::rr::Name; - let fqdn_str = "_fenrir.".to_owned() + domain; + let fqdn_str = "_fenrir.".to_owned() + &domain.0; ::tracing::debug!("Resolving: {}", fqdn_str); let fqdn = Name::from_utf8(&fqdn_str)?; let answers = self.resolver.txt_lookup(fqdn).await?; @@ -135,7 +143,7 @@ impl Dnssec { ::tracing::error!("Can't parse record: {}", e); Err(::std::io::Error::new( ::std::io::ErrorKind::InvalidData, - "Can't parse the record", + "Can't parse the record: ".to_owned() + &e.to_string(), )) } }; diff --git a/src/dnssec/record.rs b/src/dnssec/record.rs index 5f5e09f..a995cea 100644 --- a/src/dnssec/record.rs +++ b/src/dnssec/record.rs @@ -3,136 +3,67 @@ //! //! Encoding and decoding in base85, RFC1924 //! -//! //! Basic encoding idea: -//! * 1 byte: half-bytes +//! * 1 byte: divided in two: //! * half: num of addresses //! * half: num of pubkeys +//! * 1 byte: divided in half: +//! * half: number of key exchanges +//! * half: number of Hkdfs +//! * 1 byte: divided in half: +//! * half: number of ciphers +//! * half: nothing +//! [ # list of pubkeys (max: 16) +//! * 2 byte: pubkey id +//! * 1 byte: pubkey length +//! * 1 byte: pubkey type +//! * Y bytes: pubkey +//! ] //! [ # list of addresses //! * 1 byte: bitfield //! * 0..1 ipv4/ipv6 //! * 2..4 priority (for failover) //! * 5..7 weight between priority -//! * 1 byte: public key id +//! * 1 byte: divided in half: +//! * half: num of public key indexes +//! * half: num of handshake ids //! * 2 bytes: UDP port +//! * [ HALF byte per public key idx ] (index on the list of public keys) +//! * [ 1 byte per handshake id ] //! * X bytes: IP //! ] -//! [ # list of pubkeys -//! * 1 byte: pubkey type -//! * 1 byte: pubkey id -//! * Y bytes: pubkey +//! [ # list of supported key exchanges +//! * 1 byte for each cipher +//! ] +//! [ # list of supported HDKFs +//! * 1 byte for each hkdf +//! ] +//! [ # list of supported ciphers +//! * 1 byte for each cipher //! ] +use crate::{ + connection::handshake::HandshakeID, + enc::{ + self, + asym::{KeyExchangeKind, KeyID, PubKey}, + hkdf::HkdfKind, + sym::CipherKind, + }, +}; use ::core::num::NonZeroU16; use ::num_traits::FromPrimitive; use ::std::{net::IpAddr, vec::Vec}; - /* * Public key data */ -/// Public Key ID -#[derive(Debug, Copy, Clone)] -pub struct PublicKeyID(u8); - -impl TryFrom<&str> for PublicKeyID { - type Error = ::std::io::Error; - fn try_from(raw: &str) -> Result { - if let Ok(id_u8) = raw.parse::() { - return Ok(PublicKeyID(id_u8)); - } - return Err(::std::io::Error::new( - ::std::io::ErrorKind::InvalidData, - "Public Key ID must be between 0 and 256", - )); - } -} - -/// Public Key Type -#[derive(::num_derive::FromPrimitive, Debug, Copy, Clone)] -// public enum: use non_exhaustive to force users to add a default case -// so in the future we can expand this easily -#[non_exhaustive] -#[repr(u8)] -pub enum PublicKeyType { - /// ed25519 asymmetric key - Ed25519 = 0, - /// Ephemeral X25519 (Curve25519) key. - /// Used in the directory synchronized handshake - X25519, -} - -impl PublicKeyType { - /// Get the size of a public key of this kind - pub fn key_len(&self) -> usize { - match &self { - PublicKeyType::Ed25519 => 32, - PublicKeyType::X25519 => 32, // FIXME: hopefully... - } - } -} - -impl TryFrom<&str> for PublicKeyType { - type Error = ::std::io::Error; - fn try_from(raw: &str) -> Result { - if let Ok(type_u8) = raw.parse::() { - if let Some(kind) = PublicKeyType::from_u8(type_u8) { - return Ok(kind); - } - } - return Err(::std::io::Error::new( - ::std::io::ErrorKind::InvalidData, - "Public Key Type 0 is the only one supported", - )); - } -} - -/// Public Key, with its type and id -#[derive(Debug, Clone)] -pub struct PublicKey { - /// public key raw data - pub raw: Vec, - /// type of public key - pub kind: PublicKeyType, - /// id of public key - pub id: PublicKeyID, -} - -impl PublicKey { - fn raw_len(&self) -> usize { - let size = 2; // Public Key Type + ID - size + self.raw.len() - } - fn encode_into(&self, raw: &mut Vec) { - raw.push(self.kind as u8); - raw.push(self.id.0); - raw.extend_from_slice(&self.raw); - } - fn decode_raw(raw: &[u8]) -> Result<(Self, usize), Error> { - if raw.len() < 4 { - return Err(Error::NotEnoughData(0)); - } - - let kind = PublicKeyType::from_u8(raw[0]).unwrap(); - let id = PublicKeyID(raw[1]); - if raw.len() < 2 + kind.key_len() { - return Err(Error::NotEnoughData(2)); - } - - let mut raw_key = Vec::with_capacity(kind.key_len()); - let total_length = 2 + kind.key_len(); - raw_key.extend_from_slice(&raw[2..total_length]); - - Ok(( - Self { - raw: raw_key, - kind, - id, - }, - total_length, - )) - } -} +/// Public Key Index. +/// this points to the index of the public key in the array of public keys. +/// needed to have one byte less in the list of public keys +/// supported by and address +#[derive(Debug, Copy, Clone, PartialEq)] +pub struct PubKeyIdx(pub u8); /* * Address data @@ -230,44 +161,13 @@ impl TryFrom<&str> for AddressWeight { } } -/// List of possible handshakes -#[derive(::num_derive::FromPrimitive, Debug, Clone, Copy)] -#[repr(u8)] -pub enum HandshakeID { - /// 1-RTT Directory synchronized handshake. Fast, no forward secrecy - DirectorySynchronized = 0, - /// 2-RTT Stateful exchange. Little DDos protection - Stateful, - /// 3-RTT stateless exchange. Forward secrecy and ddos protection - Stateless, -} - -impl TryFrom<&str> for HandshakeID { - type Error = ::std::io::Error; - // TODO: from actual names, not only numeric - fn try_from(raw: &str) -> Result { - if let Ok(handshake_u8) = raw.parse::() { - if handshake_u8 >= 1 { - if let Some(handshake) = HandshakeID::from_u8(handshake_u8 - 1) - { - return Ok(handshake); - } - } - } - return Err(::std::io::Error::new( - ::std::io::ErrorKind::InvalidData, - "Unknown handshake ID", - )); - } -} - /// Authentication server address information: /// * ip /// * udp port /// * priority /// * weight within priority /// * list of supported handshakes IDs -/// * list of public keys IDs +/// * list of public keys. Indexes in the Record.public_keys #[derive(Debug, Clone)] pub struct Address { /// Ip address of server, v4 or v6 @@ -282,20 +182,32 @@ pub struct Address { /// List of supported handshakes pub handshake_ids: Vec, /// Public key IDs used by this address - pub public_key_ids: Vec, + pub public_key_idx: Vec, } impl Address { - fn raw_len(&self) -> usize { - // UDP port + Priority + Weight + pubkey_len + handshake_len - let mut size = 6; - size = size + self.public_key_ids.len() + self.handshake_ids.len(); - size + match self.ip { - IpAddr::V4(_) => size + 4, - IpAddr::V6(_) => size + 16, + /// return this Address as a socket address + /// Note that since Fenrir can work on top of IP, without ports, + /// this is not guaranteed to return a SocketAddr + pub fn as_sockaddr(&self) -> Option<::std::net::SocketAddr> { + match self.port { + Some(port) => { + Some(::std::net::SocketAddr::new(self.ip, port.get())) + } + None => None, } } - fn encode_into(&self, raw: &mut Vec) { + fn len(&self) -> usize { + let mut size = 4; + let num_pubkey_idx = self.public_key_idx.len(); + let idx_bytes = (num_pubkey_idx / 2) + (num_pubkey_idx % 2); + size = size + idx_bytes + self.handshake_ids.len(); + size + match self.ip { + IpAddr::V4(_) => 4, + IpAddr::V6(_) => 16, + } + } + fn serialize_into(&self, raw: &mut [u8]) { let mut bitfield: u8 = match self.ip { IpAddr::V4(_) => 0, IpAddr::V6(_) => 1, @@ -305,49 +217,75 @@ impl Address { bitfield <<= 3; bitfield |= self.weight as u8; - raw.push(bitfield); + raw[0] = bitfield; + let len_combined: u8 = self.public_key_idx.len() as u8; + let len_combined = len_combined << 4; + let len_combined = len_combined | self.handshake_ids.len() as u8; + raw[1] = len_combined; - raw.extend_from_slice( + raw[2..4].copy_from_slice( &(match self.port { Some(port) => port.get().to_le_bytes(), None => [0, 0], // oh noez, which zero goes first? }), ); + let mut written: usize = 4; - raw.push(self.public_key_ids.len() as u8); - for id in self.public_key_ids.iter() { - raw.push(id.0); + // pair every idx, since the max is 16 + for chunk in self.public_key_idx.chunks(2) { + let second = { + if chunk.len() == 2 { + chunk[1].0 + } else { + 0 + } + }; + let tmp = chunk[0].0 << 4; + let tmp = tmp | second; + raw[written] = tmp; + written = written + 1; + } + for id in self.handshake_ids.iter() { + raw[written] = *id as u8; + written = written + 1; } + let next_written; match self.ip { IpAddr::V4(ip) => { + next_written = written + 4; let raw_ip = ip.octets(); - raw.extend_from_slice(&raw_ip); + raw[written..next_written].copy_from_slice(&raw_ip); } IpAddr::V6(ip) => { + next_written = written + 16; let raw_ip = ip.octets(); - raw.extend_from_slice(&raw_ip); + raw[written..next_written].copy_from_slice(&raw_ip); } }; + assert!( + next_written == raw.len(), + "write how much? {} - {}", + next_written, + raw.len() + ); } fn decode_raw(raw: &[u8]) -> Result<(Self, usize), Error> { - if raw.len() < 10 { + if raw.len() < 9 { return Err(Error::NotEnoughData(0)); } + // 3-byte bitfield let ip_type = raw[0] >> 6; let is_ipv6: bool; - let total_length: usize; + let ip_len: usize; match ip_type { 0 => { is_ipv6 = false; - total_length = 8; + ip_len = 4; } 1 => { - total_length = 20; - if raw.len() < total_length { - return Err(Error::NotEnoughData(1)); - } - is_ipv6 = true + is_ipv6 = true; + ip_len = 16; } _ => return Err(Error::UnsupportedData(0)), } @@ -356,28 +294,42 @@ impl Address { let priority = AddressPriority::from_u8(raw_priority).unwrap(); let weight = AddressWeight::from_u8(raw_weight).unwrap(); - let raw_port = u16::from_le_bytes([raw[1], raw[2]]); + // Add publickey ids + let num_pubkey_idx = (raw[1] >> 4) as usize; + let num_handshake_ids = (raw[1] & 0x0F) as usize; - // Add publi key ids - let num_pubkey_ids = raw[3] as usize; - if raw.len() < 3 + num_pubkey_ids { + // UDP port + let raw_port = u16::from_le_bytes([raw[2], raw[3]]); + let port = if raw_port == 0 { + None + } else { + Some(NonZeroU16::new(raw_port).unwrap()) + }; + + if raw.len() <= 3 + num_pubkey_idx + num_handshake_ids + ip_len { return Err(Error::NotEnoughData(3)); } - let mut public_key_ids = Vec::with_capacity(num_pubkey_ids); - - for raw_pubkey_id in raw[4..num_pubkey_ids].iter() { - public_key_ids.push(PublicKeyID(*raw_pubkey_id)); + let mut bytes_parsed = 4; + let mut public_key_idx = Vec::with_capacity(num_pubkey_idx); + let idx_bytes = (num_pubkey_idx / 2) + (num_pubkey_idx % 2); + let mut idx_added = 0; + for raw_pubkey_idx_pair in + raw[bytes_parsed..(bytes_parsed + idx_bytes)].iter() + { + let first = PubKeyIdx(raw_pubkey_idx_pair >> 4); + let second = PubKeyIdx(raw_pubkey_idx_pair & 0x0F); + public_key_idx.push(first); + if num_pubkey_idx - idx_added >= 2 { + public_key_idx.push(second); + } + idx_added = idx_added + 2; } + bytes_parsed = bytes_parsed + idx_bytes; // add handshake ids - let next_ptr = 3 + num_pubkey_ids; - let num_handshake_ids = raw[next_ptr] as usize; - if raw.len() < next_ptr + num_handshake_ids { - return Err(Error::NotEnoughData(next_ptr)); - } let mut handshake_ids = Vec::with_capacity(num_handshake_ids); for raw_handshake_id in - raw[next_ptr..(next_ptr + num_pubkey_ids)].iter() + raw[bytes_parsed..(bytes_parsed + num_handshake_ids)].iter() { match HandshakeID::from_u8(*raw_handshake_id) { Some(h_id) => handshake_ids.push(h_id), @@ -390,26 +342,24 @@ impl Address { } } } - let next_ptr = next_ptr + num_pubkey_ids; + bytes_parsed = bytes_parsed + num_handshake_ids; - let port = if raw_port == 0 { - None - } else { - Some(NonZeroU16::new(raw_port).unwrap()) - }; let ip = if is_ipv6 { - let ip_end = next_ptr + 16; + let ip_end = bytes_parsed + 16; if raw.len() < ip_end { - return Err(Error::NotEnoughData(next_ptr)); + return Err(Error::NotEnoughData(bytes_parsed)); } - let raw_ip: [u8; 16] = raw[next_ptr..ip_end].try_into().unwrap(); + let raw_ip: [u8; 16] = + raw[bytes_parsed..ip_end].try_into().unwrap(); + bytes_parsed = bytes_parsed + 16; IpAddr::from(raw_ip) } else { - let ip_end = next_ptr + 4; + let ip_end = bytes_parsed + 4; if raw.len() < ip_end { - return Err(Error::NotEnoughData(next_ptr)); + return Err(Error::NotEnoughData(bytes_parsed)); } - let raw_ip: [u8; 4] = raw[next_ptr..ip_end].try_into().unwrap(); + let raw_ip: [u8; 4] = raw[bytes_parsed..ip_end].try_into().unwrap(); + bytes_parsed = bytes_parsed + 4; IpAddr::from(raw_ip) }; @@ -419,26 +369,32 @@ impl Address { port, priority, weight, - public_key_ids, + public_key_idx, handshake_ids, }, - total_length, + bytes_parsed, )) } } /* - * Actual record puuting it all toghether + * Actual record putting it all toghether */ /// All informations found in the DNSSEC record #[derive(Debug, Clone)] pub struct Record { /// Public keys used by any authentication server - pub public_keys: Vec, + pub public_keys: Vec<(KeyID, PubKey)>, /// List of all authentication servers' addresses. /// Multiple ones can point to the same authentication server pub addresses: Vec
, + /// List of supported key exchanges + pub key_exchanges: Vec, + /// List of supported key exchanges + pub hkdfs: Vec, + /// List of supported ciphers + pub ciphers: Vec, } impl Record { @@ -457,46 +413,134 @@ impl Record { if self.addresses.len() > 16 { return Err(Error::Max16Addresses); } + if self.key_exchanges.len() > 16 { + return Err(Error::Max16KeyExchanges); + } + if self.hkdfs.len() > 16 { + return Err(Error::Max16Hkdfs); + } + if self.ciphers.len() > 16 { + return Err(Error::Max16Ciphers); + } // everything else is all good - let total_size: usize = 1 - + self.addresses.iter().map(|a| a.raw_len()).sum::() - + self.public_keys.iter().map(|a| a.raw_len()).sum::(); + let total_size: usize = 3 + + self.addresses.iter().map(|a| a.len()).sum::() + + self + .public_keys + .iter() + .map(|(_, key)| 3 + key.kind().pub_len()) + .sum::() + + self.key_exchanges.len() + + self.hkdfs.len() + + self.ciphers.len(); let mut raw = Vec::with_capacity(total_size); + raw.resize(total_size, 0); // amount of data. addresses, then pubkeys. 4 bits each let len_combined: u8 = self.addresses.len() as u8; let len_combined = len_combined << 4; let len_combined = len_combined | self.public_keys.len() as u8; + raw[0] = len_combined; + // number of key exchanges and hkdfs + let len_combined: u8 = self.key_exchanges.len() as u8; + let len_combined = len_combined << 4; + let len_combined = len_combined | self.hkdfs.len() as u8; + raw[1] = len_combined; + let num_of_ciphers: u8 = (self.ciphers.len() as u8) << 4; + raw[2] = num_of_ciphers; - raw.push(len_combined); - - for address in self.addresses.iter() { - address.encode_into(&mut raw); + let mut written: usize = 3; + for (public_key_id, public_key) in self.public_keys.iter() { + let key_id_bytes = public_key_id.0.to_le_bytes(); + let written_next = written + KeyID::len(); + raw[written..written_next].copy_from_slice(&key_id_bytes); + written = written_next; + raw[written] = public_key.len() as u8; + written = written + 1; + let written_next = written + public_key.len(); + public_key.serialize_into(&mut raw[written..written_next]); + written = written_next; } - for public_key in self.public_keys.iter() { - public_key.encode_into(&mut raw); + for address in self.addresses.iter() { + let len = address.len(); + let written_next = written + len; + address.serialize_into(&mut raw[written..written_next]); + written = written_next; + } + for k_x in self.key_exchanges.iter() { + raw[written] = *k_x as u8; + written = written + 1; + } + for h in self.hkdfs.iter() { + raw[written] = *h as u8; + written = written + 1; + } + for c in self.ciphers.iter() { + raw[written] = *c as u8; + written = written + 1; } Ok(::base85::encode(&raw)) } /// Decode from base85 to the actual object pub fn decode(raw: &[u8]) -> Result { - // bare minimum for 1 address and key - const MIN_RAW_LENGTH: usize = 1 + 8 + 8; + // bare minimum for lengths, (1 address), (1 key),no cipher negotiation + const MIN_RAW_LENGTH: usize = 3 + (6 + 4) + (4 + 32); if raw.len() <= MIN_RAW_LENGTH { return Err(Error::NotEnoughData(0)); } let mut num_addresses = (raw[0] >> 4) as usize; let mut num_public_keys = (raw[0] & 0x0F) as usize; - let mut bytes_parsed = 1; + let mut num_key_exchanges = (raw[1] >> 4) as usize; + let mut num_hkdfs = (raw[1] & 0x0F) as usize; + let mut num_ciphers = (raw[2] >> 4) as usize; + let mut bytes_parsed = 3; let mut result = Self { addresses: Vec::with_capacity(num_addresses), public_keys: Vec::with_capacity(num_public_keys), + key_exchanges: Vec::with_capacity(num_key_exchanges), + hkdfs: Vec::with_capacity(num_hkdfs), + ciphers: Vec::with_capacity(num_ciphers), }; + while num_public_keys > 0 { + if bytes_parsed + 3 >= raw.len() { + return Err(Error::NotEnoughData(bytes_parsed)); + } + + let raw_key_id = + u16::from_le_bytes([raw[bytes_parsed], raw[bytes_parsed + 1]]); + let id = KeyID(raw_key_id); + bytes_parsed = bytes_parsed + KeyID::len(); + let pubkey_length = raw[bytes_parsed] as usize; + bytes_parsed = bytes_parsed + 1; + let bytes_next_key = bytes_parsed + pubkey_length; + if bytes_next_key > raw.len() { + return Err(Error::NotEnoughData(bytes_parsed)); + } + let (public_key, bytes) = + match PubKey::deserialize(&raw[bytes_parsed..bytes_next_key]) { + Ok((public_key, bytes)) => (public_key, bytes), + Err(enc::Error::UnsupportedKey(_)) => { + // continue parsing. This could be a new pubkey type + // that is not supported by an older client + bytes_parsed = bytes_parsed + pubkey_length; + continue; + } + Err(_) => { + return Err(Error::UnsupportedData(bytes_parsed)); + } + }; + if bytes != pubkey_length { + return Err(Error::UnsupportedData(bytes_parsed)); + } + bytes_parsed = bytes_parsed + bytes; + result.public_keys.push((id, public_key)); + num_public_keys = num_public_keys - 1; + } while num_addresses > 0 { let (address, bytes) = match Address::decode_raw(&raw[bytes_parsed..]) { @@ -513,21 +557,80 @@ impl Record { result.addresses.push(address); num_addresses = num_addresses - 1; } - while num_public_keys > 0 { - let (public_key, bytes) = - match PublicKey::decode_raw(&raw[bytes_parsed..]) { - Ok(public_key) => public_key, - Err(Error::UnsupportedData(b)) => { - return Err(Error::UnsupportedData(bytes_parsed + b)) - } - Err(Error::NotEnoughData(b)) => { - return Err(Error::NotEnoughData(bytes_parsed + b)) - } - Err(e) => return Err(e), - }; - bytes_parsed = bytes_parsed + bytes; - result.public_keys.push(public_key); - num_public_keys = num_public_keys - 1; + for addr in result.addresses.iter() { + for idx in addr.public_key_idx.iter() { + if idx.0 as usize >= result.public_keys.len() { + return Err(Error::Max16PublicKeys); + } + if !result.public_keys[idx.0 as usize] + .1 + .kind() + .capabilities() + .has_exchange() + { + return Err(Error::UnsupportedData(bytes_parsed)); + } + } + } + if bytes_parsed + num_key_exchanges + num_hkdfs + num_ciphers + != raw.len() + { + return Err(Error::NotEnoughData(bytes_parsed)); + } + while num_key_exchanges > 0 { + let key_exchange = match KeyExchangeKind::from_u8(raw[bytes_parsed]) + { + Some(key_exchange) => key_exchange, + None => { + // continue parsing. This could be a new key exchange type + // that is not supported by an older client + ::tracing::warn!( + "Unknown Key exchange {}. Ignoring", + raw[bytes_parsed] + ); + bytes_parsed = bytes_parsed + 1; + continue; + } + }; + bytes_parsed = bytes_parsed + 1; + result.key_exchanges.push(key_exchange); + num_key_exchanges = num_key_exchanges - 1; + } + while num_hkdfs > 0 { + let hkdf = match HkdfKind::from_u8(raw[bytes_parsed]) { + Some(hkdf) => hkdf, + None => { + // continue parsing. This could be a new hkdf type + // that is not supported by an older client + ::tracing::warn!( + "Unknown hkdf {}. Ignoring", + raw[bytes_parsed] + ); + bytes_parsed = bytes_parsed + 1; + continue; + } + }; + bytes_parsed = bytes_parsed + 1; + result.hkdfs.push(hkdf); + num_hkdfs = num_hkdfs - 1; + } + while num_ciphers > 0 { + let cipher = match CipherKind::from_u8(raw[bytes_parsed]) { + Some(cipher) => cipher, + None => { + // continue parsing. This could be a new cipher type + // that is not supported by an older client + ::tracing::warn!( + "Unknown Cipher {}. Ignoring", + raw[bytes_parsed] + ); + bytes_parsed = bytes_parsed + 1; + continue; + } + }; + bytes_parsed = bytes_parsed + 1; + result.ciphers.push(cipher); + num_ciphers = num_ciphers - 1; } if bytes_parsed != raw.len() { Err(Error::UnknownData(bytes_parsed)) @@ -549,6 +652,15 @@ pub enum Error { /// Too many addresses (max 16) #[error("can't encode more than 16 addresses")] Max16Addresses, + /// Too many key exchanges (max 16) + #[error("can't encode more than 16 key exchanges")] + Max16KeyExchanges, + /// Too many Hkdfs (max 16) + #[error("can't encode more than 16 Hkdfs")] + Max16Hkdfs, + /// Too many ciphers (max 16) + #[error("can't encode more than 16 Ciphers")] + Max16Ciphers, /// We need at least one public key #[error("no public keys found")] NoPublicKeyFound, diff --git a/src/dnssec/tests.rs b/src/dnssec/tests.rs new file mode 100644 index 0000000..ff450ae --- /dev/null +++ b/src/dnssec/tests.rs @@ -0,0 +1,59 @@ +use super::*; + +#[test] +fn test_dnssec_serialization() { + let rand = enc::Random::new(); + let (_, exchange_key) = + match enc::asym::KeyExchangeKind::X25519DiffieHellman.new_keypair(&rand) + { + Ok(pair) => pair, + Err(_) => { + assert!(false, "Can't generate random keypair"); + return; + } + }; + use crate::{connection::handshake::HandshakeID, enc}; + + let record = Record { + public_keys: [( + enc::asym::KeyID(42), + enc::asym::PubKey::Exchange(exchange_key), + )] + .to_vec(), + addresses: [record::Address { + ip: ::std::net::IpAddr::V4(::std::net::Ipv4Addr::new(127, 0, 0, 1)), + port: Some(::core::num::NonZeroU16::new(31337).unwrap()), + priority: record::AddressPriority::P1, + weight: record::AddressWeight::W1, + handshake_ids: [HandshakeID::DirectorySynchronized].to_vec(), + public_key_idx: [record::PubKeyIdx(0)].to_vec(), + }] + .to_vec(), + key_exchanges: [enc::asym::KeyExchangeKind::X25519DiffieHellman] + .to_vec(), + hkdfs: [enc::hkdf::HkdfKind::Sha3].to_vec(), + ciphers: [enc::sym::CipherKind::XChaCha20Poly1305].to_vec(), + }; + let encoded = match record.encode() { + Ok(encoded) => encoded, + Err(e) => { + assert!(false, "{}", e.to_string()); + return; + } + }; + let full_record = "v=Fenrir1 ".to_string() + &encoded; + let record = match Dnssec::parse_txt_record(&full_record) { + Ok(record) => record, + Err(e) => { + assert!(false, "{}", e.to_string()); + return; + } + }; + let _re_encoded = match record.encode() { + Ok(re_encoded) => re_encoded, + Err(e) => { + assert!(false, "{}", e.to_string()); + return; + } + }; +} diff --git a/src/enc/asym.rs b/src/enc/asym.rs index c705a5f..47a3b5a 100644 --- a/src/enc/asym.rs +++ b/src/enc/asym.rs @@ -1,10 +1,12 @@ //! Asymmetric key handling and wrappers use ::num_traits::FromPrimitive; -use ::std::vec::Vec; use super::Error; -use crate::enc::sym::Secret; +use crate::{ + config::Config, + enc::{Random, Secret}, +}; /// Public key ID #[derive(Debug, Copy, Clone, PartialEq)] @@ -15,115 +17,392 @@ impl KeyID { pub const fn len() -> usize { 2 } + /// Maximum possible KeyID + pub const MAX: u16 = u16::MAX; /// Serialize into raw bytes pub fn serialize(&self, out: &mut [u8; KeyID::len()]) { out.copy_from_slice(&self.0.to_le_bytes()); } } -/// Kind of key used in the handshake -#[derive(Debug, Copy, Clone, PartialEq, ::num_derive::FromPrimitive)] -#[repr(u8)] -pub enum Key { - /// X25519 Public key - X25519 = 0, +impl TryFrom<&str> for KeyID { + type Error = ::std::io::Error; + fn try_from(raw: &str) -> Result { + if let Ok(id_u16) = raw.parse::() { + return Ok(KeyID(id_u16)); + } + return Err(::std::io::Error::new( + ::std::io::ErrorKind::InvalidData, + "KeyID must be between 0 and 65535", + )); + } } -impl Key { - fn pub_len(&self) -> usize { +impl ::std::fmt::Display for KeyID { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + write!(f, "{}", self.0) + } +} + +/// Capabilities of each key +#[derive(Debug, Clone, Copy)] +pub enum KeyCapabilities { + /// signing *only* + Sign, + /// encrypt *only* + Encrypt, + /// key exchange *only* + Exchange, + /// both sign and encrypt + SignEncrypt, + /// both signing and key exchange + SignExchange, + /// both encrypt and key exchange + EncryptExchange, + /// All: sign, encrypt, Key Exchange + SignEncryptExchage, +} +impl KeyCapabilities { + /// Check if this key supports eky exchage + pub fn has_exchange(&self) -> bool { match self { - // FIXME: 99% wrong size - Key::X25519 => ::ring::signature::ED25519_PUBLIC_KEY_LEN, + KeyCapabilities::Exchange + | KeyCapabilities::SignExchange + | KeyCapabilities::SignEncryptExchage => true, + _ => false, } } } -/// Kind of key exchange -#[derive(Debug, Copy, Clone, PartialEq, ::num_derive::FromPrimitive)] +/// Kind of key used in the handshake +#[derive( + Debug, + Copy, + Clone, + PartialEq, + ::num_derive::FromPrimitive, + ::strum_macros::EnumString, + ::strum_macros::IntoStaticStr, +)] +#[non_exhaustive] #[repr(u8)] -pub enum KeyExchange { - /// X25519 Public key - X25519DiffieHellman = 0, +pub enum KeyKind { + /// Ed25519 Public key (sign only) + #[strum(serialize = "ed25519")] + Ed25519 = 0, + /// X25519 Public key (key exchange) + #[strum(serialize = "x25519")] + X25519, } -impl KeyExchange { - /// The serialize length of the field - pub fn len() -> usize { +impl KeyKind { + /// Length of the serialized field + pub const fn len() -> usize { 1 } + /// return the expected length of the public key + pub fn pub_len(&self) -> usize { + KeyKind::len() + + match self { + // FIXME: 99% wrong size + KeyKind::Ed25519 => ::ring::signature::ED25519_PUBLIC_KEY_LEN, + KeyKind::X25519 => 32, + } + } + /// Get the capabilities of this key type + pub fn capabilities(&self) -> KeyCapabilities { + match self { + KeyKind::Ed25519 => KeyCapabilities::Sign, + KeyKind::X25519 => KeyCapabilities::Exchange, + } + } + /// Returns the key exchanges supported by this key + pub fn key_exchanges(&self) -> &'static [KeyExchangeKind] { + const EMPTY: [KeyExchangeKind; 0] = []; + const X25519_KEY_EXCHANGES: [KeyExchangeKind; 1] = + [KeyExchangeKind::X25519DiffieHellman]; + match self { + KeyKind::Ed25519 => &EMPTY, + KeyKind::X25519 => &X25519_KEY_EXCHANGES, + } + } + /// generate new keypair + pub fn new_keypair( + &self, + rnd: &Random, + ) -> Result<(PrivKey, PubKey), Error> { + PubKey::new_keypair(*self, rnd) + } } -/// Kind of key in the handshake +/// Kind of key exchange +#[derive( + Debug, + Copy, + Clone, + PartialEq, + ::num_derive::FromPrimitive, + ::strum_macros::EnumString, + ::strum_macros::IntoStaticStr, +)] +#[non_exhaustive] +#[repr(u8)] +pub enum KeyExchangeKind { + /// X25519 Public key + #[strum(serialize = "x25519diffiehellman")] + X25519DiffieHellman = 0, +} +impl KeyExchangeKind { + /// The serialize length of the field + pub const fn len() -> usize { + 1 + } + /// Build a new keypair for key exchange + pub fn new_keypair( + &self, + rnd: &Random, + ) -> Result<(ExchangePrivKey, ExchangePubKey), Error> { + match self { + KeyExchangeKind::X25519DiffieHellman => { + let raw_priv = ::x25519_dalek::StaticSecret::new(rnd); + let pub_key = ExchangePubKey::X25519( + ::x25519_dalek::PublicKey::from(&raw_priv), + ); + let priv_key = ExchangePrivKey::X25519(raw_priv); + Ok((priv_key, pub_key)) + } + } + } +} + +/// Kind of public key in the handshake +#[derive(Debug, Copy, Clone, Eq, Hash, PartialEq)] +#[allow(missing_debug_implementations)] +#[non_exhaustive] +pub enum PubKey { + /// Keys to be used only in key exchanges, not for signing + Exchange(ExchangePubKey), + /// Keys to be used only for signing + Signing, +} + +impl PubKey { + /// Get the serialized key length + pub fn len(&self) -> usize { + match self { + PubKey::Exchange(ex) => ex.len(), + PubKey::Signing => todo!(), + } + } + /// return the kind of public key + pub fn kind(&self) -> KeyKind { + match self { + // FIXME: lie, we don't fully support this + PubKey::Signing => KeyKind::Ed25519, + PubKey::Exchange(ex) => ex.kind(), + } + } + /// generate new keypair + fn new_keypair( + kind: KeyKind, + rnd: &Random, + ) -> Result<(PrivKey, PubKey), Error> { + match kind { + KeyKind::Ed25519 => todo!(), + KeyKind::X25519 => { + let (priv_key, pub_key) = + KeyExchangeKind::X25519DiffieHellman.new_keypair(rnd)?; + Ok((PrivKey::Exchange(priv_key), PubKey::Exchange(pub_key))) + } + } + } + /// serialize the key into the buffer + /// NOTE: Assumes there is enough space + pub fn serialize_into(&self, out: &mut [u8]) { + match self { + PubKey::Signing => { + ::tracing::error!("serializing ed25519 not supported"); + return; + } + PubKey::Exchange(ex) => ex.serialize_into(out), + } + } + /// Try to deserialize the pubkey from raw bytes + /// on success returns the public key and the number of parsed bytes + pub fn deserialize(raw: &[u8]) -> Result<(Self, usize), Error> { + if raw.len() < 1 { + return Err(Error::NotEnoughData(0)); + } + let kind: KeyKind = match KeyKind::from_u8(raw[0]) { + Some(kind) => kind, + None => return Err(Error::UnsupportedKey(1)), + }; + if raw.len() < kind.pub_len() { + return Err(Error::NotEnoughData(1)); + } + match kind { + KeyKind::Ed25519 => { + ::tracing::error!("ed25519 keys are not yet supported"); + return Err(Error::Parsing); + } + KeyKind::X25519 => { + let pub_key: ::x25519_dalek::PublicKey = + //match ::bincode::deserialize(&raw[1..(1 + kind.pub_len())]) + match ::bincode::deserialize(&raw[1..]) + { + Ok(pub_key) => pub_key, + Err(e) => { + ::tracing::error!("x25519 deserialize: {}", e); + return Err(Error::Parsing); + } + }; + Ok(( + PubKey::Exchange(ExchangePubKey::X25519(pub_key)), + kind.pub_len(), + )) + } + } + } +} + +/// Kind of private key in the handshake #[derive(Clone)] #[allow(missing_debug_implementations)] +#[non_exhaustive] pub enum PrivKey { /// Keys to be used only in key exchanges, not for signing Exchange(ExchangePrivKey), /// Keys to be used only for signing + // TODO: implement ed25519 Signing, } +impl PrivKey { + /// Get the serialized key length + pub fn len(&self) -> usize { + match self { + PrivKey::Exchange(ex) => ex.len(), + PrivKey::Signing => todo!(), + } + } + /// return the kind of public key + pub fn kind(&self) -> KeyKind { + match self { + PrivKey::Signing => todo!(), + PrivKey::Exchange(ex) => ex.kind(), + } + } + /// serialize the key into the buffer + /// NOTE: Assumes there is enough space + pub fn serialize_into(&self, out: &mut [u8]) { + match self { + PrivKey::Exchange(ex) => ex.serialize_into(out), + PrivKey::Signing => todo!(), + } + } +} +// Fake debug implementation to avoid leaking secrets +impl ::core::fmt::Debug for PrivKey { + fn fmt( + &self, + f: &mut core::fmt::Formatter<'_>, + ) -> Result<(), ::std::fmt::Error> { + ::core::fmt::Debug::fmt("[hidden privkey]", f) + } +} + /// Ephemeral private keys #[derive(Clone)] #[allow(missing_debug_implementations)] +#[non_exhaustive] pub enum ExchangePrivKey { /// X25519(Curve25519) used for key exchange X25519(::x25519_dalek::StaticSecret), } impl ExchangePrivKey { - /// Get the kind of key - pub fn kind(&self) -> Key { + /// Get the serialized key length + pub fn len(&self) -> usize { match self { - ExchangePrivKey::X25519(_) => Key::X25519, + ExchangePrivKey::X25519(_) => KeyKind::X25519.pub_len(), + } + } + /// Get the kind of key + pub fn kind(&self) -> KeyKind { + match self { + ExchangePrivKey::X25519(_) => KeyKind::X25519, } } /// Run the key exchange between two keys of the same kind pub fn key_exchange( &self, - exchange: KeyExchange, + exchange: KeyExchangeKind, pub_key: ExchangePubKey, ) -> Result { match self { ExchangePrivKey::X25519(priv_key) => { - if exchange != KeyExchange::X25519DiffieHellman { + if exchange != KeyExchangeKind::X25519DiffieHellman { return Err(Error::UnsupportedKeyExchange); } - if let ExchangePubKey::X25519(inner_pub_key) = pub_key { - let shared_secret = priv_key.diffie_hellman(&inner_pub_key); - Ok(shared_secret.into()) - } else { - Err(Error::UnsupportedKeyExchange) - } + let ExchangePubKey::X25519(inner_pub_key) = pub_key; + let shared_secret = priv_key.diffie_hellman(&inner_pub_key); + Ok(shared_secret.into()) + } + } + } + /// serialize the key into the buffer + /// NOTE: Assumes there is enough space + pub fn serialize_into(&self, out: &mut [u8]) { + out[0] = self.kind() as u8; + match self { + ExchangePrivKey::X25519(key) => { + out[1..33].copy_from_slice(&key.to_bytes()); } } } } /// all Ephemeral Public keys -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, Eq, Hash, PartialEq)] +#[non_exhaustive] pub enum ExchangePubKey { /// X25519(Curve25519) used for key exchange X25519(::x25519_dalek::PublicKey), } impl ExchangePubKey { - /// length of the public key used for key exchange + /// Get the serialized key length pub fn len(&self) -> usize { match self { - ExchangePubKey::X25519(_) => 32, + ExchangePubKey::X25519(_) => KeyKind::X25519.pub_len(), + } + } + /// Get the kind of key + pub fn kind(&self) -> KeyKind { + match self { + ExchangePubKey::X25519(_) => KeyKind::X25519, + } + } + /// serialize the key into the buffer + /// NOTE: Assumes there is enough space + pub fn serialize_into(&self, out: &mut [u8]) { + out[0] = self.kind() as u8; + match self { + ExchangePubKey::X25519(pk) => { + let bytes = pk.as_bytes(); + out[1..33].copy_from_slice(bytes); + } } } /// Load public key used for key exchange from it raw bytes /// The riesult is "unparsed" since we don't verify /// the actual key - pub fn from_slice(raw: &[u8]) -> Result<(Self, usize), Error> { - // FIXME: get *real* minimum key size - const MIN_KEY_SIZE: usize = 32; - if raw.len() < 1 + MIN_KEY_SIZE { - return Err(Error::NotEnoughData); - } - match Key::from_u8(raw[0]) { + pub fn deserialize(raw: &[u8]) -> Result<(Self, usize), Error> { + match KeyKind::from_u8(raw[0]) { Some(kind) => match kind { - Key::X25519 => { + KeyKind::Ed25519 => { + ::tracing::error!("ed25519 keys are not yet supported"); + return Err(Error::Parsing); + } + KeyKind::X25519 => { let pub_key: ::x25519_dalek::PublicKey = match ::bincode::deserialize( &raw[1..(1 + kind.pub_len())], @@ -140,3 +419,29 @@ impl ExchangePubKey { } } } + +/// Select the best key exchange from our supported list +/// and the other endpoint supported list. +/// Give priority to our list +pub fn server_select_key_exchange( + cfg: &Config, + client_supported: &Vec, +) -> Option { + cfg.key_exchanges + .iter() + .find(|k| client_supported.contains(k)) + .copied() +} +/// Select the best key exchange from our supported list +/// and the other endpoint supported list. +/// Give priority to the server list +/// This is used only in the Directory Synchronized handshake +pub fn client_select_key_exchange( + cfg: &Config, + server_supported: &Vec, +) -> Option { + server_supported + .iter() + .find(|k| cfg.key_exchanges.contains(k)) + .copied() +} diff --git a/src/enc/errors.rs b/src/enc/errors.rs index ded1fc0..a394406 100644 --- a/src/enc/errors.rs +++ b/src/enc/errors.rs @@ -7,15 +7,14 @@ pub enum Error { #[error("can't parse key")] Parsing, /// Not enough data - #[error("not enough data")] - NotEnoughData, + #[error("not enough data: {0}")] + NotEnoughData(usize), /// buffer too small #[error("buffer too small")] InsufficientBuffer, - /// Wrong Key type found. - /// You might have passed rsa keys where x25519 was expected - #[error("wrong key type")] - WrongKey, + /// Unsupported Key type found. + #[error("unsupported key type: {0}")] + UnsupportedKey(usize), /// Unsupported key exchange for this key #[error("unsupported key exchange")] UnsupportedKeyExchange, diff --git a/src/enc/hkdf.rs b/src/enc/hkdf.rs index bb6ca59..e52e236 100644 --- a/src/enc/hkdf.rs +++ b/src/enc/hkdf.rs @@ -1,23 +1,86 @@ //! Hash-based Key Derivation Function //! We just repackage other crates -use ::hkdf::Hkdf; use ::sha3::Sha3_256; use ::zeroize::Zeroize; -use crate::enc::sym::Secret; +use crate::{config::Config, enc::Secret}; + +/// Kind of HKDF +#[derive( + Debug, + Copy, + Clone, + PartialEq, + ::num_derive::FromPrimitive, + ::strum_macros::EnumString, + ::strum_macros::IntoStaticStr, +)] +#[non_exhaustive] +#[repr(u8)] +pub enum HkdfKind { + /// Sha3 + #[strum(serialize = "sha3")] + Sha3 = 0, +} +impl HkdfKind { + /// Length of the serialized type + pub const fn len() -> usize { + 1 + } +} + +/// Generic wrapper on Hkdfs +#[derive(Clone)] +pub enum Hkdf { + /// Sha3 based + Sha3(HkdfSha3), +} + +// Fake debug implementation to avoid leaking secrets +impl ::core::fmt::Debug for Hkdf { + fn fmt( + &self, + f: &mut core::fmt::Formatter<'_>, + ) -> Result<(), ::std::fmt::Error> { + ::core::fmt::Debug::fmt("[hidden hkdf]", f) + } +} + +impl Hkdf { + /// New Hkdf + pub fn new(kind: HkdfKind, salt: &[u8], key: Secret) -> Self { + match kind { + HkdfKind::Sha3 => Self::Sha3(HkdfSha3::new(salt, key)), + } + } + /// Get a secret generated from the key and a given context + pub fn get_secret(&self, context: &[u8]) -> Secret { + match self { + Hkdf::Sha3(sha3) => sha3.get_secret(context), + } + } + /// get the kind of this Hkdf + pub fn kind(&self) -> HkdfKind { + match self { + Hkdf::Sha3(_) => HkdfKind::Sha3, + } + } +} // Hack & tricks: // HKDF are pretty important, but this lib don't zero out the data. // we can't use #[derive(Zeroing)] either. // So we craete a union with a Zeroing object, and drop both manually. +// TODO: move this to Hkdf instead of Sha3 + #[derive(Zeroize)] #[zeroize(drop)] -struct Zeroable([u8; ::core::mem::size_of::>()]); +struct Zeroable([u8; ::core::mem::size_of::<::hkdf::Hkdf>()]); union HkdfInner { - hkdf: ::core::mem::ManuallyDrop>, + hkdf: ::core::mem::ManuallyDrop<::hkdf::Hkdf>, zeroable: ::core::mem::ManuallyDrop, } @@ -49,23 +112,20 @@ pub struct HkdfSha3 { impl HkdfSha3 { /// Instantiate a new HKDF with Sha3-256 - pub fn new(salt: &[u8], key: Secret) -> Self { - let hkdf = Hkdf::::new(Some(salt), key.as_ref()); - #[allow(unsafe_code)] - unsafe { - Self { - inner: HkdfInner { - hkdf: ::core::mem::ManuallyDrop::new(hkdf), - }, - } + pub(crate) fn new(salt: &[u8], key: Secret) -> Self { + let hkdf = ::hkdf::Hkdf::::new(Some(salt), key.as_ref()); + Self { + inner: HkdfInner { + hkdf: ::core::mem::ManuallyDrop::new(hkdf), + }, } } /// Get a secret generated from the key and a given context - pub fn get_secret(&self, context: &[u8]) -> Secret { + pub(crate) fn get_secret(&self, context: &[u8]) -> Secret { let mut out: [u8; 32] = [0; 32]; #[allow(unsafe_code)] unsafe { - self.inner.hkdf.expand(context, &mut out); + let _ = self.inner.hkdf.expand(context, &mut out); } out.into() } @@ -80,3 +140,29 @@ impl ::core::fmt::Debug for HkdfSha3 { ::core::fmt::Debug::fmt("[hidden hkdf]", f) } } + +/// Select the best hkdf from our supported list +/// and the other endpoint supported list. +/// Give priority to our list +pub fn server_select_hkdf( + cfg: &Config, + client_supported: &Vec, +) -> Option { + cfg.hkdfs + .iter() + .find(|h| client_supported.contains(h)) + .copied() +} +/// Select the best hkdf from our supported list +/// and the other endpoint supported list. +/// Give priority to the server list +/// this is used only in the directory synchronized handshake +pub fn client_select_hkdf( + cfg: &Config, + server_supported: &Vec, +) -> Option { + server_supported + .iter() + .find(|h| cfg.hkdfs.contains(h)) + .copied() +} diff --git a/src/enc/mod.rs b/src/enc/mod.rs index eda3385..663c72d 100644 --- a/src/enc/mod.rs +++ b/src/enc/mod.rs @@ -4,5 +4,126 @@ pub mod asym; mod errors; pub mod hkdf; pub mod sym; +#[cfg(test)] +mod tests; pub use errors::Error; + +use ::ring::rand::SecureRandom; +use ::zeroize::Zeroize; + +/// wrapper where we implement whatever random traint stuff each library needs +pub struct Random { + /// actual source of randomness + rnd: ::ring::rand::SystemRandom, +} + +impl Random { + /// Build a nre Random source + pub fn new() -> Self { + Self { + rnd: ::ring::rand::SystemRandom::new(), + } + } + /// Fill a buffer with randomness + pub fn fill(&self, out: &mut [u8]) { + let _ = self.rnd.fill(out); + } + /// return the underlying ring SystemRandom + pub fn ring_rnd(&self) -> &::ring::rand::SystemRandom { + &self.rnd + } +} + +// Fake debug implementation to avoid leaking secrets +impl ::core::fmt::Debug for Random { + fn fmt( + &self, + f: &mut core::fmt::Formatter<'_>, + ) -> Result<(), ::std::fmt::Error> { + ::core::fmt::Debug::fmt("[hidden randomness]", f) + } +} + +// ::rand_core::{RngCore, CryptoRng} needed for ::x25519::dalek +impl ::rand_core::RngCore for &Random { + fn next_u32(&mut self) -> u32 { + use ::core::mem::MaybeUninit; + let mut out: MaybeUninit<[u8; 4]> = MaybeUninit::uninit(); + #[allow(unsafe_code)] + unsafe { + let _ = self.rnd.fill(out.assume_init_mut()); + u32::from_le_bytes(out.assume_init()) + } + } + fn next_u64(&mut self) -> u64 { + use ::core::mem::MaybeUninit; + let mut out: MaybeUninit<[u8; 8]> = MaybeUninit::uninit(); + #[allow(unsafe_code)] + unsafe { + let _ = self.rnd.fill(out.assume_init_mut()); + u64::from_le_bytes(out.assume_init()) + } + } + fn fill_bytes(&mut self, dest: &mut [u8]) { + let _ = self.rnd.fill(dest); + } + fn try_fill_bytes( + &mut self, + dest: &mut [u8], + ) -> Result<(), ::rand_core::Error> { + match self.rnd.fill(dest) { + Ok(()) => Ok(()), + Err(e) => Err(::rand_core::Error::new(e)), + } + } +} +impl ::rand_core::CryptoRng for &Random {} + +/// Secret, used for keys. +/// Grants that on drop() we will zero out memory +#[derive(Zeroize, Clone, PartialEq)] +#[zeroize(drop)] +pub struct Secret([u8; 32]); +// Fake debug implementation to avoid leaking secrets +impl ::core::fmt::Debug for Secret { + fn fmt( + &self, + f: &mut core::fmt::Formatter<'_>, + ) -> Result<(), ::std::fmt::Error> { + ::core::fmt::Debug::fmt("[hidden secret]", f) + } +} + +impl Secret { + /// return the length of the serialized secret + pub const fn len() -> usize { + 32 + } + /// New randomly generated secret + pub fn new_rand(rand: &Random) -> Self { + let mut ret = Self([0; 32]); + rand.fill(&mut ret.0); + ret + } + /// return a reference to the secret + pub fn as_ref(&self) -> &[u8; 32] { + &self.0 + } +} +impl From<[u8; 32]> for Secret { + fn from(shared_secret: [u8; 32]) -> Self { + Self(shared_secret) + } +} +impl From<&[u8; 32]> for Secret { + fn from(shared_secret: &[u8; 32]) -> Self { + Self(*shared_secret) + } +} + +impl From<::x25519_dalek::SharedSecret> for Secret { + fn from(shared_secret: ::x25519_dalek::SharedSecret) -> Self { + Self(shared_secret.to_bytes()) + } +} diff --git a/src/enc/sym.rs b/src/enc/sym.rs index 366f665..d4204e0 100644 --- a/src/enc/sym.rs +++ b/src/enc/sym.rs @@ -1,67 +1,36 @@ //! Symmetric cypher stuff use super::Error; -use ::std::collections::VecDeque; -use ::zeroize::Zeroize; - -/// Secret, used for keys. -/// Grants that on drop() we will zero out memory -#[derive(Zeroize, Clone)] -#[zeroize(drop)] -pub struct Secret([u8; 32]); -// Fake debug implementation to avoid leaking secrets -impl ::core::fmt::Debug for Secret { - fn fmt( - &self, - f: &mut core::fmt::Formatter<'_>, - ) -> Result<(), ::std::fmt::Error> { - ::core::fmt::Debug::fmt("[hidden secret]", f) - } -} - -impl Secret { - /// New randomly generated secret - pub fn new_rand(rand: &::ring::rand::SystemRandom) -> Self { - use ::ring::rand::SecureRandom; - let mut ret = Self([0; 32]); - rand.fill(&mut ret.0); - ret - } - /// return a reference to the secret - pub fn as_ref(&self) -> &[u8; 32] { - &self.0 - } -} - -impl From<[u8; 32]> for Secret { - fn from(shared_secret: [u8; 32]) -> Self { - Self(shared_secret) - } -} - -impl From<::x25519_dalek::SharedSecret> for Secret { - fn from(shared_secret: ::x25519_dalek::SharedSecret) -> Self { - Self(shared_secret.to_bytes()) - } -} +use crate::{ + config::Config, + enc::{Random, Secret}, +}; /// List of possible Ciphers -#[derive(Debug, Copy, Clone, PartialEq, ::num_derive::FromPrimitive)] +#[derive( + Debug, + Copy, + Clone, + PartialEq, + ::num_derive::FromPrimitive, + ::strum_macros::EnumString, + ::strum_macros::IntoStaticStr, +)] #[repr(u8)] pub enum CipherKind { /// XChaCha20_Poly1305 + #[strum(serialize = "xchacha20poly1305")] XChaCha20Poly1305 = 0, } impl CipherKind { /// length of the serialized id for the cipher kind field - pub fn len() -> usize { + pub const fn len() -> usize { 1 } /// required length of the nonce - pub fn nonce_len(&self) -> usize { - // TODO: how the hell do I take this from ::chacha20poly1305? - Nonce::len() + pub fn nonce_len(&self) -> HeadLen { + HeadLen(Nonce::len()) } /// required length of the key pub fn key_len(&self) -> usize { @@ -69,15 +38,15 @@ impl CipherKind { XChaCha20Poly1305::key_size() } /// Length of the authentication tag - pub fn tag_len(&self) -> usize { + pub fn tag_len(&self) -> TagLen { // TODO: how the hell do I take this from ::chacha20poly1305? - ::ring::aead::CHACHA20_POLY1305.tag_len() + TagLen(::ring::aead::CHACHA20_POLY1305.tag_len()) } } /// Additional Authenticated Data #[derive(Debug)] -pub struct AAD<'a>(pub &'a mut [u8]); +pub struct AAD<'a>(pub &'a [u8]); /// Cipher direction, to make sure we don't reuse the same cipher /// for both decrypting and encrypting @@ -90,6 +59,16 @@ pub enum CipherDirection { Send, } +/// strong typedef for header length +/// aka: nonce length in the encrypted data) +#[derive(Debug, Copy, Clone)] +pub struct HeadLen(pub usize); +/// strong typedef for the Tag length +/// aka: cryptographic authentication tag length at the end +/// of the encrypted data +#[derive(Debug, Copy, Clone)] +pub struct TagLen(pub usize); + /// actual ciphers enum Cipher { /// Cipher XChaha20_Poly1305 @@ -105,35 +84,42 @@ impl Cipher { } } } - fn nonce_len(&self) -> usize { + pub fn kind(&self) -> CipherKind { + match self { + Cipher::XChaCha20Poly1305(_) => CipherKind::XChaCha20Poly1305, + } + } + fn nonce_len(&self) -> HeadLen { + match self { + Cipher::XChaCha20Poly1305(_) => HeadLen(Nonce::len()), + } + } + fn tag_len(&self) -> TagLen { match self { Cipher::XChaCha20Poly1305(_) => { // TODO: how the hell do I take this from ::chacha20poly1305? - ::ring::aead::CHACHA20_POLY1305.nonce_len() + TagLen(::ring::aead::CHACHA20_POLY1305.tag_len()) } } } - fn tag_len(&self) -> usize { - match self { - Cipher::XChaCha20Poly1305(_) => { - // TODO: how the hell do I take this from ::chacha20poly1305? - ::ring::aead::CHACHA20_POLY1305.tag_len() - } - } - } - fn decrypt(&self, aad: AAD, data: &mut VecDeque) -> Result<(), Error> { + fn decrypt<'a>( + &self, + aad: AAD, + raw_data: &'a mut [u8], + ) -> Result<&'a [u8], Error> { match self { Cipher::XChaCha20Poly1305(cipher) => { use ::chacha20poly1305::{ aead::generic_array::GenericArray, AeadInPlace, }; - let final_len: usize; - { - let raw_data = data.as_mut_slices().0; - // FIXME: check min data length - let (nonce_bytes, data_and_tag) = raw_data.split_at_mut(13); + let final_len: usize = { + if raw_data.len() <= self.overhead() { + return Err(Error::NotEnoughData(raw_data.len())); + } + let (nonce_bytes, data_and_tag) = + raw_data.split_at_mut(Nonce::len()); let (data_notag, tag_bytes) = data_and_tag.split_at_mut( - data_and_tag.len() + 1 + data_and_tag.len() - ::ring::aead::CHACHA20_POLY1305.tag_len(), ); let nonce = GenericArray::from_slice(nonce_bytes); @@ -147,19 +133,19 @@ impl Cipher { if maybe.is_err() { return Err(Error::Decrypt); } - final_len = data_notag.len(); - } - data.drain(..Nonce::len()); - data.truncate(final_len); - Ok(()) + data_notag.len() + }; + //data.drain(..Nonce::len()); + //data.truncate(final_len); + Ok(&raw_data[Nonce::len()..Nonce::len() + final_len]) } } } fn overhead(&self) -> usize { match self { - Cipher::XChaCha20Poly1305(cipher) => { + Cipher::XChaCha20Poly1305(_) => { let cipher = CipherKind::XChaCha20Poly1305; - cipher.nonce_len() + cipher.tag_len() + cipher.nonce_len().0 + cipher.tag_len().0 } } } @@ -167,35 +153,31 @@ impl Cipher { &self, nonce: &Nonce, aad: AAD, - data: &mut Data, + data: &mut [u8], ) -> Result<(), Error> { - // No need to check for minimum buffer size since `Data` assures we - // already went through that + // FIXME: check minimum buffer size match self { Cipher::XChaCha20Poly1305(cipher) => { - use ::chacha20poly1305::{ - aead::generic_array::GenericArray, AeadInPlace, - }; + use ::chacha20poly1305::AeadInPlace; + let tag_len: usize = ::ring::aead::CHACHA20_POLY1305.tag_len(); + let data_len_notag = data.len() - tag_len; // write nonce - data.get_slice_full()[..Nonce::len()] - .copy_from_slice(nonce.as_bytes()); + data[..Nonce::len()].copy_from_slice(nonce.as_bytes()); // encrypt data match cipher.cipher.encrypt_in_place_detached( nonce.as_bytes().into(), aad.0, - data.get_slice(), + &mut data[Nonce::len()..data_len_notag], ) { Ok(tag) => { - // add tag - data.get_tag_slice().copy_from_slice(tag.as_slice()); + data[data_len_notag..].copy_from_slice(tag.as_slice()); Ok(()) } Err(_) => Err(Error::Encrypt), - }; + } } } - todo!() } } @@ -216,46 +198,25 @@ impl CipherRecv { Self(Cipher::new(kind, secret)) } /// Get the length of the nonce for this cipher - pub fn nonce_len(&self) -> usize { + pub fn nonce_len(&self) -> HeadLen { self.0.nonce_len() } + /// Get the length of the nonce for this cipher + pub fn tag_len(&self) -> TagLen { + self.0.tag_len() + } /// Decrypt a paket. Nonce and Tag are taken from the packet, /// while you need to provide AAD (Additional Authenticated Data) pub fn decrypt<'a>( &self, aad: AAD, - data: &mut VecDeque, - ) -> Result<(), Error> { + data: &'a mut [u8], + ) -> Result<&'a [u8], Error> { self.0.decrypt(aad, data) } -} - -/// Allocate some data, with additional indexes to track -/// where nonce and tags are -#[derive(Debug, Clone)] -pub struct Data { - data: Vec, - skip_start: usize, - skip_end: usize, -} - -impl Data { - /// Get the slice where you will write the actual data - /// this will skip the actual nonce and AEAD tag and give you - /// only the space for the data - pub fn get_slice(&mut self) -> &mut [u8] { - &mut self.data[self.skip_start..self.skip_end] - } - fn get_tag_slice(&mut self) -> &mut [u8] { - let start = self.data.len() - self.skip_end; - &mut self.data[start..] - } - fn get_slice_full(&mut self) -> &mut [u8] { - &mut self.data - } - /// Consume the data and return the whole raw vector - pub fn get_raw(self) -> Vec { - self.data + /// return the underlying cipher id + pub fn kind(&self) -> CipherKind { + self.0.kind() } } @@ -275,30 +236,22 @@ impl ::core::fmt::Debug for CipherSend { impl CipherSend { /// Build a new Cipher - pub fn new( - kind: CipherKind, - secret: Secret, - rand: &::ring::rand::SystemRandom, - ) -> Self { + pub fn new(kind: CipherKind, secret: Secret, rand: &Random) -> Self { Self { nonce: NonceSync::new(rand), cipher: Cipher::new(kind, secret), } } - /// Allocate the memory for the data that will be encrypted - pub fn make_data(&self, length: usize) -> Data { - Data { - data: Vec::with_capacity(length + self.cipher.overhead()), - skip_start: self.cipher.nonce_len(), - skip_end: self.cipher.tag_len(), - } - } /// Encrypt the given data - pub fn encrypt(&self, aad: AAD, data: &mut Data) -> Result<(), Error> { + pub fn encrypt(&self, aad: AAD, data: &mut [u8]) -> Result<(), Error> { let old_nonce = self.nonce.advance(); self.cipher.encrypt(&old_nonce, aad, data)?; Ok(()) } + /// return the underlying cipher id + pub fn kind(&self) -> CipherKind { + self.cipher.kind() + } } /// XChaCha20Poly1305 cipher @@ -317,7 +270,7 @@ impl XChaCha20Poly1305 { } // -// TODO: For efficiency "Nonce" should become a reference. +// TODO: Merge crate::{enc::sym::Nonce, connection::handshake::dirsync::Nonce} // #[derive(Debug, Copy, Clone)] @@ -332,7 +285,7 @@ struct NonceNum { #[repr(C)] pub union Nonce { num: NonceNum, - raw: [u8; 12], + raw: [u8; Self::len()], } impl ::core::fmt::Debug for Nonce { @@ -349,18 +302,18 @@ impl ::core::fmt::Debug for Nonce { impl Nonce { /// Generate a new random Nonce - pub fn new(rand: &::ring::rand::SystemRandom) -> Self { - use ring::rand::SecureRandom; - let mut raw = [0; 12]; + pub fn new(rand: &Random) -> Self { + let mut raw = [0; Self::len()]; rand.fill(&mut raw); - #[allow(unsafe_code)] - unsafe { - Self { raw } - } + Self { raw } } /// Length of this nonce in bytes pub const fn len() -> usize { - return 12; + // FIXME: was:12. xchacha20poly1305 requires 24. + // but we should change keys much earlier than that, and our + // nonces are not random, but sequential. + // we should change keys every 2^30 bytes to be sure (stream max window) + return 24; } /// Get reference to the nonce bytes pub fn as_bytes(&self) -> &[u8] { @@ -370,11 +323,8 @@ impl Nonce { } } /// Create Nonce from array - pub fn from_slice(raw: [u8; 12]) -> Self { - #[allow(unsafe_code)] - unsafe { - Self { raw } - } + pub fn from_slice(raw: [u8; Self::len()]) -> Self { + Self { raw } } /// Go to the next nonce pub fn advance(&mut self) { @@ -390,13 +340,14 @@ impl Nonce { } /// Synchronize the mutex acess with a nonce for multithread safety +// TODO: remove mutex, not needed anymore #[derive(Debug)] pub struct NonceSync { nonce: ::std::sync::Mutex, } impl NonceSync { /// Create a new thread safe nonce - pub fn new(rand: &::ring::rand::SystemRandom) -> Self { + pub fn new(rand: &Random) -> Self { Self { nonce: ::std::sync::Mutex::new(Nonce::new(rand)), } @@ -412,3 +363,28 @@ impl NonceSync { old_nonce } } +/// Select the best cipher from our supported list +/// and the other endpoint supported list. +/// Give priority to our list +pub fn server_select_cipher( + cfg: &Config, + client_supported: &Vec, +) -> Option { + cfg.ciphers + .iter() + .find(|c| client_supported.contains(c)) + .copied() +} +/// Select the best cipher from our supported list +/// and the other endpoint supported list. +/// Give priority to the server list +/// This is used only in the Directory synchronized handshake +pub fn client_select_cipher( + cfg: &Config, + server_supported: &Vec, +) -> Option { + server_supported + .iter() + .find(|c| cfg.ciphers.contains(c)) + .copied() +} diff --git a/src/enc/tests.rs b/src/enc/tests.rs new file mode 100644 index 0000000..ead07ee --- /dev/null +++ b/src/enc/tests.rs @@ -0,0 +1,135 @@ +use crate::{ + auth, + connection::{handshake::*, ID}, + enc::{self, asym::KeyID}, +}; + +#[test] +fn test_simple_encrypt_decrypt() { + let rand = enc::Random::new(); + let cipher = enc::sym::CipherKind::XChaCha20Poly1305; + let secret = enc::Secret::new_rand(&rand); + let secret2 = secret.clone(); + + let cipher_send = enc::sym::CipherSend::new(cipher, secret, &rand); + let cipher_recv = enc::sym::CipherRecv::new(cipher, secret2); + + let mut data = Vec::new(); + let tot_len = cipher_recv.nonce_len().0 + 1234 + cipher_recv.tag_len().0; + data.resize(tot_len, 0); + rand.fill(&mut data); + data[..enc::sym::Nonce::len()].copy_from_slice(&[0; 24]); + let last = data.len() - cipher_recv.tag_len().0; + data[last..].copy_from_slice(&[0; 16]); + let orig = data.clone(); + let raw_aad: [u8; 8] = [0, 1, 2, 3, 4, 5, 6, 7]; + let aad = enc::sym::AAD(&raw_aad[..]); + let aad2 = enc::sym::AAD(&raw_aad[..]); + if cipher_send.encrypt(aad, &mut data).is_err() { + assert!(false, "Encrypt failed"); + } + if cipher_recv.decrypt(aad2, &mut data).is_err() { + assert!(false, "Decrypt failed"); + } + data[..enc::sym::Nonce::len()].copy_from_slice(&[0; 24]); + let last = data.len() - cipher_recv.tag_len().0; + data[last..].copy_from_slice(&[0; 16]); + assert!(orig == data, "DIFFERENT!\n{:?}\n{:?}\n", orig, data); +} + +#[test] +fn test_encrypt_decrypt() { + let rand = enc::Random::new(); + let cipher = enc::sym::CipherKind::XChaCha20Poly1305; + let secret = enc::Secret::new_rand(&rand); + let secret2 = secret.clone(); + + let cipher_send = enc::sym::CipherSend::new(cipher, secret, &rand); + let cipher_recv = enc::sym::CipherRecv::new(cipher, secret2); + let nonce_len = cipher_recv.nonce_len(); + let tag_len = cipher_recv.tag_len(); + + let service_key = enc::Secret::new_rand(&rand); + + let data = dirsync::RespInner::ClearText(dirsync::RespData { + client_nonce: dirsync::Nonce::new(&rand), + id: ID::ID(::core::num::NonZeroU64::new(424242).unwrap()), + service_connection_id: ID::ID( + ::core::num::NonZeroU64::new(434343).unwrap(), + ), + service_key, + }); + + let resp = dirsync::Resp { + client_key_id: KeyID(4444), + data, + }; + let encrypt_from = resp.encrypted_offset(); + let encrypt_to = encrypt_from + resp.encrypted_length(nonce_len, tag_len); + + let h_resp = + Handshake::new(HandshakeData::DirSync(dirsync::DirSync::Resp(resp))); + + let mut bytes = Vec::::with_capacity( + h_resp.len(cipher.nonce_len(), cipher.tag_len()), + ); + bytes.resize(h_resp.len(cipher.nonce_len(), cipher.tag_len()), 0); + h_resp.serialize(cipher.nonce_len(), cipher.tag_len(), &mut bytes); + + let raw_aad: [u8; 7] = [0, 1, 2, 3, 4, 5, 6]; + let aad = enc::sym::AAD(&raw_aad[..]); + let aad2 = enc::sym::AAD(&raw_aad[..]); + + let pre_encrypt = bytes.clone(); + // encrypt + if cipher_send + .encrypt(aad, &mut bytes[encrypt_from..encrypt_to]) + .is_err() + { + assert!(false, "Encrypt failed"); + } + if cipher_recv + .decrypt(aad2, &mut bytes[encrypt_from..encrypt_to]) + .is_err() + { + assert!(false, "Decrypt failed"); + } + // make sure Nonce and Tag are 0 + bytes[encrypt_from..(encrypt_from + nonce_len.0)].copy_from_slice(&[0; 24]); + let tag_from = encrypt_to - tag_len.0; + bytes[tag_from..(tag_from + tag_len.0)].copy_from_slice(&[0; 16]); + assert!( + pre_encrypt == bytes, + "{}|{}=\n{:?}\n{:?}", + encrypt_from, + encrypt_to, + pre_encrypt, + bytes + ); + + // decrypt + + let mut deserialized = match Handshake::deserialize(&bytes) { + Ok(deserialized) => deserialized, + Err(e) => { + assert!(false, "{}", e.to_string()); + return; + } + }; + // reparse + if let HandshakeData::DirSync(dirsync::DirSync::Resp(r_a)) = + &mut deserialized.data + { + let enc_start = r_a.encrypted_offset() + cipher.nonce_len().0; + if let Err(e) = r_a.data.deserialize_as_cleartext( + &bytes[enc_start..(bytes.len() - cipher.tag_len().0)], + ) { + assert!(false, "DirSync Resp Inner serialize: {}", e.to_string()); + } + }; + + assert!( + deserialized == h_resp, + "DirSync Resp (de)serialization not working", + ); +} diff --git a/src/inner/mod.rs b/src/inner/mod.rs new file mode 100644 index 0000000..001ca16 --- /dev/null +++ b/src/inner/mod.rs @@ -0,0 +1,14 @@ +//! Inner Fenrir tracking +//! This is meant to be **async-free** so that others might use it +//! without the tokio runtime + +pub(crate) mod worker; + +/// Track the total number of threads and our index +/// 65K cpus should be enough for anybody +#[derive(Debug, Clone, Copy)] +pub(crate) struct ThreadTracker { + pub total: u16, + /// Note: starts from 1 + pub id: u16, +} diff --git a/src/inner/worker.rs b/src/inner/worker.rs new file mode 100644 index 0000000..083a9ac --- /dev/null +++ b/src/inner/worker.rs @@ -0,0 +1,655 @@ +//! Worker thread implementation +use crate::{ + auth::{self, Domain, ServiceID, Token, TokenChecker, UserID}, + config::Config, + connection::{ + self, + handshake::{ + self, + dirsync::{self, DirSync}, + tracker::{HandshakeAction, HandshakeTracker}, + Handshake, HandshakeData, + }, + socket::{UdpClient, UdpServer}, + ConnList, Connection, IDSend, Packet, + }, + dnssec, + enc::{ + asym::{self, KeyID, PrivKey, PubKey}, + hkdf::{self, Hkdf, HkdfKind}, + sym, Random, Secret, + }, + inner::ThreadTracker, +}; +use ::std::{sync::Arc, vec::Vec}; +/// This worker must be cpu-pinned +use ::tokio::{ + net::UdpSocket, + sync::{mpsc, oneshot, Mutex}, +}; + +/// Track a raw Udp packet +pub(crate) struct RawUdp { + pub src: UdpClient, + pub dst: UdpServer, + pub data: Vec, + pub packet: Packet, +} + +pub(crate) struct ConnectInfo { + pub answer: oneshot::Sender, + pub resolved: dnssec::Record, + pub service_id: ServiceID, + pub domain: Domain, + // TODO: UserID, Token information +} + +pub(crate) enum Work { + /// ask the thread to report to the main thread the total number of + /// connections present + CountConnections(oneshot::Sender), + Connect(ConnectInfo), + DropHandshake(KeyID), + Recv(RawUdp), +} + +/// Actual worker implementation. +#[allow(missing_debug_implementations)] +pub struct Worker { + cfg: Config, + thread_id: ThreadTracker, + // PERF: rand uses syscalls. how to do that async? + rand: Random, + stop_working: crate::StopWorkingRecvCh, + token_check: Option>>, + sockets: Vec>, + queue: ::async_channel::Receiver, + queue_timeouts_recv: mpsc::UnboundedReceiver, + queue_timeouts_send: mpsc::UnboundedSender, + thread_channels: Vec<::async_channel::Sender>, + connections: ConnList, + handshakes: HandshakeTracker, +} + +#[allow(unsafe_code)] +unsafe impl Send for Worker {} + +impl Worker { + pub(crate) async fn new( + mut cfg: Config, + thread_id: ThreadTracker, + stop_working: crate::StopWorkingRecvCh, + token_check: Option>>, + sockets: Vec>, + queue: ::async_channel::Receiver, + ) -> ::std::io::Result { + let (queue_timeouts_send, queue_timeouts_recv) = + mpsc::unbounded_channel(); + let mut handshakes = HandshakeTracker::new( + thread_id, + cfg.ciphers.clone(), + cfg.key_exchanges.clone(), + ); + let mut server_keys = Vec::new(); + // make sure the keys are no longer in the config + ::core::mem::swap(&mut server_keys, &mut cfg.server_keys); + for k in server_keys.into_iter() { + if handshakes.add_server_key(k.id, k.priv_key).is_err() { + return Err(::std::io::Error::new( + ::std::io::ErrorKind::AlreadyExists, + "You can't use the same KeyID for multiple keys", + )); + } + } + for srv in cfg.servers.iter() { + if handshakes.add_server_domain(&srv.fqdn, &srv.keys).is_err() { + return Err(::std::io::Error::new( + ::std::io::ErrorKind::NotFound, + "Specified a KeyID that we don't have", + )); + } + } + + Ok(Self { + cfg, + thread_id, + rand: Random::new(), + stop_working, + token_check, + sockets, + queue, + queue_timeouts_recv, + queue_timeouts_send, + thread_channels: Vec::new(), + connections: ConnList::new(thread_id), + handshakes, + }) + } + /// Continuously loop and process work as needed + pub async fn work_loop(&mut self) { + 'mainloop: loop { + let work = ::tokio::select! { + tell_stopped = self.stop_working.recv() => { + if let Ok(stop_ch) = tell_stopped { + let _ = stop_ch.send( + crate::StopWorking::WorkerStopped).await; + } + break; + } + maybe_timeout = self.queue.recv() => { + match maybe_timeout { + Ok(work) => work, + Err(_) => break, + } + } + maybe_work = self.queue.recv() => { + match maybe_work { + Ok(work) => work, + Err(_) => break, + } + } + }; + match work { + Work::CountConnections(sender) => { + let conn_num = self.connections.len(); + let _ = sender.send(conn_num); + } + Work::Connect(conn_info) => { + // PERF: geolocation + + // Find the first destination with: + // * UDP port + // * a coherent pubkey/key exchange. + let destination = + conn_info.resolved.addresses.iter().find_map(|addr| { + if addr.port.is_none() + || addr + .handshake_ids + .iter() + .find(|h_srv| { + self.cfg.handshakes.contains(h_srv) + }) + .is_none() + { + // skip servers with no corresponding + // handshake types or no udp port + return None; + } + + // make sure this server has a public key + // that supports one of the key exchanges that + // *we* support + for idx in addr.public_key_idx.iter() { + // for each key, + // get all the possible key exchanges + let key_supported_k_x = + conn_info.resolved.public_keys + [idx.0 as usize] + .1 + .kind() + .key_exchanges(); + // consider only the key exchanges allowed + // in the dnssec record + let filtered_key_exchanges: Vec< + asym::KeyExchangeKind, + > = key_supported_k_x + .into_iter() + .filter(|k_x| { + conn_info + .resolved + .key_exchanges + .contains(k_x) + }) + .copied() + .collect(); + + // finally make sure that we support those + match self.cfg.key_exchanges.iter().find(|x| { + filtered_key_exchanges.contains(x) + }) { + Some(exchange) => { + return Some(( + addr, + conn_info.resolved.public_keys + [idx.0 as usize], + exchange.clone(), + )) + } + None => return None, + } + } + return None; + }); + let (addr, key, exchange) = match destination { + Some((addr, key, exchange)) => (addr, key, exchange), + None => { + let _ = + conn_info + .answer + .send(Err(crate::Error::Resolution( + "No selectable address and key combination" + .to_owned(), + ))); + continue 'mainloop; + } + }; + let hkdf_selected = match hkdf::client_select_hkdf( + &self.cfg, + &conn_info.resolved.hkdfs, + ) { + Some(hkdf_selected) => hkdf_selected, + None => { + let _ = conn_info.answer.send(Err( + handshake::Error::Negotiation.into(), + )); + continue 'mainloop; + } + }; + let cipher_selected = match sym::client_select_cipher( + &self.cfg, + &conn_info.resolved.ciphers, + ) { + Some(cipher_selected) => cipher_selected, + None => { + let _ = conn_info.answer.send(Err( + handshake::Error::Negotiation.into(), + )); + continue 'mainloop; + } + }; + + let (priv_key, pub_key) = + match exchange.new_keypair(&self.rand) { + Ok(pair) => pair, + Err(_) => { + ::tracing::error!("Failed to generate keys"); + let _ = conn_info.answer.send(Err( + handshake::Error::KeyGeneration.into(), + )); + continue 'mainloop; + } + }; + let hkdf; + + if let PubKey::Exchange(srv_pub) = key.1 { + let secret = + match priv_key.key_exchange(exchange, srv_pub) { + Ok(secret) => secret, + Err(_) => { + ::tracing::warn!( + "Could not run the key exchange" + ); + let _ = conn_info.answer.send(Err( + handshake::Error::Negotiation.into(), + )); + continue 'mainloop; + } + }; + hkdf = Hkdf::new(hkdf_selected, b"fenrir", secret); + } else { + // crate::dnssec already verifies that the keys + // listed in dnssec::Record.addresses.public_key_idx + // are PubKey::Exchange + unreachable!() + } + let mut conn = Connection::new( + hkdf, + cipher_selected, + connection::Role::Client, + &self.rand, + ); + + let auth_recv_id = self.connections.reserve_first(); + let service_conn_id = self.connections.reserve_first(); + conn.id_recv = auth_recv_id; + let (client_key_id, hshake) = match self + .handshakes + .add_client( + PrivKey::Exchange(priv_key), + PubKey::Exchange(pub_key), + conn_info.service_id, + service_conn_id, + conn, + conn_info.answer, + key.0, + ) { + Ok((client_key_id, hshake)) => (client_key_id, hshake), + Err(answer) => { + ::tracing::warn!("Too many client handshakes"); + let _ = answer.send(Err( + handshake::Error::TooManyClientHandshakes + .into(), + )); + continue 'mainloop; + } + }; + + // build request + let auth_info = dirsync::AuthInfo { + user: UserID::new_anonymous(), + token: Token::new_anonymous(&self.rand), + service_id: conn_info.service_id, + domain: conn_info.domain, + }; + let req_data = dirsync::ReqData { + nonce: dirsync::Nonce::new(&self.rand), + client_key_id, + id: auth_recv_id.0, //FIXME: is zero + auth: auth_info, + }; + let req = dirsync::Req { + key_id: key.0, + exchange, + hkdf: hkdf_selected, + cipher: cipher_selected, + exchange_key: pub_key, + data: dirsync::ReqInner::ClearText(req_data), + }; + let encrypt_start = ID::len() + req.encrypted_offset(); + let encrypt_end = encrypt_start + + req.encrypted_length( + cipher_selected.nonce_len(), + cipher_selected.tag_len(), + ); + let h_req = Handshake::new(HandshakeData::DirSync( + DirSync::Req(req), + )); + use connection::{PacketData, ID}; + let packet = Packet { + id: ID::Handshake, + data: PacketData::Handshake(h_req), + }; + + let tot_len = packet.len( + cipher_selected.nonce_len(), + cipher_selected.tag_len(), + ); + let mut raw = Vec::::with_capacity(tot_len); + raw.resize(tot_len, 0); + packet.serialize( + cipher_selected.nonce_len(), + cipher_selected.tag_len(), + &mut raw[..], + ); + // encrypt + if let Err(e) = hshake.connection.cipher_send.encrypt( + sym::AAD(&[]), + &mut raw[encrypt_start..encrypt_end], + ) { + ::tracing::error!("Can't encrypt DirSync Request"); + if let Some(client) = + self.handshakes.remove_client(client_key_id) + { + let _ = client.answer.send(Err(e.into())); + }; + continue 'mainloop; + } + + // send always from the first socket + // FIXME: select based on routing table + let sender = self.sockets[0].local_addr().unwrap(); + let dest = UdpClient(addr.as_sockaddr().unwrap()); + + // start the timeout right before sending the packet + hshake.timeout = Some(::tokio::task::spawn_local( + Self::handshake_timeout( + self.queue_timeouts_send.clone(), + client_key_id, + ), + )); + + // send packet + self.send_packet(raw, dest, UdpServer(sender)).await; + + continue 'mainloop; + } + Work::DropHandshake(key_id) => { + if let Some(connections) = + self.handshakes.timeout_client(key_id) + { + for conn_id in connections.into_iter() { + self.connections.remove(conn_id); + } + }; + } + Work::Recv(pkt) => { + self.recv(pkt).await; + } + } + } + } + async fn handshake_timeout( + timeout_queue: mpsc::UnboundedSender, + key_id: KeyID, + ) { + ::tokio::time::sleep(::std::time::Duration::from_secs(10)).await; + let _ = timeout_queue.send(Work::DropHandshake(key_id)); + } + /// Read and do stuff with the raw udp packet + async fn recv(&mut self, mut udp: RawUdp) { + if udp.packet.id.is_handshake() { + let handshake = match Handshake::deserialize( + &udp.data[connection::ID::len()..], + ) { + Ok(handshake) => handshake, + Err(e) => { + ::tracing::debug!("Handshake parsing: {}", e); + return; + } + }; + let action = match self.handshakes.recv_handshake( + handshake, + &mut udp.data[connection::ID::len()..], + ) { + Ok(action) => action, + Err(err) => { + ::tracing::debug!("Handshake recv error {}", err); + return; + } + }; + match action { + HandshakeAction::AuthNeeded(authinfo) => { + let req; + if let HandshakeData::DirSync(DirSync::Req(r)) = + authinfo.handshake.data + { + req = r; + } else { + ::tracing::error!("AuthInfo on non DS::Req"); + return; + } + use dirsync::ReqInner; + let req_data = match req.data { + ReqInner::ClearText(req_data) => req_data, + _ => { + ::tracing::error!("AuthNeeded: expected ClearText"); + assert!(false, "AuthNeeded: unreachable"); + return; + } + }; + // FIXME: This part can take a while, + // we should just spawn it probably + let maybe_auth_check = { + match &self.token_check { + None => { + if req_data.auth.user == auth::USERID_ANONYMOUS + { + Ok(true) + } else { + Ok(false) + } + } + Some(token_check) => { + let tk_check = token_check.lock().await; + tk_check( + req_data.auth.user, + req_data.auth.token, + req_data.auth.service_id, + req_data.auth.domain, + ) + .await + } + } + }; + let is_authenticated = match maybe_auth_check { + Ok(is_authenticated) => is_authenticated, + Err(_) => { + ::tracing::error!("error in token auth"); + // TODO: retry? + return; + } + }; + if !is_authenticated { + ::tracing::warn!( + "Wrong authentication for user {:?}", + req_data.auth.user + ); + // TODO: error response + return; + } + // Client has correctly authenticated + // TODO: contact the service, get the key and + // connection ID + let srv_conn_id = ID::new_rand(&self.rand); + let srv_secret = Secret::new_rand(&self.rand); + let head_len = req.cipher.nonce_len(); + let tag_len = req.cipher.tag_len(); + + let mut auth_conn = Connection::new( + authinfo.hkdf, + req.cipher, + connection::Role::Server, + &self.rand, + ); + auth_conn.id_send = IDSend(req_data.id); + // track connection + let auth_id_recv = self.connections.reserve_first(); + auth_conn.id_recv = auth_id_recv; + + let resp_data = dirsync::RespData { + client_nonce: req_data.nonce, + id: auth_conn.id_recv.0, + service_connection_id: srv_conn_id, + service_key: srv_secret, + }; + use crate::enc::sym::AAD; + // no aad for now + let aad = AAD(&mut []); + + use dirsync::RespInner; + let resp = dirsync::Resp { + client_key_id: req_data.client_key_id, + data: RespInner::ClearText(resp_data), + }; + let encrypt_from = ID::len() + resp.encrypted_offset(); + let encrypt_until = + encrypt_from + resp.encrypted_length(head_len, tag_len); + let resp_handshake = Handshake::new( + HandshakeData::DirSync(DirSync::Resp(resp)), + ); + use connection::{PacketData, ID}; + let packet = Packet { + id: ID::new_handshake(), + data: PacketData::Handshake(resp_handshake), + }; + let tot_len = packet.len(head_len, tag_len); + let mut raw_out = Vec::::with_capacity(tot_len); + raw_out.resize(tot_len, 0); + packet.serialize(head_len, tag_len, &mut raw_out); + + if let Err(e) = auth_conn + .cipher_send + .encrypt(aad, &mut raw_out[encrypt_from..encrypt_until]) + { + ::tracing::error!("can't encrypt: {:?}", e); + return; + } + self.send_packet(raw_out, udp.src, udp.dst).await; + } + HandshakeAction::ClientConnect(cci) => { + let ds_resp; + if let HandshakeData::DirSync(DirSync::Resp(resp)) = + cci.handshake.data + { + ds_resp = resp; + } else { + ::tracing::error!("ClientConnect on non DS::Resp"); + return; + } + // track connection + let resp_data; + if let dirsync::RespInner::ClearText(r_data) = ds_resp.data + { + resp_data = r_data; + } else { + ::tracing::error!( + "ClientConnect on non DS::Resp::ClearText" + ); + unreachable!(); + } + let auth_srv_conn = IDSend(resp_data.id); + let mut conn = cci.connection; + conn.id_send = auth_srv_conn; + let id_recv = conn.id_recv; + let cipher = conn.cipher_recv.kind(); + // track the connection to the authentication server + if self.connections.track(conn.into()).is_err() { + ::tracing::error!("Could not track new connection"); + self.connections.remove(id_recv); + let _ = cci.answer.send(Err( + handshake::Error::InternalTracking.into(), + )); + return; + } + if cci.service_id != auth::SERVICEID_AUTH { + // create and track the connection to the service + // SECURITY: xor with secrets + //FIXME: the Secret should be XORed with the client + // stored secret (if any) + let hkdf = Hkdf::new( + HkdfKind::Sha3, + cci.service_id.as_bytes(), + resp_data.service_key, + ); + let mut service_connection = Connection::new( + hkdf, + cipher, + connection::Role::Client, + &self.rand, + ); + service_connection.id_recv = cci.service_connection_id; + service_connection.id_send = + IDSend(resp_data.service_connection_id); + let _ = + self.connections.track(service_connection.into()); + } + let _ = + cci.answer.send(Ok((cci.srv_key_id, auth_srv_conn))); + } + HandshakeAction::Nothing => {} + }; + } + } + async fn send_packet( + &self, + data: Vec, + client: UdpClient, + server: UdpServer, + ) { + let src_sock = match self + .sockets + .iter() + .find(|&s| s.local_addr().unwrap() == server.0) + { + Some(src_sock) => src_sock, + None => { + ::tracing::error!( + "Can't send packet: Server changed listening ip{}!", + server.0 + ); + return; + } + }; + let _res = src_sock.send_to(&data, client.0).await; + } +} diff --git a/src/lib.rs b/src/lib.rs index c9543fb..fbca09a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,23 +18,27 @@ mod config; pub mod connection; pub mod dnssec; pub mod enc; +mod inner; -use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption}; -use ::std::{net::SocketAddr, pin::Pin, sync::Arc, vec, vec::Vec}; -use ::tokio::{ - macros::support::Future, net::UdpSocket, sync::RwLock, task::JoinHandle, -}; +#[cfg(test)] +mod tests; -use crate::enc::{ - asym, - hkdf::HkdfSha3, - sym::{CipherKind, CipherRecv, CipherSend}, +use ::std::{sync::Arc, vec::Vec}; +use ::tokio::{net::UdpSocket, sync::Mutex}; + +use crate::{ + auth::{Domain, ServiceID, TokenChecker}, + connection::{ + handshake, + socket::{SocketTracker, UdpClient, UdpServer}, + AuthServerConnections, Packet, + }, + inner::{ + worker::{ConnectInfo, RawUdp, Work, Worker}, + ThreadTracker, + }, }; pub use config::Config; -use connection::{ - handshake::{self, Handshake, HandshakeKey}, - Connection, -}; /// Main fenrir library errors #[derive(::thiserror::Error, Debug)] @@ -42,6 +46,9 @@ pub enum Error { /// The library was not initialized (run .start()) #[error("not initialized")] NotInitialized, + /// Error in setting up worker threads + #[error("Setup err: {0}")] + Setup(String), /// General I/O error #[error("IO: {0:?}")] IO(#[from] ::std::io::Error), @@ -54,295 +61,24 @@ pub enum Error { /// Key error #[error("key: {0:?}")] Key(#[from] crate::enc::Error), + /// Resolution problems. wrong or incomplete DNSSEC data + #[error("DNSSEC resolution: {0}")] + Resolution(String), + /// Wrapper on encryption errors + #[error("Encrypt: {0}")] + Encrypt(enc::Error), } -/// Information needed to reply after the key exchange -#[derive(Debug, Clone)] -pub struct AuthNeededInfo { - /// Parsed handshake - pub handshake: Handshake, - /// hkdf generated from the handshake - pub hkdf: HkdfSha3, - /// cipher to be used in both directions - pub cipher: CipherKind, +pub(crate) enum StopWorking { + WorkerStopped, + ListenerStopped, } -/// Intermediate actions to be taken while parsing the handshake -#[derive(Debug, Clone)] -pub enum HandshakeAction { - /// Parsing finished, all ok, nothing to do - None, - /// Packet parsed, now go perform authentication - AuthNeeded(AuthNeededInfo), -} -// No async here -struct FenrirInner { - key_exchanges: ArcSwapAny>>, - ciphers: ArcSwapAny>>, - keys: ArcSwapAny>>, -} - -#[allow(unsafe_code)] -unsafe impl Send for FenrirInner {} -#[allow(unsafe_code)] -unsafe impl Sync for FenrirInner {} - -// No async here -impl FenrirInner { - fn recv_handshake( - &self, - mut handshake: Handshake, - ) -> Result { - use connection::handshake::{ - dirsync::{self, DirSync}, - HandshakeData, - }; - match handshake.data { - HandshakeData::DirSync(ref mut ds) => match ds { - DirSync::Req(ref mut req) => { - let ephemeral_key = { - // Keep this block short to avoid contention - // on self.keys - let keys = self.keys.load(); - if let Some(h_k) = - keys.iter().find(|k| k.id == req.key_id) - { - use enc::asym::PrivKey; - // Directory synchronized can only use keys - // for key exchange, not signing keys - if let PrivKey::Exchange(k) = &h_k.key { - Some(k.clone()) - } else { - None - } - } else { - None - } - }; - if ephemeral_key.is_none() { - ::tracing::debug!("No such key id: {:?}", req.key_id); - return Err(handshake::Error::UnknownKeyID.into()); - } - let ephemeral_key = ephemeral_key.unwrap(); - { - let exchanges = self.key_exchanges.load(); - if None - == exchanges.iter().find(|&x| { - *x == (ephemeral_key.kind(), req.exchange) - }) - { - return Err( - enc::Error::UnsupportedKeyExchange.into() - ); - } - } - { - let ciphers = self.ciphers.load(); - if None == ciphers.iter().find(|&x| *x == req.cipher) { - return Err(enc::Error::UnsupportedCipher.into()); - } - } - let shared_key = match ephemeral_key - .key_exchange(req.exchange, req.exchange_key) - { - Ok(shared_key) => shared_key, - Err(e) => return Err(handshake::Error::Key(e).into()), - }; - let hkdf = HkdfSha3::new(b"fenrir", shared_key); - let secret_recv = hkdf.get_secret(b"to_server"); - let cipher_recv = CipherRecv::new(req.cipher, secret_recv); - use crate::enc::sym::AAD; - let aad = AAD(&mut []); // no aad for now - match cipher_recv.decrypt(aad, &mut req.data.ciphertext()) { - Ok(()) => req.data.mark_as_cleartext(), - Err(e) => { - return Err(handshake::Error::Key(e).into()); - } - } - req.set_data(dirsync::ReqData::deserialize(&req.data)?); - - let cipher = req.cipher; - - return Ok(HandshakeAction::AuthNeeded(AuthNeededInfo { - handshake, - hkdf, - cipher, - })); - } - DirSync::Resp(resp) => { - todo!(); - } - }, - } - } -} - -type TokenChecker = - fn( - user: auth::UserID, - token: auth::Token, - service_id: auth::ServiceID, - domain: auth::Domain, - ) -> ::futures::future::BoxFuture<'static, Result>; - -// PERF: move to arcswap: -// We use multiple UDP sockets. -// We need a list of them all and to scan this list. -// But we want to be able to update this list without locking everyone -// So we should wrap into ArcSwap, and use unsafe `as_raf_fd()` and -// `from_raw_fd` to update the list. -// This means that we will have multiple `UdpSocket` per actual socket -// so we have to handle `drop()` manually, and garbage-collect the ones we -// are no longer using in the background. sigh. -// Just go with a ArcSwapAny, Arc>>); -struct SocketList { - list: ArcSwap>, -} -impl SocketList { - fn new() -> Self { - Self { - list: ArcSwap::new(Arc::new(Vec::new())), - } - } - // TODO: fn rm_socket() - fn rm_all(&self) -> Self { - let new_list = Arc::new(Vec::new()); - let old_list = self.list.swap(new_list); - Self { - list: old_list.into(), - } - } - async fn add_socket( - &self, - socket: Arc, - handle: JoinHandle<::std::io::Result<()>>, - ) { - // we could simplify this into just a `.swap` instead of `.rcu` but - // it is not yet guaranteed that only one thread will call this fn - // ...we don't need performance here anyway - let arc_handle = Arc::new(handle); - self.list.rcu(|old_list| { - let mut new_list = Arc::new(Vec::with_capacity(old_list.len() + 1)); - new_list = old_list.to_vec().into(); - Arc::get_mut(&mut new_list) - .unwrap() - .push((socket.clone(), arc_handle.clone())); - new_list - }); - } - /// This method assumes no other `add_sockets` are being run - async fn stop_all(mut self) { - let mut arc_list = self.list.into_inner(); - let list = loop { - match Arc::try_unwrap(arc_list) { - Ok(list) => break list, - Err(arc_retry) => { - arc_list = arc_retry; - ::tokio::time::sleep(::core::time::Duration::from_millis( - 50, - )) - .await; - } - } - }; - for (_socket, mut handle) in list.into_iter() { - Arc::get_mut(&mut handle).unwrap().await; - } - } - fn lock(&self) -> SocketListRef { - SocketListRef { - list: self.list.load_full(), - } - } -} -// TODO: impl Drop for SocketList -struct SocketListRef { - list: Arc>, -} -impl SocketListRef { - fn find(&self, sock: UdpServer) -> Option> { - match self - .list - .iter() - .find(|&(s, _)| s.local_addr().unwrap() == sock.0) - { - Some((sock_srv, _)) => Some(sock_srv.clone()), - None => None, - } - } -} - -#[derive(Debug, Copy, Clone)] -struct UdpClient(SocketAddr); -#[derive(Debug, Copy, Clone)] -struct UdpServer(SocketAddr); - -struct RawUdp { - data: Vec, - src: UdpClient, - dst: UdpServer, -} - -enum Work { - Recv(RawUdp), -} - -// PERF: Arc> loks a bit too much, need to find -// faster ways to do this -struct ConnList { - connections: Vec>>, - /// Bitmap to track which connection ids are used or free - ids_used: Vec<::bitmaps::Bitmap<1024>>, -} - -impl ConnList { - fn new() -> Self { - let mut bitmap_id = ::bitmaps::Bitmap::<1024>::new(); - bitmap_id.set(0, true); // ID(0) == handshake - Self { - connections: Vec::with_capacity(128), - ids_used: vec![bitmap_id], - } - } - fn reserve_first(&mut self, mut conn: Connection) -> Arc { - // uhm... bad things are going on here: - // * id must be initialized, but only because: - // * rust does not understand that after the `!found` id is always - // initialized - // * `ID::new_u64` is really safe only with >0, but here it always is - // ...we should probably rewrite it in better, safer rust - let mut id: u64 = 0; - let mut found = false; - for (i, b) in self.ids_used.iter_mut().enumerate() { - match b.first_false_index() { - Some(idx) => { - b.set(idx, true); - id = ((i as u64) * 1024) + (idx as u64); - found = true; - break; - } - None => {} - } - } - if !found { - let mut new_bitmap = ::bitmaps::Bitmap::<1024>::new(); - new_bitmap.set(0, true); - id = (self.ids_used.len() as u64) * 1024; - self.ids_used.push(new_bitmap); - } - let new_id = connection::ID::new_u64(id); - conn.id = new_id; - let conn = Arc::new(conn); - if (self.connections.len() as u64) < id { - self.connections.push(Some(conn.clone())); - } else { - // very probably redundant - self.connections[id as usize] = Some(conn.clone()); - } - conn - } -} +pub(crate) type StopWorkingSendCh = + ::tokio::sync::broadcast::Sender<::tokio::sync::mpsc::Sender>; +pub(crate) type StopWorkingRecvCh = ::tokio::sync::broadcast::Receiver< + ::tokio::sync::mpsc::Sender, +>; /// Instance of a fenrir endpoint #[allow(missing_copy_implementations, missing_debug_implementations)] @@ -350,397 +86,576 @@ pub struct Fenrir { /// library Configuration cfg: Config, /// listening udp sockets - sockets: SocketList, + sockets: Vec, /// DNSSEC resolver, with failovers - dnssec: Option, + dnssec: dnssec::Dnssec, /// Broadcast channel to tell workers to stop working - stop_working: ::tokio::sync::broadcast::Sender, - /// Private keys used in the handshake - _inner: Arc, + stop_working: StopWorkingSendCh, /// where to ask for token check - token_check: Arc>, - /// MPMC work queue. sender - work_send: Arc<::async_channel::Sender>, - /// MPMC work queue. receiver - work_recv: Arc<::async_channel::Receiver>, - // PERF: rand uses syscalls. should we do that async? - rand: ::ring::rand::SystemRandom, - /// list of Established connections - connections: Arc>, + token_check: Option>>, + /// tracks the connections to authentication servers + /// so that we can reuse them + conn_auth_srv: Mutex, + // TODO: find a way to both increase and decrease these two in a thread-safe + // manner + _thread_pool: Vec<::std::thread::JoinHandle<()>>, + _thread_work: Arc>>, } - // TODO: graceful vs immediate stop impl Drop for Fenrir { fn drop(&mut self) { - self.stop_sync() + ::tracing::debug!( + "Fenrir fast shutdown.\ + Some threads might remain a bit longer" + ); + let _ = self.stop_sync(); } } impl Fenrir { + /// Gracefully stop all listeners and workers + /// only return when all resources have been deallocated + pub async fn graceful_stop(mut self) { + ::tracing::debug!("Fenrir full shut down"); + if let Some((ch, listeners, workers)) = self.stop_sync() { + self.stop_wait(ch, listeners, workers).await; + } + } + fn stop_sync( + &mut self, + ) -> Option<( + ::tokio::sync::mpsc::Receiver, + Vec<::tokio::task::JoinHandle<::std::io::Result<()>>>, + usize, + )> { + let workers_num = self._thread_work.len(); + if self.sockets.len() > 0 || self._thread_work.len() > 0 { + let (ch_send, ch_recv) = ::tokio::sync::mpsc::channel(4); + let _ = self.stop_working.send(ch_send); + let mut old_listeners = Vec::with_capacity(self.sockets.len()); + ::core::mem::swap(&mut old_listeners, &mut self.sockets); + self._thread_pool.clear(); + let listeners = old_listeners + .into_iter() + .map(|(_, joinable)| joinable) + .collect(); + Some((ch_recv, listeners, workers_num)) + } else { + None + } + } + async fn stop_wait( + &mut self, + mut ch: ::tokio::sync::mpsc::Receiver, + listeners: Vec<::tokio::task::JoinHandle<::std::io::Result<()>>>, + mut workers_num: usize, + ) { + let mut listeners_num = listeners.len(); + while listeners_num > 0 && workers_num > 0 { + match ch.recv().await { + Some(stopped) => match stopped { + StopWorking::WorkerStopped => workers_num = workers_num - 1, + StopWorking::ListenerStopped => { + listeners_num = listeners_num - 1 + } + }, + _ => break, + } + } + for l in listeners.into_iter() { + if let Err(e) = l.await { + ::tracing::error!("Unclean shutdown of listener: {:?}", e); + } + } + } /// Create a new Fenrir endpoint - pub fn new(config: &Config) -> Result { - let listen_num = config.listen.len(); + /// spawn threads pinned to cpus in our own way with tokio's runtime + pub async fn with_threads( + config: &Config, + tokio_rt: Arc<::tokio::runtime::Runtime>, + ) -> Result { let (sender, _) = ::tokio::sync::broadcast::channel(1); - let (work_send, work_recv) = ::async_channel::unbounded::(); - let endpoint = Fenrir { - cfg: config.clone(), - sockets: SocketList::new(), - dnssec: None, - stop_working: sender, - _inner: Arc::new(FenrirInner { - ciphers: ArcSwapAny::new(Arc::new(Vec::new())), - key_exchanges: ArcSwapAny::new(Arc::new(Vec::new())), - keys: ArcSwapAny::new(Arc::new(Vec::new())), - }), - token_check: Arc::new(ArcSwapOption::from(None)), - work_send: Arc::new(work_send), - work_recv: Arc::new(work_recv), - rand: ::ring::rand::SystemRandom::new(), - connections: Arc::new(RwLock::new(ConnList::new())), + let dnssec = dnssec::Dnssec::new(&config.resolvers)?; + // bind sockets early so we can change "port 0" (aka: random) + // in the config + let binded_sockets = Self::bind_sockets(&config).await?; + let socket_addrs = binded_sockets + .iter() + .map(|s| s.local_addr().unwrap()) + .collect(); + let cfg = { + let mut tmp = config.clone(); + tmp.listen = socket_addrs; + tmp }; + let mut endpoint = Self { + cfg, + sockets: Vec::with_capacity(config.listen.len()), + dnssec, + stop_working: sender, + token_check: None, + conn_auth_srv: Mutex::new(AuthServerConnections::new()), + _thread_pool: Vec::new(), + _thread_work: Arc::new(Vec::new()), + }; + endpoint + .start_work_threads_pinned(tokio_rt, binded_sockets.clone()) + .await?; + endpoint.run_listeners(binded_sockets).await?; Ok(endpoint) } - - /// Start all workers, listeners - pub async fn start(&mut self) -> Result<(), Error> { - if let Err(e) = self.add_sockets().await { - self.stop().await; - return Err(e.into()); + /// Create a new Fenrir endpoint + /// Get the workers that you can use in a tokio LocalSet + /// You should: + /// * move these workers each in its own thread + /// * make sure that the threads are pinned on the cpu + pub async fn with_workers( + config: &Config, + ) -> Result<(Self, Vec), Error> { + let (stop_working, _) = ::tokio::sync::broadcast::channel(1); + let dnssec = dnssec::Dnssec::new(&config.resolvers)?; + // bind sockets early so we can change "port 0" (aka: random) + // in the config + let binded_sockets = Self::bind_sockets(&config).await?; + let socket_addrs = binded_sockets + .iter() + .map(|s| s.local_addr().unwrap()) + .collect(); + let cfg = { + let mut tmp = config.clone(); + tmp.listen = socket_addrs; + tmp + }; + let mut endpoint = Self { + cfg, + sockets: Vec::with_capacity(config.listen.len()), + dnssec, + stop_working: stop_working.clone(), + token_check: None, + conn_auth_srv: Mutex::new(AuthServerConnections::new()), + _thread_pool: Vec::new(), + _thread_work: Arc::new(Vec::new()), + }; + let worker_num = config.threads.unwrap().get(); + let mut workers = Vec::with_capacity(worker_num); + for _ in 0..worker_num { + workers.push( + endpoint.start_single_worker(binded_sockets.clone()).await?, + ); } - self.dnssec = Some(dnssec::Dnssec::new(&self.cfg.resolvers).await?); - Ok(()) + endpoint.run_listeners(binded_sockets).await?; + Ok((endpoint, workers)) + } + /// Returns the list of the actual addresses we are listening on + /// Note that this can be different from what was configured: + /// if you specified UDP port 0 a random one has been assigned to you + /// by the operating system. + pub fn addresses(&self) -> Vec<::std::net::SocketAddr> { + self.sockets.iter().map(|(s, _)| s.clone()).collect() } - /// Stop all workers, listeners - /// asyncronous version for Drop - fn stop_sync(&mut self) { - let _ = self.stop_working.send(true); - let mut toempty_sockets = self.sockets.rm_all(); - let task = ::tokio::task::spawn(toempty_sockets.stop_all()); - let _ = ::futures::executor::block_on(task); - self.dnssec = None; - } - - /// Stop all workers, listeners - pub async fn stop(&mut self) { - let _ = self.stop_working.send(true); - let mut toempty_sockets = self.sockets.rm_all(); - toempty_sockets.stop_all().await; - self.dnssec = None; - } - - /// Enable some common socket options. This is just the unsafe part - fn enable_sock_opt( - fd: ::std::os::fd::RawFd, - option: ::libc::c_int, - value: ::libc::c_int, - ) -> ::std::io::Result<()> { - #[allow(unsafe_code)] - unsafe { - #[allow(trivial_casts)] - let val = &value as *const _ as *const ::libc::c_void; - let size = ::std::mem::size_of_val(&value) as ::libc::socklen_t; - // always clear the error bit before doing a new syscall - let _ = ::std::io::Error::last_os_error(); - let ret = - ::libc::setsockopt(fd, ::libc::SOL_SOCKET, option, val, size); - if ret != 0 { - return Err(::std::io::Error::last_os_error()); - } - } - Ok(()) - } - - /// Add all UDP sockets found in config - /// and start listening for packets - async fn add_sockets(&self) -> ::std::io::Result<()> { - let sockets = self.cfg.listen.iter().map(|s_addr| async { - let socket = - ::tokio::spawn(Self::bind_udp(s_addr.clone())).await??; - Ok(socket) + // only call **before** starting all threads + /// bind all UDP sockets found in config + async fn bind_sockets(cfg: &Config) -> Result>, Error> { + // try to bind multiple sockets in parallel + let mut sock_set = ::tokio::task::JoinSet::new(); + cfg.listen.iter().for_each(|s_addr| { + let socket_address = s_addr.clone(); + sock_set.spawn(async move { + connection::socket::bind_udp(socket_address).await + }); }); - let sockets = ::futures::future::join_all(sockets).await; - for s_res in sockets.into_iter() { - match s_res { - Ok(s) => { - let stop_working = self.stop_working.subscribe(); - let arc_s = Arc::new(s); - let join = ::tokio::spawn(Self::listen_udp( - stop_working, - self.work_send.clone(), - arc_s.clone(), - )); - self.sockets.add_socket(arc_s, join); - } + // make sure we either return all of them, or none + let mut all_socks = Vec::with_capacity(cfg.listen.len()); + while let Some(join_res) = sock_set.join_next().await { + match join_res { + Ok(s_res) => match s_res { + Ok(s) => { + all_socks.push(Arc::new(s)); + } + Err(e) => { + return Err(e.into()); + } + }, Err(e) => { - return Err(e); + return Err(Error::Setup(e.to_string())); } } } - Ok(()) + assert!(all_socks.len() == cfg.listen.len(), "missing socks"); + Ok(all_socks) } - - /// Add an async udp listener - async fn bind_udp(sock: SocketAddr) -> ::std::io::Result { - let socket = UdpSocket::bind(sock).await?; - - use ::std::os::fd::AsRawFd; - let fd = socket.as_raw_fd(); - // can be useful later on for reloads - Self::enable_sock_opt(fd, ::libc::SO_REUSEADDR, 1)?; - Self::enable_sock_opt(fd, ::libc::SO_REUSEPORT, 1)?; - - // We will do path MTU discovery by ourselves, - // always set the "don't fragment" bit - if sock.is_ipv6() { - Self::enable_sock_opt(fd, ::libc::IPV6_DONTFRAG, 1)?; - } else { - // FIXME: linux only - Self::enable_sock_opt( - fd, - ::libc::IP_MTU_DISCOVER, - ::libc::IP_PMTUDISC_DO, - )?; + // only call **after** starting all threads + /// spawn all listeners + async fn run_listeners( + &mut self, + socks: Vec>, + ) -> Result<(), Error> { + for sock in socks.into_iter() { + let sockaddr = sock.local_addr().unwrap(); + let stop_working = self.stop_working.subscribe(); + let th_work = self._thread_work.clone(); + let joinable = ::tokio::spawn(async move { + Self::listen_udp(stop_working, th_work, sock.clone()).await + }); + self.sockets.push((sockaddr, joinable)); } - - Ok(socket) + Ok(()) } /// Run a dedicated loop to read packets on the listening socket async fn listen_udp( - mut stop_working: ::tokio::sync::broadcast::Receiver, - work_queue: Arc<::async_channel::Sender>, + mut stop_working: StopWorkingRecvCh, + work_queues: Arc>>, socket: Arc, ) -> ::std::io::Result<()> { // jumbo frames are 9K max let sock_receiver = UdpServer(socket.local_addr()?); let mut buffer: [u8; 9000] = [0; 9000]; + let queues_num = work_queues.len() as u64; loop { let (bytes, sock_sender) = ::tokio::select! { - _done = stop_working.recv() => { - break; + tell_stopped = stop_working.recv() => { + drop(socket); + if let Ok(stop_ch) = tell_stopped { + let _ = stop_ch + .send(StopWorking::ListenerStopped).await; + } + return Ok(()); } result = socket.recv_from(&mut buffer) => { - result? + let (bytes, from) = result?; + (bytes, UdpClient(from)) } }; let data: Vec = buffer[..bytes].to_vec(); - work_queue.send(Work::Recv(RawUdp { - data, - src: UdpClient(sock_sender), - dst: sock_receiver, - })); - } - Ok(()) - } + // we very likely have multiple threads, pinned to different cpus. + // use the ConnectionID to send the same connection + // to the same thread. + // Handshakes have connection ID 0, so we use the sender's UDP port + + let packet = match Packet::deserialize_id(&data) { + Ok(packet) => packet, + Err(_) => continue, // packet way too short, ignore. + }; + let thread_idx: usize = { + use connection::packet::ConnectionID; + match packet.id { + ConnectionID::Handshake => { + let send_port = sock_sender.0.port() as u64; + (send_port % queues_num) as usize + } + ConnectionID::ID(id) => (id.get() % queues_num) as usize, + } + }; + let _ = work_queues[thread_idx] + .send(Work::Recv(RawUdp { + src: sock_sender, + dst: sock_receiver, + packet, + data, + })) + .await; + } + } /// Get the raw TXT record of a Fenrir domain - pub async fn resolv_str(&self, domain: &str) -> Result { - match &self.dnssec { - Some(dnssec) => Ok(dnssec.resolv(domain).await?), - None => Err(Error::NotInitialized), + pub async fn resolv_txt(&self, domain: &Domain) -> Result { + match self.dnssec.resolv(domain).await { + Ok(res) => Ok(res), + Err(e) => Err(e.into()), } } /// Get the raw TXT record of a Fenrir domain - pub async fn resolv(&self, domain: &str) -> Result { - let record_str = self.resolv_str(domain).await?; + pub async fn resolv( + &self, + domain: &Domain, + ) -> Result { + let record_str = self.resolv_txt(domain).await?; Ok(dnssec::Dnssec::parse_txt_record(&record_str)?) } - /// Loop continuously and parse packets and other work - pub async fn work_loop(&self) { - let mut stop_working = self.stop_working.subscribe(); + /// Connect to a service, doing the dnssec resolution ourselves + pub async fn connect( + &self, + domain: &Domain, + service: ServiceID, + ) -> Result<(), Error> { + let resolved = self.resolv(domain).await?; + self.connect_resolved(resolved, domain, service).await + } + /// Connect to a service, with the user provided details + pub async fn connect_resolved( + &self, + resolved: dnssec::Record, + domain: &Domain, + service: ServiceID, + ) -> Result<(), Error> { loop { - let work = ::tokio::select! { - _done = stop_working.recv() => { - break; + // check if we already have a connection to that auth. srv + let is_reserved = { + let mut conn_auth_lock = self.conn_auth_srv.lock().await; + conn_auth_lock.get_or_reserve(&resolved) + }; + use connection::Reservation; + match is_reserved { + Reservation::Waiting => { + use ::std::time::Duration; + use ::tokio::time::sleep; + // PERF: exponential backoff. + // or we can have a broadcast channel + sleep(Duration::from_millis(50)).await; + continue; } - maybe_work = self.work_recv.recv() => { - match maybe_work { - Ok(work) => work, - Err(_) => break, + Reservation::Reserved => break, + Reservation::Present(_id_send) => { + //TODO: reuse connection + todo!() + } + } + } + // Spot reserved for the connection + + // find the thread with less connections + let th_num = self._thread_work.len(); + let mut conn_count = Vec::::with_capacity(th_num); + let mut wait_res = + Vec::<::tokio::sync::oneshot::Receiver>::with_capacity( + th_num, + ); + for th in self._thread_work.iter() { + let (send, recv) = ::tokio::sync::oneshot::channel(); + wait_res.push(recv); + let _ = th.send(Work::CountConnections(send)).await; + } + for ch in wait_res.into_iter() { + if let Ok(conn_num) = ch.await { + conn_count.push(conn_num); + } + } + if conn_count.len() != th_num { + return Err(Error::IO(::std::io::Error::new( + ::std::io::ErrorKind::NotConnected, + "can't connect to a thread", + ))); + } + let thread_idx = conn_count + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.cmp(b)) + .map(|(index, _)| index) + .unwrap(); + + // and tell that thread to connect somewhere + let (send, recv) = ::tokio::sync::oneshot::channel(); + let _ = self._thread_work[thread_idx] + .send(Work::Connect(ConnectInfo { + answer: send, + resolved: resolved.clone(), + service_id: service, + domain: domain.clone(), + })) + .await; + + match recv.await { + Ok(res) => { + match res { + Err(e) => { + let mut conn_auth_lock = + self.conn_auth_srv.lock().await; + conn_auth_lock.remove_reserved(&resolved); + Err(e) + } + Ok((key_id, id_send)) => { + let key = resolved + .public_keys + .iter() + .find(|k| k.0 == key_id) + .unwrap(); + let mut conn_auth_lock = + self.conn_auth_srv.lock().await; + conn_auth_lock.add(&key.1, id_send, &resolved); + + //FIXME: user needs to somehow track the connection + Ok(()) } } - }; - match work { - Work::Recv(pkt) => { - self.recv(pkt).await; - } + } + Err(e) => { + // Thread dropped the sender. no more thread? + let mut conn_auth_lock = self.conn_auth_srv.lock().await; + conn_auth_lock.remove_reserved(&resolved); + Err(Error::IO(::std::io::Error::new( + ::std::io::ErrorKind::Interrupted, + "recv failure on connect: ".to_owned() + &e.to_string(), + ))) } } } - const MIN_PACKET_BYTES: usize = 8; - /// Read and do stuff with the raw udp packet - async fn recv(&self, udp: RawUdp) { - if udp.data.len() < Self::MIN_PACKET_BYTES { - return; + // needs to be called before run_listeners + async fn start_single_worker( + &mut self, + socks: Vec>, + ) -> ::std::result::Result { + let thread_idx = self._thread_work.len() as u16; + let max_threads = self.cfg.threads.unwrap().get() as u16; + if thread_idx >= max_threads { + ::tracing::error!( + "thread id higher than number of threads in config" + ); + assert!( + false, + "thread_idx is an index that can't reach cfg.threads" + ); + return Err(Error::Setup("Thread id > threads_max".to_owned())); } - use connection::ID; - let raw_id: [u8; 8] = (udp.data[..8]).try_into().expect("unreachable"); - if ID::from(raw_id).is_handshake() { - use connection::handshake::Handshake; - let handshake = match Handshake::deserialize(&udp.data[8..]) { - Ok(handshake) => handshake, - Err(e) => { - ::tracing::warn!("Handshake parsing: {}", e); - return; + let thread_id = ThreadTracker { + id: thread_idx, + total: max_threads, + }; + let (work_send, work_recv) = ::async_channel::unbounded::(); + let worker = Worker::new( + self.cfg.clone(), + thread_id, + self.stop_working.subscribe(), + self.token_check.clone(), + socks, + work_recv, + ) + .await?; + // don't keep around private keys too much + if (thread_idx + 1) == max_threads { + self.cfg.server_keys.clear(); + } + loop { + let queues_lock = match Arc::get_mut(&mut self._thread_work) { + Some(queues_lock) => queues_lock, + None => { + // should not even ever happen + ::tokio::time::sleep(::std::time::Duration::from_millis( + 50, + )) + .await; + continue; } }; - let action = match self._inner.recv_handshake(handshake) { - Ok(action) => action, - Err(err) => { - ::tracing::debug!("Handshake recv error {}", err); - return; + queues_lock.push(work_send); + break; + } + Ok(worker) + } + + // needs to be called before add_sockets + /// Start one working thread for each physical cpu + /// threads are pinned to each cpu core. + /// Work will be divided and rerouted so that there is no need to lock + async fn start_work_threads_pinned( + &mut self, + tokio_rt: Arc<::tokio::runtime::Runtime>, + sockets: Vec>, + ) -> ::std::result::Result<(), Error> { + use ::std::sync::Mutex; + let hw_topology = match ::hwloc2::Topology::new() { + Some(hw_topology) => Arc::new(Mutex::new(hw_topology)), + None => { + return Err(Error::Setup( + "Can't get hardware topology".to_owned(), + )) + } + }; + let cores; + { + let topology_lock = hw_topology.lock().unwrap(); + let all_cores = match topology_lock + .objects_with_type(&::hwloc2::ObjectType::Core) + { + Ok(all_cores) => all_cores, + Err(_) => { + return Err(Error::Setup("can't list cores".to_owned())) } }; - match action { - HandshakeAction::AuthNeeded(authinfo) => { - let tk_check = match self.token_check.load_full() { - Some(tk_check) => tk_check, - None => { - ::tracing::error!( - "Handshake received, but no tocken_checker" - ); - return; - } - }; - use handshake::{ - dirsync::{self, DirSync}, - HandshakeData, - }; - match authinfo.handshake.data { - HandshakeData::DirSync(ds) => match ds { - DirSync::Req(req) => { - use dirsync::ReqInner; - let req_data = match req.data { - ReqInner::Data(req_data) => req_data, - _ => { - ::tracing::error!( - "token_check: expected Data" - ); - return; - } - }; - let is_authenticated = match tk_check( - req_data.auth.user, - req_data.auth.token, - req_data.auth.service_id, - req_data.auth.domain, - ) - .await - { - Ok(is_authenticated) => is_authenticated, - Err(_) => { - ::tracing::error!( - "error in token auth" - ); - // TODO: retry? - return; - } - }; - if !is_authenticated { - ::tracing::warn!( - "Wrong authentication for user {:?}", - req_data.auth.user - ); - // TODO: error response - return; - } - // Client has correctly authenticated - // TODO: contact the service, get the key and - // connection ID - let srv_conn_id = - connection::ID::new_rand(&self.rand); - let srv_secret = - enc::sym::Secret::new_rand(&self.rand); + cores = all_cores.len(); + if cores <= 0 || !topology_lock.support().cpu().set_thread() { + ::tracing::error!("No support for CPU pinning"); + return Err(Error::Setup("No cpu pinning support".to_owned())); + } + } + for core in 0..cores { + ::tracing::debug!("Spawning thread {}", core); + let th_topology = hw_topology.clone(); + let th_tokio_rt = tokio_rt.clone(); + let th_config = self.cfg.clone(); + let (work_send, work_recv) = ::async_channel::unbounded::(); + let th_stop_working = self.stop_working.subscribe(); + let th_token_check = self.token_check.clone(); + let th_sockets = sockets.clone(); + let thread_id = ThreadTracker { + total: cores as u16, + id: 1 + (core as u16), + }; - let raw_conn = Connection::new( - authinfo.hkdf, - req.cipher, - connection::Role::Server, - &self.rand, - ); - // track connection - let auth_conn = { - let mut lock = - self.connections.write().await; - lock.reserve_first(raw_conn) - }; - - // TODO: move all the next bits into - // dirsync::Req::respond(...) - - let resp_data = dirsync::RespData { - client_nonce: req_data.nonce, - id: auth_conn.id, - service_id: srv_conn_id, - service_key: srv_secret, - }; - use crate::enc::sym::AAD; - let aad = AAD(&mut []); // no aad for now - let mut data = auth_conn - .cipher_send - .make_data(dirsync::RespData::len()); - - if let Err(e) = auth_conn - .cipher_send - .encrypt(aad, &mut data) - { - ::tracing::error!("can't encrypt: {:?}", e); - return; - } - let resp = dirsync::Resp { - client_key_id: req_data.client_key_id, - enc: data.get_raw(), - }; - let resp_handshake = Handshake::new( - HandshakeData::DirSync(DirSync::Resp(resp)), - ); - use connection::{Packet, PacketData, ID}; - let packet = Packet { - id: ID::new_handshake(), - data: PacketData::Handshake(resp_handshake), - }; - let mut raw_out = - Vec::::with_capacity(packet.len()); - packet.serialize(&mut raw_out); - self.send_packet(raw_out, udp.src, udp.dst) - .await; - } - _ => { - todo!() - } - }, + let join_handle = ::std::thread::spawn(move || { + // bind to a specific core + let th_pinning; + { + let mut th_topology_lock = th_topology.lock().unwrap(); + let th_cores = th_topology_lock + .objects_with_type(&::hwloc2::ObjectType::Core) + .unwrap(); + let cpuset = th_cores.get(core).unwrap().cpuset().unwrap(); + th_pinning = th_topology_lock.set_cpubind( + cpuset, + ::hwloc2::CpuBindFlags::CPUBIND_THREAD, + ); + } + match th_pinning { + Ok(_) => {} + Err(_) => { + ::tracing::error!("Can't bind thread to cpu"); + return; } } - _ => {} - }; + // finally run the main worker. + // make sure things stay on this thread + let tk_local = ::tokio::task::LocalSet::new(); + let _ = tk_local.block_on(&th_tokio_rt, async move { + let mut worker = match Worker::new( + th_config, + thread_id, + th_stop_working, + th_token_check, + th_sockets, + work_recv, + ) + .await + { + Ok(worker) => worker, + Err(_) => return, + }; + worker.work_loop().await + }); + }); + loop { + let queues_lock = match Arc::get_mut(&mut self._thread_work) { + Some(queues_lock) => queues_lock, + None => { + ::tokio::time::sleep( + ::std::time::Duration::from_millis(50), + ) + .await; + continue; + } + }; + queues_lock.push(work_send); + break; + } + self._thread_pool.push(join_handle); } - // copy packet, spawn - todo!(); - } - async fn send_packet( - &self, - data: Vec, - client: UdpClient, - server: UdpServer, - ) { - let src_sock; - { - let sockets = self.sockets.lock(); - src_sock = match sockets.find(server) { - Some(src_sock) => src_sock, - None => { - ::tracing::error!( - "Can't send packet: Server changed listening ip!" - ); - return; - } - }; - } - src_sock.send_to(&data, client.0); + // don't keep around private keys too much + self.cfg.server_keys.clear(); + Ok(()) } } diff --git a/src/tests.rs b/src/tests.rs new file mode 100644 index 0000000..acf57cc --- /dev/null +++ b/src/tests.rs @@ -0,0 +1,102 @@ +use crate::*; + +#[::tracing_test::traced_test] +#[::tokio::test] +async fn test_connection_dirsync() { + use enc::asym::{KeyID, PrivKey, PubKey}; + let rand = enc::Random::new(); + let (priv_exchange_key, pub_exchange_key) = + match enc::asym::KeyExchangeKind::X25519DiffieHellman.new_keypair(&rand) + { + Ok((privkey, pubkey)) => { + (PrivKey::Exchange(privkey), PubKey::Exchange(pubkey)) + } + Err(_) => { + assert!(false, "Can't generate random keypair"); + return; + } + }; + let cfg_client = { + let mut cfg = config::Config::default(); + cfg.threads = Some(::core::num::NonZeroUsize::new(1).unwrap()); + cfg + }; + let test_domain: Domain = "example.com".into(); + let cfg_server = { + let mut cfg = cfg_client.clone(); + cfg.server_keys = [config::ServerKey { + id: KeyID(42), + priv_key: priv_exchange_key, + pub_key: pub_exchange_key, + }] + .to_vec(); + cfg.servers = [config::AuthServer { + fqdn: test_domain.clone(), + keys: [KeyID(42)].to_vec(), + }] + .to_vec(); + cfg + }; + + let (server, mut srv_workers) = + Fenrir::with_workers(&cfg_server).await.unwrap(); + let (client, mut cli_workers) = + Fenrir::with_workers(&cfg_client).await.unwrap(); + let mut srv_worker = srv_workers.pop().unwrap(); + let mut cli_worker = cli_workers.pop().unwrap(); + + ::std::thread::spawn(move || { + let rt = ::tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + let local_thread = ::tokio::task::LocalSet::new(); + local_thread.spawn_local(async move { + srv_worker.work_loop().await; + }); + + local_thread.spawn_local(async move { + ::tokio::time::sleep(::std::time::Duration::from_millis(100)).await; + cli_worker.work_loop().await; + }); + rt.block_on(local_thread); + }); + + use crate::{ + connection::handshake::HandshakeID, + dnssec::{record, Record}, + }; + + let port: u16 = server.addresses()[0].port(); + + let dnssec_record = Record { + public_keys: [(KeyID(42), pub_exchange_key)].to_vec(), + addresses: [record::Address { + ip: ::std::net::IpAddr::V4(::std::net::Ipv4Addr::new(127, 0, 0, 1)), + port: Some(::core::num::NonZeroU16::new(port).unwrap()), + priority: record::AddressPriority::P1, + weight: record::AddressWeight::W1, + handshake_ids: [HandshakeID::DirectorySynchronized].to_vec(), + public_key_idx: [record::PubKeyIdx(0)].to_vec(), + }] + .to_vec(), + key_exchanges: [enc::asym::KeyExchangeKind::X25519DiffieHellman] + .to_vec(), + hkdfs: [enc::hkdf::HkdfKind::Sha3].to_vec(), + ciphers: [enc::sym::CipherKind::XChaCha20Poly1305].to_vec(), + }; + + ::tokio::time::sleep(::std::time::Duration::from_millis(500)).await; + match client + .connect_resolved(dnssec_record, &test_domain, auth::SERVICEID_AUTH) + .await + { + Ok(()) => {} + Err(e) => { + assert!(false, "Err on client connection: {:?}", e); + } + } + + server.graceful_stop().await; + client.graceful_stop().await; +} diff --git a/var/git-pre-commit b/var/git-pre-commit new file mode 100755 index 0000000..81ac843 --- /dev/null +++ b/var/git-pre-commit @@ -0,0 +1,4 @@ +#!/usr/bin/env bash + +RUSTFLAGS=-Awarnings exec cargo test --offline --profile dev +