From 7a129dbe90bad3704ea4db209eeba627c87e6e46 Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Thu, 11 May 2023 11:28:30 +0200 Subject: [PATCH 01/34] Handhsake DirSync RespInner Signed-off-by: Luca Fulchir --- flake.lock | 72 +++++++++++++++++++------- flake.nix | 5 +- src/connection/handshake/dirsync.rs | 80 ++++++++++++++++++++++++++--- src/connection/handshake/mod.rs | 14 +++-- src/lib.rs | 61 +++++++++++++++++++--- 5 files changed, 195 insertions(+), 37 deletions(-) diff --git a/flake.lock b/flake.lock index 2920402..ce2e37c 100644 --- a/flake.lock +++ b/flake.lock @@ -1,12 +1,15 @@ { "nodes": { "flake-utils": { + "inputs": { + "systems": "systems" + }, "locked": { - "lastModified": 1676283394, - "narHash": "sha256-XX2f9c3iySLCw54rJ/CZs+ZK6IQy7GXNY4nSOyu2QG4=", + "lastModified": 1681202837, + "narHash": "sha256-H+Rh19JDwRtpVPAWp64F+rlEtxUWBAQW28eAi3SRSzg=", "owner": "numtide", "repo": "flake-utils", - "rev": "3db36a8b464d0c4532ba1c7dda728f4576d6d073", + "rev": "cfacdce06f30d2b68473a46042957675eebb3401", "type": "github" }, "original": { @@ -16,12 +19,15 @@ } }, "flake-utils_2": { + "inputs": { + "systems": "systems_2" + }, "locked": { - "lastModified": 1659877975, - "narHash": "sha256-zllb8aq3YO3h8B/U0/J1WBgAL8EX5yWf5pMj3G0NAmc=", + "lastModified": 1681202837, + "narHash": "sha256-H+Rh19JDwRtpVPAWp64F+rlEtxUWBAQW28eAi3SRSzg=", "owner": "numtide", "repo": "flake-utils", - "rev": "c0e246b9b83f637f4681389ecabcb2681b4f3af0", + "rev": "cfacdce06f30d2b68473a46042957675eebb3401", "type": "github" }, "original": { @@ -32,11 +38,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1677624842, - "narHash": "sha256-4DF9DbDuK4/+KYx0L6XcPBeDHUFVCtzok2fWtwXtb5w=", + "lastModified": 1683478192, + "narHash": "sha256-7f7RR71w0jRABDgBwjq3vE1yY3nrVJyXk8hDzu5kl1E=", "owner": "nixos", "repo": "nixpkgs", - "rev": "d70f5cd5c3bef45f7f52698f39e7cc7a89daa7f0", + "rev": "c568239bcc990050b7aedadb7387832440ad8fb1", "type": "github" }, "original": { @@ -48,11 +54,11 @@ }, "nixpkgs-unstable": { "locked": { - "lastModified": 1677407201, - "narHash": "sha256-3blwdI9o1BAprkvlByHvtEm5HAIRn/XPjtcfiunpY7s=", + "lastModified": 1683408522, + "narHash": "sha256-9kcPh6Uxo17a3kK3XCHhcWiV1Yu1kYj22RHiymUhMkU=", "owner": "nixos", "repo": "nixpkgs", - "rev": "7f5639fa3b68054ca0b062866dc62b22c3f11505", + "rev": "897876e4c484f1e8f92009fd11b7d988a121a4e7", "type": "github" }, "original": { @@ -64,11 +70,11 @@ }, "nixpkgs_2": { "locked": { - "lastModified": 1665296151, - "narHash": "sha256-uOB0oxqxN9K7XGF1hcnY+PQnlQJ+3bP2vCn/+Ru/bbc=", + "lastModified": 1681358109, + "narHash": "sha256-eKyxW4OohHQx9Urxi7TQlFBTDWII+F+x2hklDOQPB50=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "14ccaaedd95a488dd7ae142757884d8e125b3363", + "rev": "96ba1c52e54e74c3197f4d43026b3f3d92e83ff9", "type": "github" }, "original": { @@ -92,11 +98,11 @@ "nixpkgs": "nixpkgs_2" }, "locked": { - "lastModified": 1677638104, - "narHash": "sha256-vbdOoDYnQ1QYSchMb3fYGCLYeta3XwmGvMrlXchST5s=", + "lastModified": 1683512408, + "narHash": "sha256-QMJGp/37En+d5YocJuSU89GL14bBYkIJQ6mqhRfqkkc=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "f388187efb41ce4195b2f4de0b6bb463d3cd0a76", + "rev": "75b07756c3feb22cf230e75fb064c1b4c725b9bc", "type": "github" }, "original": { @@ -104,6 +110,36 @@ "repo": "rust-overlay", "type": "github" } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_2": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } } }, "root": "root", diff --git a/flake.nix b/flake.nix index 89bfdb9..37e97ab 100644 --- a/flake.nix +++ b/flake.nix @@ -15,6 +15,9 @@ pkgs = import nixpkgs { inherit system overlays; }; + pkgs-unstable = import nixpkgs-unstable { + inherit system overlays; + }; in { devShells.default = pkgs.mkShell { @@ -36,7 +39,7 @@ cargo-watch cargo-license lld - rust-bin.stable.latest.default + rust-bin.stable."1.69.0".default rustfmt rust-analyzer ]; diff --git a/src/connection/handshake/dirsync.rs b/src/connection/handshake/dirsync.rs index 9c0a299..902e4bd 100644 --- a/src/connection/handshake/dirsync.rs +++ b/src/connection/handshake/dirsync.rs @@ -63,6 +63,7 @@ pub struct Req { pub exchange_key: ExchangePubKey, /// encrypted data pub data: ReqInner, + // Security: Add padding to min: 1200 bytes to avoid amplification attaks } impl Req { @@ -79,7 +80,7 @@ impl Req { + self.data.len() } /// Serialize into raw bytes - /// NOTE: assumes that there is exactly asa much buffer as needed + /// NOTE: assumes that there is exactly as much buffer as needed pub fn serialize(&self, out: &mut [u8]) { //assert!(out.len() > , ": not enough buffer to serialize"); todo!() @@ -280,35 +281,94 @@ impl ReqData { } } +/// Quick way to avoid mixing cipher and clear text +#[derive(Debug, Clone)] +pub enum RespInner { + /// Server data, still in ciphertext + CipherText(VecDeque), + /// Server data, decrypted but unprocessed + ClearText(VecDeque), + /// Parsed server data + Data(RespData), +} +impl RespInner { + /// The length of the data + pub fn len(&self) -> usize { + match self { + RespInner::CipherText(c) => c.len(), + RespInner::ClearText(c) => c.len(), + RespInner::Data(d) => RespData::len(), + } + } + /// Get the ciptertext, or panic + pub fn ciphertext<'a>(&'a mut self) -> &'a mut VecDeque { + match self { + RespInner::CipherText(data) => data, + _ => panic!(), + } + } + /// switch from ciphertext to cleartext + pub fn mark_as_cleartext(&mut self) { + let mut newdata: VecDeque; + match self { + RespInner::CipherText(data) => { + newdata = VecDeque::new(); + ::core::mem::swap(&mut newdata, data); + } + _ => return, + } + *self = RespInner::ClearText(newdata); + } + /// serialize, but only if ciphertext + pub fn serialize(&self, out: &mut [u8]) { + todo!() + } +} + /// Server response in a directory synchronized handshake #[derive(Debug, Clone)] pub struct Resp { /// Tells the client with which key the exchange was done pub client_key_id: KeyID, - /// encrypted data - pub enc: Vec, + /// actual response data, might be encrypted + pub data: RespInner, } impl super::HandshakeParsing for Resp { fn deserialize(raw: &[u8]) -> Result { - todo!() + const MIN_PKT_LEN: usize = 68; + if raw.len() < MIN_PKT_LEN { + return Err(Error::NotEnoughData); + } + let client_key_id: KeyID = + KeyID(u16::from_le_bytes(raw[0..1].try_into().unwrap())); + Ok(HandshakeData::DirSync(DirSync::Resp(Self { + client_key_id, + data: RespInner::CipherText(raw[KeyID::len()..].to_vec().into()), + }))) } } impl Resp { /// Total length of the response handshake pub fn len(&self) -> usize { - KeyID::len() + self.enc.len() + KeyID::len() + self.data.len() } /// Serialize into raw bytes - /// NOTE: assumes that there is exactly asa much buffer as needed + /// NOTE: assumes that there is exactly as much buffer as needed + /// NOTE: assumes that the data is encrypted pub fn serialize(&self, out: &mut [u8]) { assert!( - out.len() == KeyID::len() + self.enc.len(), + out.len() == KeyID::len() + self.data.len(), "DirSync Resp: not enough buffer to serialize" ); self.client_key_id.serialize(array_mut_ref![out, 0, 2]); - out[2..].copy_from_slice(&self.enc[..]); + let end_data = 2 + self.data.len(); + self.data.serialize(&mut out[2..end_data]); + } + /// Set the cleartext data after it was parsed + pub fn set_data(&mut self, data: RespData) { + self.data = RespInner::Data(data); } } @@ -348,4 +408,8 @@ impl RespData { end = end + Self::NONCE_LEN; out[start..end].copy_from_slice(self.service_key.as_ref()); } + /// Parse the cleartext raw data + pub fn deserialize(raw: &RespInner) -> Result { + todo!(); + } } diff --git a/src/connection/handshake/mod.rs b/src/connection/handshake/mod.rs index 763191d..6e91c28 100644 --- a/src/connection/handshake/mod.rs +++ b/src/connection/handshake/mod.rs @@ -24,11 +24,19 @@ pub enum Error { NotEnoughData, } -pub(crate) struct HandshakeKey { +pub(crate) struct HandshakeServer { pub id: crate::enc::asym::KeyID, pub key: crate::enc::asym::PrivKey, } +#[derive(Clone)] +pub(crate) struct HandshakeClient { + pub id: crate::enc::asym::KeyID, + pub key: crate::enc::asym::PrivKey, + pub hkdf: crate::enc::hkdf::HkdfSha3, + pub cipher: crate::enc::sym::CipherKind, +} + /// Parsed handshake #[derive(Debug, Clone)] pub enum HandshakeData { @@ -117,14 +125,14 @@ impl Handshake { }) } /// serialize the handshake into bytes - /// NOTE: assumes that there is exactly asa much buffer as needed + /// NOTE: assumes that there is exactly as much buffer as needed pub fn serialize(&self, out: &mut [u8]) { assert!(out.len() > 1, "Handshake: not enough buffer to serialize"); self.fenrir_version.serialize(&mut out[0]); self.data.serialize(&mut out[1..]); } - pub(crate) fn work(&self, keys: &[HandshakeKey]) -> Result<(), Error> { + pub(crate) fn work(&self, keys: &[HandshakeServer]) -> Result<(), Error> { todo!() } } diff --git a/src/lib.rs b/src/lib.rs index c9543fb..b6965e8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -32,7 +32,7 @@ use crate::enc::{ }; pub use config::Config; use connection::{ - handshake::{self, Handshake, HandshakeKey}, + handshake::{self, Handshake, HandshakeClient, HandshakeServer}, Connection, }; @@ -79,7 +79,10 @@ pub enum HandshakeAction { struct FenrirInner { key_exchanges: ArcSwapAny>>, ciphers: ArcSwapAny>>, - keys: ArcSwapAny>>, + /// ephemeral keys used server side in key exchange + keys_srv: ArcSwapAny>>, + /// ephemeral keys used client side in key exchange + hshake_cli: ArcSwapAny>>, } #[allow(unsafe_code)] @@ -102,8 +105,8 @@ impl FenrirInner { DirSync::Req(ref mut req) => { let ephemeral_key = { // Keep this block short to avoid contention - // on self.keys - let keys = self.keys.load(); + // on self.keys_srv + let keys = self.keys_srv.load(); if let Some(h_k) = keys.iter().find(|k| k.id == req.key_id) { @@ -120,7 +123,10 @@ impl FenrirInner { } }; if ephemeral_key.is_none() { - ::tracing::debug!("No such key id: {:?}", req.key_id); + ::tracing::debug!( + "No such server key id: {:?}", + req.key_id + ); return Err(handshake::Error::UnknownKeyID.into()); } let ephemeral_key = ephemeral_key.unwrap(); @@ -170,6 +176,40 @@ impl FenrirInner { })); } DirSync::Resp(resp) => { + let hshake = { + // Keep this block short to avoid contention + // on self.hshake_cli + let hshake_cli_lock = self.hshake_cli.load(); + match hshake_cli_lock + .iter() + .find(|h| h.id == resp.client_key_id) + { + Some(h) => Some(h.clone()), + None => None, + } + }; + if hshake.is_none() { + ::tracing::debug!( + "No such client key id: {:?}", + resp.client_key_id + ); + return Err(handshake::Error::UnknownKeyID.into()); + } + let hshake = hshake.unwrap(); + let secret_recv = hshake.hkdf.get_secret(b"to_client"); + let cipher_recv = + CipherRecv::new(hshake.cipher, secret_recv); + use crate::enc::sym::AAD; + let aad = AAD(&mut []); // no aad for now + match cipher_recv.decrypt(aad, &mut resp.data.ciphertext()) + { + Ok(()) => resp.data.mark_as_cleartext(), + Err(e) => { + return Err(handshake::Error::Key(e).into()); + } + } + resp.set_data(dirsync::RespData::deserialize(&resp.data)?); + todo!(); } }, @@ -391,7 +431,8 @@ impl Fenrir { _inner: Arc::new(FenrirInner { ciphers: ArcSwapAny::new(Arc::new(Vec::new())), key_exchanges: ArcSwapAny::new(Arc::new(Vec::new())), - keys: ArcSwapAny::new(Arc::new(Vec::new())), + keys_srv: ArcSwapAny::new(Arc::new(Vec::new())), + hshake_cli: ArcSwapAny::new(Arc::new(Vec::new())), }), token_check: Arc::new(ArcSwapOption::from(None)), work_send: Arc::new(work_send), @@ -692,9 +733,12 @@ impl Fenrir { ::tracing::error!("can't encrypt: {:?}", e); return; } + use dirsync::RespInner; let resp = dirsync::Resp { client_key_id: req_data.client_key_id, - enc: data.get_raw(), + data: RespInner::CipherText( + data.get_raw().into(), + ), }; let resp_handshake = Handshake::new( HandshakeData::DirSync(DirSync::Resp(resp)), @@ -710,6 +754,9 @@ impl Fenrir { self.send_packet(raw_out, udp.src, udp.dst) .await; } + DirSync::Resp(resp) => { + todo!() + } _ => { todo!() } From a5f18ac533e420f9547ea75b4f190e6f25a11e3d Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Wed, 17 May 2023 10:26:39 +0200 Subject: [PATCH 02/34] DirSync::Resp work Signed-off-by: Luca Fulchir --- src/connection/handshake/dirsync.rs | 163 ++++++++++++++++++---------- src/connection/handshake/mod.rs | 23 +++- src/connection/packet.rs | 21 +++- src/enc/sym.rs | 78 +++++++------ src/lib.rs | 114 ++++++++++++------- 5 files changed, 259 insertions(+), 140 deletions(-) diff --git a/src/connection/handshake/dirsync.rs b/src/connection/handshake/dirsync.rs index 902e4bd..721ab9c 100644 --- a/src/connection/handshake/dirsync.rs +++ b/src/connection/handshake/dirsync.rs @@ -11,15 +11,16 @@ use super::{Error, HandshakeData}; use crate::{ auth, - connection::ID, + connection::{ProtocolVersion, ID}, enc::{ asym::{ExchangePubKey, KeyExchange, KeyID}, - sym::{CipherKind, Secret}, + sym::{CipherKind, HeadLen, Secret, TagLen}, }, }; use ::arrayref::array_mut_ref; use ::std::{collections::VecDeque, num::NonZeroU64, vec::Vec}; +use trust_dns_client::rr::rdata::key::Protocol; type Nonce = [u8; 16]; @@ -42,10 +43,15 @@ impl DirSync { } /// Serialize into raw bytes /// NOTE: assumes that there is exactly asa much buffer as needed - pub fn serialize(&self, out: &mut [u8]) { + pub fn serialize( + &self, + head_len: HeadLen, + tag_len: TagLen, + out: &mut [u8], + ) { match self { - DirSync::Req(req) => req.serialize(out), - DirSync::Resp(resp) => resp.serialize(out), + DirSync::Req(req) => req.serialize(head_len, tag_len, out), + DirSync::Resp(resp) => resp.serialize(head_len, tag_len, out), } } } @@ -67,9 +73,21 @@ pub struct Req { } impl Req { - /// Set the cleartext data after it was parsed - pub fn set_data(&mut self, data: ReqData) { - self.data = ReqInner::Data(data); + /// return the offset of the encrypted data + /// NOTE: starts from the beginning of the fenrir packet + pub fn encrypted_offset(&self) -> usize { + ProtocolVersion::len() + + KeyID::len() + + KeyExchange::len() + + CipherKind::len() + + self.exchange_key.len() + } + /// return the total length of the cleartext data + pub fn encrypted_length(&self) -> usize { + match &self.data { + ReqInner::ClearText(data) => data.len(), + _ => 0, + } } /// actual length of the directory synchronized request pub fn len(&self) -> usize { @@ -81,7 +99,12 @@ impl Req { } /// Serialize into raw bytes /// NOTE: assumes that there is exactly as much buffer as needed - pub fn serialize(&self, out: &mut [u8]) { + pub fn serialize( + &self, + head_len: HeadLen, + tag_len: TagLen, + out: &mut [u8], + ) { //assert!(out.len() > , ": not enough buffer to serialize"); todo!() } @@ -108,10 +131,7 @@ impl super::HandshakeParsing for Req { Ok(exchange_key) => exchange_key, Err(e) => return Err(e.into()), }; - let mut vec = VecDeque::with_capacity(raw.len() - (4 + len)); - vec.extend(raw[(4 + len)..].iter().copied()); - let _ = vec.make_contiguous(); - let data = ReqInner::CipherText(vec); + let data = ReqInner::CipherText(raw.len() - (4 + len)); Ok(HandshakeData::DirSync(DirSync::Req(Self { key_id, exchange, @@ -125,40 +145,35 @@ impl super::HandshakeParsing for Req { /// Quick way to avoid mixing cipher and clear text #[derive(Debug, Clone)] pub enum ReqInner { - /// Client data, still in ciphertext - CipherText(VecDeque), - /// Client data, decrypted but unprocessed - ClearText(VecDeque), - /// Parsed client data - Data(ReqData), + /// Data is still encrytped, we only keep the length + CipherText(usize), + /// Client data, decrypted and parsed + ClearText(ReqData), } impl ReqInner { /// The length of the data pub fn len(&self) -> usize { match self { - ReqInner::CipherText(c) => c.len(), - ReqInner::ClearText(c) => c.len(), - ReqInner::Data(d) => d.len(), + ReqInner::CipherText(len) => *len, + ReqInner::ClearText(data) => data.len(), } } - /// Get the ciptertext, or panic - pub fn ciphertext<'a>(&'a mut self) -> &'a mut VecDeque { - match self { - ReqInner::CipherText(data) => data, - _ => panic!(), - } - } - /// switch from ciphertext to cleartext - pub fn mark_as_cleartext(&mut self) { - let mut newdata: VecDeque; - match self { - ReqInner::CipherText(data) => { - newdata = VecDeque::new(); - ::core::mem::swap(&mut newdata, data); + /// parse the cleartext + pub fn deserialize_as_cleartext(&mut self, raw: &[u8]) { + let clear = match self { + ReqInner::CipherText(len) => { + assert!( + *len == raw.len(), + "DirSync::ReqInner::CipherText length mismatch" + ); + match ReqData::deserialize(raw) { + Ok(clear) => clear, + Err(_) => return, + } } _ => return, - } - *self = ReqInner::ClearText(newdata); + }; + *self = ReqInner::ClearText(clear); } } @@ -246,12 +261,7 @@ impl ReqData { pub const MIN_PKT_LEN: usize = 16 + KeyID::len() + ID::len() + AuthInfo::MIN_PKT_LEN; /// Parse the cleartext raw data - pub fn deserialize(raw: &ReqInner) -> Result { - let raw = match raw { - // raw is VecDeque, assume everything is on the first slice - ReqInner::ClearText(raw) => raw.as_slices().0, - _ => return Err(Error::Parsing), - }; + pub fn deserialize(raw: &[u8]) -> Result { if raw.len() < Self::MIN_PKT_LEN { return Err(Error::NotEnoughData); } @@ -285,21 +295,19 @@ impl ReqData { #[derive(Debug, Clone)] pub enum RespInner { /// Server data, still in ciphertext - CipherText(VecDeque), - /// Server data, decrypted but unprocessed - ClearText(VecDeque), - /// Parsed server data - Data(RespData), + CipherText(usize), + /// Parsed, cleartext server data + ClearText(RespData), } impl RespInner { /// The length of the data pub fn len(&self) -> usize { match self { - RespInner::CipherText(c) => c.len(), - RespInner::ClearText(c) => c.len(), - RespInner::Data(d) => RespData::len(), + RespInner::CipherText(len) => *len, + RespInner::ClearText(d) => RespData::len(), } } + /* /// Get the ciptertext, or panic pub fn ciphertext<'a>(&'a mut self) -> &'a mut VecDeque { match self { @@ -307,6 +315,25 @@ impl RespInner { _ => panic!(), } } + */ + /// parse the cleartext + pub fn deserialize_as_cleartext(&mut self, raw: &[u8]) { + let clear = match self { + RespInner::CipherText(len) => { + assert!( + *len == raw.len(), + "DirSync::RespInner::CipherText length mismatch" + ); + match RespData::deserialize(raw) { + Ok(clear) => clear, + Err(_) => return, + } + } + _ => return, + }; + *self = RespInner::ClearText(clear); + } + /* /// switch from ciphertext to cleartext pub fn mark_as_cleartext(&mut self) { let mut newdata: VecDeque; @@ -319,6 +346,7 @@ impl RespInner { } *self = RespInner::ClearText(newdata); } + */ /// serialize, but only if ciphertext pub fn serialize(&self, out: &mut [u8]) { todo!() @@ -344,31 +372,48 @@ impl super::HandshakeParsing for Resp { KeyID(u16::from_le_bytes(raw[0..1].try_into().unwrap())); Ok(HandshakeData::DirSync(DirSync::Resp(Self { client_key_id, - data: RespInner::CipherText(raw[KeyID::len()..].to_vec().into()), + data: RespInner::CipherText(raw[KeyID::len()..].len()), }))) } } impl Resp { + /// return the offset of the encrypted data + /// NOTE: starts from the beginning of the fenrir packet + pub fn encrypted_offset(&self) -> usize { + ProtocolVersion::len() + KeyID::len() + } + /// return the total length of the cleartext data + pub fn encrypted_length(&self) -> usize { + match &self.data { + RespInner::ClearText(_data) => RespData::len(), + _ => 0, + } + } /// Total length of the response handshake pub fn len(&self) -> usize { KeyID::len() + self.data.len() } /// Serialize into raw bytes /// NOTE: assumes that there is exactly as much buffer as needed - /// NOTE: assumes that the data is encrypted - pub fn serialize(&self, out: &mut [u8]) { + /// NOTE: assumes that the data is *ClearText* + pub fn serialize( + &self, + head_len: HeadLen, + tag_len: TagLen, + out: &mut [u8], + ) { assert!( out.len() == KeyID::len() + self.data.len(), "DirSync Resp: not enough buffer to serialize" ); self.client_key_id.serialize(array_mut_ref![out, 0, 2]); - let end_data = 2 + self.data.len(); - self.data.serialize(&mut out[2..end_data]); + let end_data = (2 + self.data.len()) - tag_len.0; + self.data.serialize(&mut out[(2 + head_len.0)..end_data]); } /// Set the cleartext data after it was parsed pub fn set_data(&mut self, data: RespData) { - self.data = RespInner::Data(data); + self.data = RespInner::ClearText(data); } } @@ -409,7 +454,7 @@ impl RespData { out[start..end].copy_from_slice(self.service_key.as_ref()); } /// Parse the cleartext raw data - pub fn deserialize(raw: &RespInner) -> Result { + pub fn deserialize(raw: &[u8]) -> Result { todo!(); } } diff --git a/src/connection/handshake/mod.rs b/src/connection/handshake/mod.rs index 6e91c28..74d3be1 100644 --- a/src/connection/handshake/mod.rs +++ b/src/connection/handshake/mod.rs @@ -4,7 +4,10 @@ pub mod dirsync; use ::num_traits::FromPrimitive; -use crate::connection::{self, ProtocolVersion}; +use crate::{ + connection::{self, ProtocolVersion}, + enc::sym::{HeadLen, TagLen}, +}; /// Handshake errors #[derive(::thiserror::Error, Debug, Copy, Clone)] @@ -53,9 +56,14 @@ impl HandshakeData { } /// Serialize into raw bytes /// NOTE: assumes that there is exactly asa much buffer as needed - pub fn serialize(&self, out: &mut [u8]) { + pub fn serialize( + &self, + head_len: HeadLen, + tag_len: TagLen, + out: &mut [u8], + ) { match self { - HandshakeData::DirSync(d) => d.serialize(out), + HandshakeData::DirSync(d) => d.serialize(head_len, tag_len, out), } } } @@ -126,10 +134,15 @@ impl Handshake { } /// serialize the handshake into bytes /// NOTE: assumes that there is exactly as much buffer as needed - pub fn serialize(&self, out: &mut [u8]) { + pub fn serialize( + &self, + head_len: HeadLen, + tag_len: TagLen, + out: &mut [u8], + ) { assert!(out.len() > 1, "Handshake: not enough buffer to serialize"); self.fenrir_version.serialize(&mut out[0]); - self.data.serialize(&mut out[1..]); + self.data.serialize(head_len, tag_len, &mut out[1..]); } pub(crate) fn work(&self, keys: &[HandshakeServer]) -> Result<(), Error> { diff --git a/src/connection/packet.rs b/src/connection/packet.rs index b59e02a..9e5eed3 100644 --- a/src/connection/packet.rs +++ b/src/connection/packet.rs @@ -1,6 +1,8 @@ // //! Raw packet handling, encryption, decryption, parsing +use crate::enc::sym::{HeadLen, TagLen}; + /// Fenrir Connection id /// 0 is special as it represents the handshake /// Connection IDs are to be considered u64 little endian @@ -100,10 +102,15 @@ impl PacketData { } /// serialize data into bytes /// NOTE: assumes that there is exactly asa much buffer as needed - pub fn serialize(&self, out: &mut [u8]) { + pub fn serialize( + &self, + head_len: HeadLen, + tag_len: TagLen, + out: &mut [u8], + ) { assert!(self.len() == out.len(), "PacketData: wrong buffer length"); match self { - PacketData::Handshake(h) => h.serialize(out), + PacketData::Handshake(h) => h.serialize(head_len, tag_len, out), } } } @@ -124,12 +131,18 @@ impl Packet { } /// serialize packet into buffer /// NOTE: assumes that there is exactly asa much buffer as needed - pub fn serialize(&self, out: &mut [u8]) { + pub fn serialize( + &self, + head_len: HeadLen, + tag_len: TagLen, + out: &mut [u8], + ) { assert!( out.len() > ConnectionID::len(), "Packet: not enough buffer to serialize" ); self.id.serialize(&mut out[0..ConnectionID::len()]); - self.data.serialize(&mut out[ConnectionID::len()..]); + self.data + .serialize(head_len, tag_len, &mut out[ConnectionID::len()..]); } } diff --git a/src/enc/sym.rs b/src/enc/sym.rs index 366f665..e8db1d6 100644 --- a/src/enc/sym.rs +++ b/src/enc/sym.rs @@ -59,9 +59,9 @@ impl CipherKind { 1 } /// required length of the nonce - pub fn nonce_len(&self) -> usize { + pub fn nonce_len(&self) -> HeadLen { // TODO: how the hell do I take this from ::chacha20poly1305? - Nonce::len() + HeadLen(Nonce::len()) } /// required length of the key pub fn key_len(&self) -> usize { @@ -69,9 +69,9 @@ impl CipherKind { XChaCha20Poly1305::key_size() } /// Length of the authentication tag - pub fn tag_len(&self) -> usize { + pub fn tag_len(&self) -> TagLen { // TODO: how the hell do I take this from ::chacha20poly1305? - ::ring::aead::CHACHA20_POLY1305.tag_len() + TagLen(::ring::aead::CHACHA20_POLY1305.tag_len()) } } @@ -90,6 +90,16 @@ pub enum CipherDirection { Send, } +/// strong typedef for header length +/// aka: nonce length in the encrypted data) +#[derive(Debug, Copy, Clone)] +pub struct HeadLen(pub usize); +/// strong typedef for the Tag length +/// aka: cryptographic authentication tag length at the end +/// of the encrypted data +#[derive(Debug, Copy, Clone)] +pub struct TagLen(pub usize); + /// actual ciphers enum Cipher { /// Cipher XChaha20_Poly1305 @@ -105,31 +115,33 @@ impl Cipher { } } } - fn nonce_len(&self) -> usize { + fn nonce_len(&self) -> HeadLen { match self { Cipher::XChaCha20Poly1305(_) => { // TODO: how the hell do I take this from ::chacha20poly1305? - ::ring::aead::CHACHA20_POLY1305.nonce_len() + HeadLen(::ring::aead::CHACHA20_POLY1305.nonce_len()) } } } - fn tag_len(&self) -> usize { + fn tag_len(&self) -> TagLen { match self { Cipher::XChaCha20Poly1305(_) => { // TODO: how the hell do I take this from ::chacha20poly1305? - ::ring::aead::CHACHA20_POLY1305.tag_len() + TagLen(::ring::aead::CHACHA20_POLY1305.tag_len()) } } } - fn decrypt(&self, aad: AAD, data: &mut VecDeque) -> Result<(), Error> { + fn decrypt<'a>( + &self, + aad: AAD, + raw_data: &'a mut [u8], + ) -> Result<&'a [u8], Error> { match self { Cipher::XChaCha20Poly1305(cipher) => { use ::chacha20poly1305::{ aead::generic_array::GenericArray, AeadInPlace, }; - let final_len: usize; - { - let raw_data = data.as_mut_slices().0; + let final_len: usize = { // FIXME: check min data length let (nonce_bytes, data_and_tag) = raw_data.split_at_mut(13); let (data_notag, tag_bytes) = data_and_tag.split_at_mut( @@ -147,11 +159,11 @@ impl Cipher { if maybe.is_err() { return Err(Error::Decrypt); } - final_len = data_notag.len(); - } - data.drain(..Nonce::len()); - data.truncate(final_len); - Ok(()) + data_notag.len() + }; + //data.drain(..Nonce::len()); + //data.truncate(final_len); + Ok(&raw_data[Nonce::len()..Nonce::len() + final_len]) } } } @@ -159,7 +171,7 @@ impl Cipher { match self { Cipher::XChaCha20Poly1305(cipher) => { let cipher = CipherKind::XChaCha20Poly1305; - cipher.nonce_len() + cipher.tag_len() + cipher.nonce_len().0 + cipher.tag_len().0 } } } @@ -167,28 +179,30 @@ impl Cipher { &self, nonce: &Nonce, aad: AAD, - data: &mut Data, + data: &mut [u8], ) -> Result<(), Error> { - // No need to check for minimum buffer size since `Data` assures we - // already went through that + // FIXME: check minimum buffer size match self { Cipher::XChaCha20Poly1305(cipher) => { use ::chacha20poly1305::{ aead::generic_array::GenericArray, AeadInPlace, }; + let tag_len: usize = ::ring::aead::CHACHA20_POLY1305.tag_len(); + let data_len_notag = data.len() - tag_len; // write nonce - data.get_slice_full()[..Nonce::len()] - .copy_from_slice(nonce.as_bytes()); + data[..Nonce::len()].copy_from_slice(nonce.as_bytes()); // encrypt data match cipher.cipher.encrypt_in_place_detached( nonce.as_bytes().into(), aad.0, - data.get_slice(), + &mut data[Nonce::len()..data_len_notag], ) { Ok(tag) => { - // add tag - data.get_tag_slice().copy_from_slice(tag.as_slice()); + data[data_len_notag..] + // add tag + //data.get_tag_slice() + .copy_from_slice(tag.as_slice()); Ok(()) } Err(_) => Err(Error::Encrypt), @@ -216,7 +230,7 @@ impl CipherRecv { Self(Cipher::new(kind, secret)) } /// Get the length of the nonce for this cipher - pub fn nonce_len(&self) -> usize { + pub fn nonce_len(&self) -> HeadLen { self.0.nonce_len() } /// Decrypt a paket. Nonce and Tag are taken from the packet, @@ -224,8 +238,8 @@ impl CipherRecv { pub fn decrypt<'a>( &self, aad: AAD, - data: &mut VecDeque, - ) -> Result<(), Error> { + data: &'a mut [u8], + ) -> Result<&'a [u8], Error> { self.0.decrypt(aad, data) } } @@ -289,12 +303,12 @@ impl CipherSend { pub fn make_data(&self, length: usize) -> Data { Data { data: Vec::with_capacity(length + self.cipher.overhead()), - skip_start: self.cipher.nonce_len(), - skip_end: self.cipher.tag_len(), + skip_start: self.cipher.nonce_len().0, + skip_end: self.cipher.tag_len().0, } } /// Encrypt the given data - pub fn encrypt(&self, aad: AAD, data: &mut Data) -> Result<(), Error> { + pub fn encrypt(&self, aad: AAD, data: &mut [u8]) -> Result<(), Error> { let old_nonce = self.nonce.advance(); self.cipher.encrypt(&old_nonce, aad, data)?; Ok(()) diff --git a/src/lib.rs b/src/lib.rs index b6965e8..0c75750 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,7 +28,7 @@ use ::tokio::{ use crate::enc::{ asym, hkdf::HkdfSha3, - sym::{CipherKind, CipherRecv, CipherSend}, + sym::{CipherKind, CipherRecv, CipherSend, HeadLen, TagLen}, }; pub use config::Config; use connection::{ @@ -67,13 +67,26 @@ pub struct AuthNeededInfo { pub cipher: CipherKind, } +/// Client information needed to fully establish the conenction +#[derive(Debug)] +pub struct ClientConnectInfo { + /// Parsed handshake + pub handshake: Handshake, + /// hkdf generated from the handshake + pub hkdf: HkdfSha3, + /// cipher to be used in both directions + pub cipher_recv: CipherRecv, +} + /// Intermediate actions to be taken while parsing the handshake -#[derive(Debug, Clone)] +#[derive(Debug)] pub enum HandshakeAction { /// Parsing finished, all ok, nothing to do None, /// Packet parsed, now go perform authentication AuthNeeded(AuthNeededInfo), + /// the client can fully establish a connection with this info + ClientConnect(ClientConnectInfo), } // No async here struct FenrirInner { @@ -95,6 +108,7 @@ impl FenrirInner { fn recv_handshake( &self, mut handshake: Handshake, + handshake_raw: &mut [u8], ) -> Result { use connection::handshake::{ dirsync::{self, DirSync}, @@ -159,13 +173,17 @@ impl FenrirInner { let cipher_recv = CipherRecv::new(req.cipher, secret_recv); use crate::enc::sym::AAD; let aad = AAD(&mut []); // no aad for now - match cipher_recv.decrypt(aad, &mut req.data.ciphertext()) { - Ok(()) => req.data.mark_as_cleartext(), + match cipher_recv.decrypt( + aad, + &mut handshake_raw[req.encrypted_offset()..], + ) { + Ok(cleartext) => { + req.data.deserialize_as_cleartext(cleartext) + } Err(e) => { return Err(handshake::Error::Key(e).into()); } } - req.set_data(dirsync::ReqData::deserialize(&req.data)?); let cipher = req.cipher; @@ -200,17 +218,26 @@ impl FenrirInner { let cipher_recv = CipherRecv::new(hshake.cipher, secret_recv); use crate::enc::sym::AAD; - let aad = AAD(&mut []); // no aad for now - match cipher_recv.decrypt(aad, &mut resp.data.ciphertext()) - { - Ok(()) => resp.data.mark_as_cleartext(), + // no aad for now + let aad = AAD(&mut []); + let mut raw_data = &mut handshake_raw[resp + .encrypted_offset() + ..(resp.encrypted_offset() + resp.encrypted_length())]; + match cipher_recv.decrypt(aad, &mut raw_data) { + Ok(cleartext) => { + resp.data.deserialize_as_cleartext(&cleartext) + } Err(e) => { return Err(handshake::Error::Key(e).into()); } } - resp.set_data(dirsync::RespData::deserialize(&resp.data)?); - - todo!(); + return Ok(HandshakeAction::ClientConnect( + ClientConnectInfo { + handshake, + hkdf: hshake.hkdf, + cipher_recv, + }, + )); } }, } @@ -615,7 +642,7 @@ impl Fenrir { const MIN_PACKET_BYTES: usize = 8; /// Read and do stuff with the raw udp packet - async fn recv(&self, udp: RawUdp) { + async fn recv(&self, mut udp: RawUdp) { if udp.data.len() < Self::MIN_PACKET_BYTES { return; } @@ -630,13 +657,15 @@ impl Fenrir { return; } }; - let action = match self._inner.recv_handshake(handshake) { - Ok(action) => action, - Err(err) => { - ::tracing::debug!("Handshake recv error {}", err); - return; - } - }; + let action = + match self._inner.recv_handshake(handshake, &mut udp.data[8..]) + { + Ok(action) => action, + Err(err) => { + ::tracing::debug!("Handshake recv error {}", err); + return; + } + }; match action { HandshakeAction::AuthNeeded(authinfo) => { let tk_check = match self.token_check.load_full() { @@ -657,10 +686,10 @@ impl Fenrir { DirSync::Req(req) => { use dirsync::ReqInner; let req_data = match req.data { - ReqInner::Data(req_data) => req_data, + ReqInner::ClearText(req_data) => req_data, _ => { ::tracing::error!( - "token_check: expected Data" + "token_check: expected ClearText" ); return; } @@ -697,6 +726,8 @@ impl Fenrir { connection::ID::new_rand(&self.rand); let srv_secret = enc::sym::Secret::new_rand(&self.rand); + let head_len = req.cipher.nonce_len(); + let tag_len = req.cipher.tag_len(); let raw_conn = Connection::new( authinfo.hkdf, @@ -711,9 +742,6 @@ impl Fenrir { lock.reserve_first(raw_conn) }; - // TODO: move all the next bits into - // dirsync::Req::respond(...) - let resp_data = dirsync::RespData { client_nonce: req_data.nonce, id: auth_conn.id, @@ -721,25 +749,18 @@ impl Fenrir { service_key: srv_secret, }; use crate::enc::sym::AAD; - let aad = AAD(&mut []); // no aad for now - let mut data = auth_conn - .cipher_send - .make_data(dirsync::RespData::len()); + // no aad for now + let aad = AAD(&mut []); - if let Err(e) = auth_conn - .cipher_send - .encrypt(aad, &mut data) - { - ::tracing::error!("can't encrypt: {:?}", e); - return; - } use dirsync::RespInner; let resp = dirsync::Resp { client_key_id: req_data.client_key_id, - data: RespInner::CipherText( - data.get_raw().into(), - ), + data: RespInner::ClearText(resp_data), }; + let offset_to_encrypt = resp.encrypted_offset(); + let encrypt_until = offset_to_encrypt + + resp.encrypted_length() + + tag_len.0; let resp_handshake = Handshake::new( HandshakeData::DirSync(DirSync::Resp(resp)), ); @@ -750,7 +771,20 @@ impl Fenrir { }; let mut raw_out = Vec::::with_capacity(packet.len()); - packet.serialize(&mut raw_out); + packet.serialize( + head_len, + tag_len, + &mut raw_out, + ); + + if let Err(e) = auth_conn.cipher_send.encrypt( + aad, + &mut raw_out + [offset_to_encrypt..encrypt_until], + ) { + ::tracing::error!("can't encrypt: {:?}", e); + return; + } self.send_packet(raw_out, udp.src, udp.dst) .await; } From ace56f32e70dda44a2f3550dfd23871c0c3537d0 Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Wed, 17 May 2023 12:05:13 +0200 Subject: [PATCH 03/34] refactor lib.rs in other files Signed-off-by: Luca Fulchir --- src/connection/mod.rs | 69 ++++++- src/connection/socket.rs | 108 +++++++++++ src/inner/mod.rs | 213 ++++++++++++++++++++++ src/lib.rs | 375 +++------------------------------------ 4 files changed, 408 insertions(+), 357 deletions(-) create mode 100644 src/connection/socket.rs create mode 100644 src/inner/mod.rs diff --git a/src/connection/mod.rs b/src/connection/mod.rs index 8db39c4..b247cbc 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -2,12 +2,14 @@ pub mod handshake; mod packet; +pub mod socket; -use ::std::vec::Vec; +use ::std::{sync::Arc, vec::Vec}; -pub use handshake::Handshake; -pub use packet::ConnectionID as ID; -pub use packet::{Packet, PacketData}; +pub use crate::connection::{ + handshake::Handshake, + packet::{ConnectionID as ID, Packet, PacketData}, +}; use crate::enc::{ hkdf::HkdfSha3, @@ -83,3 +85,62 @@ impl Connection { } } } + +// PERF: Arc> loks a bit too much, need to find +// faster ways to do this +pub(crate) struct ConnList { + connections: Vec>>, + /// Bitmap to track which connection ids are used or free + ids_used: Vec<::bitmaps::Bitmap<1024>>, +} + +impl ConnList { + pub(crate) fn new() -> Self { + let mut bitmap_id = ::bitmaps::Bitmap::<1024>::new(); + bitmap_id.set(0, true); // ID(0) == handshake + Self { + connections: Vec::with_capacity(128), + ids_used: vec![bitmap_id], + } + } + pub(crate) fn reserve_first( + &mut self, + mut conn: Connection, + ) -> Arc { + // uhm... bad things are going on here: + // * id must be initialized, but only because: + // * rust does not understand that after the `!found` id is always + // initialized + // * `ID::new_u64` is really safe only with >0, but here it always is + // ...we should probably rewrite it in better, safer rust + let mut id: u64 = 0; + let mut found = false; + for (i, b) in self.ids_used.iter_mut().enumerate() { + match b.first_false_index() { + Some(idx) => { + b.set(idx, true); + id = ((i as u64) * 1024) + (idx as u64); + found = true; + break; + } + None => {} + } + } + if !found { + let mut new_bitmap = ::bitmaps::Bitmap::<1024>::new(); + new_bitmap.set(0, true); + id = (self.ids_used.len() as u64) * 1024; + self.ids_used.push(new_bitmap); + } + let new_id = ID::new_u64(id); + conn.id = new_id; + let conn = Arc::new(conn); + if (self.connections.len() as u64) < id { + self.connections.push(Some(conn.clone())); + } else { + // very probably redundant + self.connections[id as usize] = Some(conn.clone()); + } + conn + } +} diff --git a/src/connection/socket.rs b/src/connection/socket.rs new file mode 100644 index 0000000..455c567 --- /dev/null +++ b/src/connection/socket.rs @@ -0,0 +1,108 @@ +//! Socket related types and functions + +use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption}; +use ::std::{ + net::SocketAddr, + sync::Arc, + vec::{self, Vec}, +}; +use ::tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle}; + +// PERF: move to arcswap: +// We use multiple UDP sockets. +// We need a list of them all and to scan this list. +// But we want to be able to update this list without locking everyone +// So we should wrap into ArcSwap, and use unsafe `as_raf_fd()` and +// `from_raw_fd` to update the list. +// This means that we will have multiple `UdpSocket` per actual socket +// so we have to handle `drop()` manually, and garbage-collect the ones we +// are no longer using in the background. sigh. +// Just go with a ArcSwapAny, Arc>>); + +/// async free socket list +pub(crate) struct SocketList { + pub list: ArcSwap>, +} +impl SocketList { + pub(crate) fn new() -> Self { + Self { + list: ArcSwap::new(Arc::new(Vec::new())), + } + } + // TODO: fn rm_socket() + pub(crate) fn rm_all(&self) -> Self { + let new_list = Arc::new(Vec::new()); + let old_list = self.list.swap(new_list); + Self { + list: old_list.into(), + } + } + pub(crate) async fn add_socket( + &self, + socket: Arc, + handle: JoinHandle<::std::io::Result<()>>, + ) { + // we could simplify this into just a `.swap` instead of `.rcu` but + // it is not yet guaranteed that only one thread will call this fn + // ...we don't need performance here anyway + let arc_handle = Arc::new(handle); + self.list.rcu(|old_list| { + let mut new_list = Arc::new(Vec::with_capacity(old_list.len() + 1)); + new_list = old_list.to_vec().into(); + Arc::get_mut(&mut new_list) + .unwrap() + .push((socket.clone(), arc_handle.clone())); + new_list + }); + } + /// This method assumes no other `add_sockets` are being run + pub(crate) async fn stop_all(mut self) { + let mut arc_list = self.list.into_inner(); + let list = loop { + match Arc::try_unwrap(arc_list) { + Ok(list) => break list, + Err(arc_retry) => { + arc_list = arc_retry; + ::tokio::time::sleep(::core::time::Duration::from_millis( + 50, + )) + .await; + } + } + }; + for (_socket, mut handle) in list.into_iter() { + Arc::get_mut(&mut handle).unwrap().await; + } + } + pub(crate) fn lock(&self) -> SocketListRef { + SocketListRef { + list: self.list.load_full(), + } + } +} + +/// Reference to a locked SocketList +// TODO: impl Drop for SocketList +pub(crate) struct SocketListRef { + list: Arc>, +} +impl SocketListRef { + pub(crate) fn find(&self, sock: UdpServer) -> Option> { + match self + .list + .iter() + .find(|&(s, _)| s.local_addr().unwrap() == sock.0) + { + Some((sock_srv, _)) => Some(sock_srv.clone()), + None => None, + } + } +} + +/// Strong typedef for a client socket address +#[derive(Debug, Copy, Clone)] +pub(crate) struct UdpClient(pub SocketAddr); +/// Strong typedef for a server socket address +#[derive(Debug, Copy, Clone)] +pub(crate) struct UdpServer(pub SocketAddr); diff --git a/src/inner/mod.rs b/src/inner/mod.rs new file mode 100644 index 0000000..22cafc8 --- /dev/null +++ b/src/inner/mod.rs @@ -0,0 +1,213 @@ +//! Inner Fenrir tracking +//! This is meant to be **async-free** so that others might use it +//! without the tokio runtime + +use crate::{ + connection::{ + self, + handshake::{self, Handshake, HandshakeClient, HandshakeServer}, + Connection, + }, + enc::{ + self, asym, + hkdf::HkdfSha3, + sym::{CipherKind, CipherRecv, CipherSend, HeadLen, TagLen}, + }, + Error, +}; +use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption}; +use ::std::{sync::Arc, vec::Vec}; + +/// Information needed to reply after the key exchange +#[derive(Debug, Clone)] +pub struct AuthNeededInfo { + /// Parsed handshake + pub handshake: Handshake, + /// hkdf generated from the handshake + pub hkdf: HkdfSha3, + /// cipher to be used in both directions + pub cipher: CipherKind, +} + +/// Client information needed to fully establish the conenction +#[derive(Debug)] +pub struct ClientConnectInfo { + /// Parsed handshake + pub handshake: Handshake, + /// hkdf generated from the handshake + pub hkdf: HkdfSha3, + /// cipher to be used in both directions + pub cipher_recv: CipherRecv, +} +/// Intermediate actions to be taken while parsing the handshake +#[derive(Debug)] +pub enum HandshakeAction { + /// Parsing finished, all ok, nothing to do + None, + /// Packet parsed, now go perform authentication + AuthNeeded(AuthNeededInfo), + /// the client can fully establish a connection with this info + ClientConnect(ClientConnectInfo), +} + +/// Async free but thread safe tracking of handhsakes and conenctions +pub struct Tracker { + key_exchanges: ArcSwapAny>>, + ciphers: ArcSwapAny>>, + /// ephemeral keys used server side in key exchange + keys_srv: ArcSwapAny>>, + /// ephemeral keys used client side in key exchange + hshake_cli: ArcSwapAny>>, +} +#[allow(unsafe_code)] +unsafe impl Send for Tracker {} +#[allow(unsafe_code)] +unsafe impl Sync for Tracker {} + +impl Tracker { + pub fn new() -> Self { + Self { + ciphers: ArcSwapAny::new(Arc::new(Vec::new())), + key_exchanges: ArcSwapAny::new(Arc::new(Vec::new())), + keys_srv: ArcSwapAny::new(Arc::new(Vec::new())), + hshake_cli: ArcSwapAny::new(Arc::new(Vec::new())), + } + } + pub(crate) fn recv_handshake( + &self, + mut handshake: Handshake, + handshake_raw: &mut [u8], + ) -> Result { + use connection::handshake::{ + dirsync::{self, DirSync}, + HandshakeData, + }; + match handshake.data { + HandshakeData::DirSync(ref mut ds) => match ds { + DirSync::Req(ref mut req) => { + let ephemeral_key = { + // Keep this block short to avoid contention + // on self.keys_srv + let keys = self.keys_srv.load(); + if let Some(h_k) = + keys.iter().find(|k| k.id == req.key_id) + { + use enc::asym::PrivKey; + // Directory synchronized can only use keys + // for key exchange, not signing keys + if let PrivKey::Exchange(k) = &h_k.key { + Some(k.clone()) + } else { + None + } + } else { + None + } + }; + if ephemeral_key.is_none() { + ::tracing::debug!( + "No such server key id: {:?}", + req.key_id + ); + return Err(handshake::Error::UnknownKeyID.into()); + } + let ephemeral_key = ephemeral_key.unwrap(); + { + let exchanges = self.key_exchanges.load(); + if None + == exchanges.iter().find(|&x| { + *x == (ephemeral_key.kind(), req.exchange) + }) + { + return Err( + enc::Error::UnsupportedKeyExchange.into() + ); + } + } + { + let ciphers = self.ciphers.load(); + if None == ciphers.iter().find(|&x| *x == req.cipher) { + return Err(enc::Error::UnsupportedCipher.into()); + } + } + let shared_key = match ephemeral_key + .key_exchange(req.exchange, req.exchange_key) + { + Ok(shared_key) => shared_key, + Err(e) => return Err(handshake::Error::Key(e).into()), + }; + let hkdf = HkdfSha3::new(b"fenrir", shared_key); + let secret_recv = hkdf.get_secret(b"to_server"); + let cipher_recv = CipherRecv::new(req.cipher, secret_recv); + use crate::enc::sym::AAD; + let aad = AAD(&mut []); // no aad for now + match cipher_recv.decrypt( + aad, + &mut handshake_raw[req.encrypted_offset()..], + ) { + Ok(cleartext) => { + req.data.deserialize_as_cleartext(cleartext) + } + Err(e) => { + return Err(handshake::Error::Key(e).into()); + } + } + + let cipher = req.cipher; + + return Ok(HandshakeAction::AuthNeeded(AuthNeededInfo { + handshake, + hkdf, + cipher, + })); + } + DirSync::Resp(resp) => { + let hshake = { + // Keep this block short to avoid contention + // on self.hshake_cli + let hshake_cli_lock = self.hshake_cli.load(); + match hshake_cli_lock + .iter() + .find(|h| h.id == resp.client_key_id) + { + Some(h) => Some(h.clone()), + None => None, + } + }; + if hshake.is_none() { + ::tracing::debug!( + "No such client key id: {:?}", + resp.client_key_id + ); + return Err(handshake::Error::UnknownKeyID.into()); + } + let hshake = hshake.unwrap(); + let secret_recv = hshake.hkdf.get_secret(b"to_client"); + let cipher_recv = + CipherRecv::new(hshake.cipher, secret_recv); + use crate::enc::sym::AAD; + // no aad for now + let aad = AAD(&mut []); + let mut raw_data = &mut handshake_raw[resp + .encrypted_offset() + ..(resp.encrypted_offset() + resp.encrypted_length())]; + match cipher_recv.decrypt(aad, &mut raw_data) { + Ok(cleartext) => { + resp.data.deserialize_as_cleartext(&cleartext) + } + Err(e) => { + return Err(handshake::Error::Key(e).into()); + } + } + return Ok(HandshakeAction::ClientConnect( + ClientConnectInfo { + handshake, + hkdf: hshake.hkdf, + cipher_recv, + }, + )); + } + }, + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 0c75750..6972d8b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,23 +18,33 @@ mod config; pub mod connection; pub mod dnssec; pub mod enc; +mod inner; use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption}; -use ::std::{net::SocketAddr, pin::Pin, sync::Arc, vec, vec::Vec}; +use ::std::{ + net::SocketAddr, + pin::Pin, + sync::Arc, + vec::{self, Vec}, +}; use ::tokio::{ macros::support::Future, net::UdpSocket, sync::RwLock, task::JoinHandle, }; -use crate::enc::{ - asym, - hkdf::HkdfSha3, - sym::{CipherKind, CipherRecv, CipherSend, HeadLen, TagLen}, +use crate::{ + connection::{ + handshake::{self, Handshake, HandshakeClient, HandshakeServer}, + socket::{SocketList, UdpClient, UdpServer}, + ConnList, Connection, + }, + enc::{ + asym, + hkdf::HkdfSha3, + sym::{CipherKind, CipherRecv, CipherSend, HeadLen, TagLen}, + }, + inner::HandshakeAction, }; pub use config::Config; -use connection::{ - handshake::{self, Handshake, HandshakeClient, HandshakeServer}, - Connection, -}; /// Main fenrir library errors #[derive(::thiserror::Error, Debug)] @@ -56,194 +66,6 @@ pub enum Error { Key(#[from] crate::enc::Error), } -/// Information needed to reply after the key exchange -#[derive(Debug, Clone)] -pub struct AuthNeededInfo { - /// Parsed handshake - pub handshake: Handshake, - /// hkdf generated from the handshake - pub hkdf: HkdfSha3, - /// cipher to be used in both directions - pub cipher: CipherKind, -} - -/// Client information needed to fully establish the conenction -#[derive(Debug)] -pub struct ClientConnectInfo { - /// Parsed handshake - pub handshake: Handshake, - /// hkdf generated from the handshake - pub hkdf: HkdfSha3, - /// cipher to be used in both directions - pub cipher_recv: CipherRecv, -} - -/// Intermediate actions to be taken while parsing the handshake -#[derive(Debug)] -pub enum HandshakeAction { - /// Parsing finished, all ok, nothing to do - None, - /// Packet parsed, now go perform authentication - AuthNeeded(AuthNeededInfo), - /// the client can fully establish a connection with this info - ClientConnect(ClientConnectInfo), -} -// No async here -struct FenrirInner { - key_exchanges: ArcSwapAny>>, - ciphers: ArcSwapAny>>, - /// ephemeral keys used server side in key exchange - keys_srv: ArcSwapAny>>, - /// ephemeral keys used client side in key exchange - hshake_cli: ArcSwapAny>>, -} - -#[allow(unsafe_code)] -unsafe impl Send for FenrirInner {} -#[allow(unsafe_code)] -unsafe impl Sync for FenrirInner {} - -// No async here -impl FenrirInner { - fn recv_handshake( - &self, - mut handshake: Handshake, - handshake_raw: &mut [u8], - ) -> Result { - use connection::handshake::{ - dirsync::{self, DirSync}, - HandshakeData, - }; - match handshake.data { - HandshakeData::DirSync(ref mut ds) => match ds { - DirSync::Req(ref mut req) => { - let ephemeral_key = { - // Keep this block short to avoid contention - // on self.keys_srv - let keys = self.keys_srv.load(); - if let Some(h_k) = - keys.iter().find(|k| k.id == req.key_id) - { - use enc::asym::PrivKey; - // Directory synchronized can only use keys - // for key exchange, not signing keys - if let PrivKey::Exchange(k) = &h_k.key { - Some(k.clone()) - } else { - None - } - } else { - None - } - }; - if ephemeral_key.is_none() { - ::tracing::debug!( - "No such server key id: {:?}", - req.key_id - ); - return Err(handshake::Error::UnknownKeyID.into()); - } - let ephemeral_key = ephemeral_key.unwrap(); - { - let exchanges = self.key_exchanges.load(); - if None - == exchanges.iter().find(|&x| { - *x == (ephemeral_key.kind(), req.exchange) - }) - { - return Err( - enc::Error::UnsupportedKeyExchange.into() - ); - } - } - { - let ciphers = self.ciphers.load(); - if None == ciphers.iter().find(|&x| *x == req.cipher) { - return Err(enc::Error::UnsupportedCipher.into()); - } - } - let shared_key = match ephemeral_key - .key_exchange(req.exchange, req.exchange_key) - { - Ok(shared_key) => shared_key, - Err(e) => return Err(handshake::Error::Key(e).into()), - }; - let hkdf = HkdfSha3::new(b"fenrir", shared_key); - let secret_recv = hkdf.get_secret(b"to_server"); - let cipher_recv = CipherRecv::new(req.cipher, secret_recv); - use crate::enc::sym::AAD; - let aad = AAD(&mut []); // no aad for now - match cipher_recv.decrypt( - aad, - &mut handshake_raw[req.encrypted_offset()..], - ) { - Ok(cleartext) => { - req.data.deserialize_as_cleartext(cleartext) - } - Err(e) => { - return Err(handshake::Error::Key(e).into()); - } - } - - let cipher = req.cipher; - - return Ok(HandshakeAction::AuthNeeded(AuthNeededInfo { - handshake, - hkdf, - cipher, - })); - } - DirSync::Resp(resp) => { - let hshake = { - // Keep this block short to avoid contention - // on self.hshake_cli - let hshake_cli_lock = self.hshake_cli.load(); - match hshake_cli_lock - .iter() - .find(|h| h.id == resp.client_key_id) - { - Some(h) => Some(h.clone()), - None => None, - } - }; - if hshake.is_none() { - ::tracing::debug!( - "No such client key id: {:?}", - resp.client_key_id - ); - return Err(handshake::Error::UnknownKeyID.into()); - } - let hshake = hshake.unwrap(); - let secret_recv = hshake.hkdf.get_secret(b"to_client"); - let cipher_recv = - CipherRecv::new(hshake.cipher, secret_recv); - use crate::enc::sym::AAD; - // no aad for now - let aad = AAD(&mut []); - let mut raw_data = &mut handshake_raw[resp - .encrypted_offset() - ..(resp.encrypted_offset() + resp.encrypted_length())]; - match cipher_recv.decrypt(aad, &mut raw_data) { - Ok(cleartext) => { - resp.data.deserialize_as_cleartext(&cleartext) - } - Err(e) => { - return Err(handshake::Error::Key(e).into()); - } - } - return Ok(HandshakeAction::ClientConnect( - ClientConnectInfo { - handshake, - hkdf: hshake.hkdf, - cipher_recv, - }, - )); - } - }, - } - } -} - type TokenChecker = fn( user: auth::UserID, @@ -252,99 +74,7 @@ type TokenChecker = domain: auth::Domain, ) -> ::futures::future::BoxFuture<'static, Result>; -// PERF: move to arcswap: -// We use multiple UDP sockets. -// We need a list of them all and to scan this list. -// But we want to be able to update this list without locking everyone -// So we should wrap into ArcSwap, and use unsafe `as_raf_fd()` and -// `from_raw_fd` to update the list. -// This means that we will have multiple `UdpSocket` per actual socket -// so we have to handle `drop()` manually, and garbage-collect the ones we -// are no longer using in the background. sigh. -// Just go with a ArcSwapAny, Arc>>); -struct SocketList { - list: ArcSwap>, -} -impl SocketList { - fn new() -> Self { - Self { - list: ArcSwap::new(Arc::new(Vec::new())), - } - } - // TODO: fn rm_socket() - fn rm_all(&self) -> Self { - let new_list = Arc::new(Vec::new()); - let old_list = self.list.swap(new_list); - Self { - list: old_list.into(), - } - } - async fn add_socket( - &self, - socket: Arc, - handle: JoinHandle<::std::io::Result<()>>, - ) { - // we could simplify this into just a `.swap` instead of `.rcu` but - // it is not yet guaranteed that only one thread will call this fn - // ...we don't need performance here anyway - let arc_handle = Arc::new(handle); - self.list.rcu(|old_list| { - let mut new_list = Arc::new(Vec::with_capacity(old_list.len() + 1)); - new_list = old_list.to_vec().into(); - Arc::get_mut(&mut new_list) - .unwrap() - .push((socket.clone(), arc_handle.clone())); - new_list - }); - } - /// This method assumes no other `add_sockets` are being run - async fn stop_all(mut self) { - let mut arc_list = self.list.into_inner(); - let list = loop { - match Arc::try_unwrap(arc_list) { - Ok(list) => break list, - Err(arc_retry) => { - arc_list = arc_retry; - ::tokio::time::sleep(::core::time::Duration::from_millis( - 50, - )) - .await; - } - } - }; - for (_socket, mut handle) in list.into_iter() { - Arc::get_mut(&mut handle).unwrap().await; - } - } - fn lock(&self) -> SocketListRef { - SocketListRef { - list: self.list.load_full(), - } - } -} -// TODO: impl Drop for SocketList -struct SocketListRef { - list: Arc>, -} -impl SocketListRef { - fn find(&self, sock: UdpServer) -> Option> { - match self - .list - .iter() - .find(|&(s, _)| s.local_addr().unwrap() == sock.0) - { - Some((sock_srv, _)) => Some(sock_srv.clone()), - None => None, - } - } -} - -#[derive(Debug, Copy, Clone)] -struct UdpClient(SocketAddr); -#[derive(Debug, Copy, Clone)] -struct UdpServer(SocketAddr); - +/// Track a raw Udp packet struct RawUdp { data: Vec, src: UdpClient, @@ -355,62 +85,6 @@ enum Work { Recv(RawUdp), } -// PERF: Arc> loks a bit too much, need to find -// faster ways to do this -struct ConnList { - connections: Vec>>, - /// Bitmap to track which connection ids are used or free - ids_used: Vec<::bitmaps::Bitmap<1024>>, -} - -impl ConnList { - fn new() -> Self { - let mut bitmap_id = ::bitmaps::Bitmap::<1024>::new(); - bitmap_id.set(0, true); // ID(0) == handshake - Self { - connections: Vec::with_capacity(128), - ids_used: vec![bitmap_id], - } - } - fn reserve_first(&mut self, mut conn: Connection) -> Arc { - // uhm... bad things are going on here: - // * id must be initialized, but only because: - // * rust does not understand that after the `!found` id is always - // initialized - // * `ID::new_u64` is really safe only with >0, but here it always is - // ...we should probably rewrite it in better, safer rust - let mut id: u64 = 0; - let mut found = false; - for (i, b) in self.ids_used.iter_mut().enumerate() { - match b.first_false_index() { - Some(idx) => { - b.set(idx, true); - id = ((i as u64) * 1024) + (idx as u64); - found = true; - break; - } - None => {} - } - } - if !found { - let mut new_bitmap = ::bitmaps::Bitmap::<1024>::new(); - new_bitmap.set(0, true); - id = (self.ids_used.len() as u64) * 1024; - self.ids_used.push(new_bitmap); - } - let new_id = connection::ID::new_u64(id); - conn.id = new_id; - let conn = Arc::new(conn); - if (self.connections.len() as u64) < id { - self.connections.push(Some(conn.clone())); - } else { - // very probably redundant - self.connections[id as usize] = Some(conn.clone()); - } - conn - } -} - /// Instance of a fenrir endpoint #[allow(missing_copy_implementations, missing_debug_implementations)] pub struct Fenrir { @@ -423,7 +97,7 @@ pub struct Fenrir { /// Broadcast channel to tell workers to stop working stop_working: ::tokio::sync::broadcast::Sender, /// Private keys used in the handshake - _inner: Arc, + _inner: Arc, /// where to ask for token check token_check: Arc>, /// MPMC work queue. sender @@ -455,12 +129,7 @@ impl Fenrir { sockets: SocketList::new(), dnssec: None, stop_working: sender, - _inner: Arc::new(FenrirInner { - ciphers: ArcSwapAny::new(Arc::new(Vec::new())), - key_exchanges: ArcSwapAny::new(Arc::new(Vec::new())), - keys_srv: ArcSwapAny::new(Arc::new(Vec::new())), - hshake_cli: ArcSwapAny::new(Arc::new(Vec::new())), - }), + _inner: Arc::new(inner::Tracker::new()), token_check: Arc::new(ArcSwapOption::from(None)), work_send: Arc::new(work_send), work_recv: Arc::new(work_recv), From 28cbe2ae20b6b5ce6a1ede8c45df4b8922292277 Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Mon, 22 May 2023 15:05:17 +0200 Subject: [PATCH 04/34] more refactoring Signed-off-by: Luca Fulchir --- src/connection/handshake/mod.rs | 7 +- src/connection/mod.rs | 20 ++- src/connection/socket.rs | 42 +++++ src/inner/mod.rs | 17 +- src/lib.rs | 299 ++++++++++++++------------------ 5 files changed, 200 insertions(+), 185 deletions(-) diff --git a/src/connection/handshake/mod.rs b/src/connection/handshake/mod.rs index 74d3be1..99eab75 100644 --- a/src/connection/handshake/mod.rs +++ b/src/connection/handshake/mod.rs @@ -2,12 +2,12 @@ pub mod dirsync; -use ::num_traits::FromPrimitive; - use crate::{ connection::{self, ProtocolVersion}, enc::sym::{HeadLen, TagLen}, }; +use ::num_traits::FromPrimitive; +use ::std::sync::Arc; /// Handshake errors #[derive(::thiserror::Error, Debug, Copy, Clone)] @@ -36,8 +36,7 @@ pub(crate) struct HandshakeServer { pub(crate) struct HandshakeClient { pub id: crate::enc::asym::KeyID, pub key: crate::enc::asym::PrivKey, - pub hkdf: crate::enc::hkdf::HkdfSha3, - pub cipher: crate::enc::sym::CipherKind, + pub connection: Arc, } /// Parsed handshake diff --git a/src/connection/mod.rs b/src/connection/mod.rs index b247cbc..3c74c94 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -16,6 +16,13 @@ use crate::enc::{ sym::{CipherKind, CipherRecv, CipherSend}, }; +/// strong typedef for receiving connection id +#[derive(Debug, Copy, Clone)] +pub struct IDRecv(pub ID); +/// strong typedef for sending connection id +#[derive(Debug, Copy, Clone)] +pub struct IDSend(pub ID); + /// Version of the fenrir protocol in use #[derive(::num_derive::FromPrimitive, Debug, Copy, Clone)] #[repr(u8)] @@ -37,8 +44,10 @@ impl ProtocolVersion { /// A single connection and its data #[derive(Debug)] pub struct Connection { - /// Connection ID - pub id: ID, + /// Receiving Connection ID + pub id_recv: IDRecv, + /// Sending Connection ID + pub id_send: IDSend, /// The main hkdf used for all secrets in this connection pub hkdf: HkdfSha3, /// Cipher for decrypting data @@ -78,7 +87,8 @@ impl Connection { let mut cipher_send = CipherSend::new(cipher, secret_send, rand); Self { - id: ID::Handshake, + id_recv: IDRecv(ID::Handshake), + id_send: IDSend(ID::Handshake), hkdf, cipher_recv, cipher_send, @@ -132,8 +142,8 @@ impl ConnList { id = (self.ids_used.len() as u64) * 1024; self.ids_used.push(new_bitmap); } - let new_id = ID::new_u64(id); - conn.id = new_id; + let new_id = IDRecv(ID::new_u64(id)); + conn.id_recv = new_id; let conn = Arc::new(conn); if (self.connections.len() as u64) < id { self.connections.push(Some(conn.clone())); diff --git a/src/connection/socket.rs b/src/connection/socket.rs index 455c567..b95e7c9 100644 --- a/src/connection/socket.rs +++ b/src/connection/socket.rs @@ -106,3 +106,45 @@ pub(crate) struct UdpClient(pub SocketAddr); /// Strong typedef for a server socket address #[derive(Debug, Copy, Clone)] pub(crate) struct UdpServer(pub SocketAddr); + +/// Enable some common socket options. This is just the unsafe part +fn enable_sock_opt( + fd: ::std::os::fd::RawFd, + option: ::libc::c_int, + value: ::libc::c_int, +) -> ::std::io::Result<()> { + #[allow(unsafe_code)] + unsafe { + #[allow(trivial_casts)] + let val = &value as *const _ as *const ::libc::c_void; + let size = ::std::mem::size_of_val(&value) as ::libc::socklen_t; + // always clear the error bit before doing a new syscall + let _ = ::std::io::Error::last_os_error(); + let ret = ::libc::setsockopt(fd, ::libc::SOL_SOCKET, option, val, size); + if ret != 0 { + return Err(::std::io::Error::last_os_error()); + } + } + Ok(()) +} +/// Add an async udp listener +pub async fn bind_udp(sock: SocketAddr) -> ::std::io::Result { + let socket = UdpSocket::bind(sock).await?; + + use ::std::os::fd::AsRawFd; + let fd = socket.as_raw_fd(); + // can be useful later on for reloads + enable_sock_opt(fd, ::libc::SO_REUSEADDR, 1)?; + enable_sock_opt(fd, ::libc::SO_REUSEPORT, 1)?; + + // We will do path MTU discovery by ourselves, + // always set the "don't fragment" bit + if sock.is_ipv6() { + enable_sock_opt(fd, ::libc::IPV6_DONTFRAG, 1)?; + } else { + // FIXME: linux only + enable_sock_opt(fd, ::libc::IP_MTU_DISCOVER, ::libc::IP_PMTUDISC_DO)?; + } + + Ok(socket) +} diff --git a/src/inner/mod.rs b/src/inner/mod.rs index 22cafc8..6007868 100644 --- a/src/inner/mod.rs +++ b/src/inner/mod.rs @@ -21,7 +21,7 @@ use ::std::{sync::Arc, vec::Vec}; /// Information needed to reply after the key exchange #[derive(Debug, Clone)] pub struct AuthNeededInfo { - /// Parsed handshake + /// Parsed handshake packet pub handshake: Handshake, /// hkdf generated from the handshake pub hkdf: HkdfSha3, @@ -32,12 +32,10 @@ pub struct AuthNeededInfo { /// Client information needed to fully establish the conenction #[derive(Debug)] pub struct ClientConnectInfo { - /// Parsed handshake + /// Parsed handshake packet pub handshake: Handshake, - /// hkdf generated from the handshake - pub hkdf: HkdfSha3, - /// cipher to be used in both directions - pub cipher_recv: CipherRecv, + /// Connection + pub connection: Arc, } /// Intermediate actions to be taken while parsing the handshake #[derive(Debug)] @@ -182,9 +180,7 @@ impl Tracker { return Err(handshake::Error::UnknownKeyID.into()); } let hshake = hshake.unwrap(); - let secret_recv = hshake.hkdf.get_secret(b"to_client"); - let cipher_recv = - CipherRecv::new(hshake.cipher, secret_recv); + let cipher_recv = &hshake.connection.cipher_recv; use crate::enc::sym::AAD; // no aad for now let aad = AAD(&mut []); @@ -202,8 +198,7 @@ impl Tracker { return Ok(HandshakeAction::ClientConnect( ClientConnectInfo { handshake, - hkdf: hshake.hkdf, - cipher_recv, + connection: hshake.connection, }, )); } diff --git a/src/lib.rs b/src/lib.rs index 6972d8b..a4faa07 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -33,9 +33,12 @@ use ::tokio::{ use crate::{ connection::{ - handshake::{self, Handshake, HandshakeClient, HandshakeServer}, + handshake::{ + self, dirsync::DirSync, Handshake, HandshakeClient, HandshakeData, + HandshakeServer, + }, socket::{SocketList, UdpClient, UdpServer}, - ConnList, Connection, + ConnList, Connection, IDSend, }, enc::{ asym, @@ -167,34 +170,13 @@ impl Fenrir { self.dnssec = None; } - /// Enable some common socket options. This is just the unsafe part - fn enable_sock_opt( - fd: ::std::os::fd::RawFd, - option: ::libc::c_int, - value: ::libc::c_int, - ) -> ::std::io::Result<()> { - #[allow(unsafe_code)] - unsafe { - #[allow(trivial_casts)] - let val = &value as *const _ as *const ::libc::c_void; - let size = ::std::mem::size_of_val(&value) as ::libc::socklen_t; - // always clear the error bit before doing a new syscall - let _ = ::std::io::Error::last_os_error(); - let ret = - ::libc::setsockopt(fd, ::libc::SOL_SOCKET, option, val, size); - if ret != 0 { - return Err(::std::io::Error::last_os_error()); - } - } - Ok(()) - } - /// Add all UDP sockets found in config /// and start listening for packets async fn add_sockets(&self) -> ::std::io::Result<()> { let sockets = self.cfg.listen.iter().map(|s_addr| async { let socket = - ::tokio::spawn(Self::bind_udp(s_addr.clone())).await??; + ::tokio::spawn(connection::socket::bind_udp(s_addr.clone())) + .await??; Ok(socket) }); let sockets = ::futures::future::join_all(sockets).await; @@ -218,32 +200,6 @@ impl Fenrir { Ok(()) } - /// Add an async udp listener - async fn bind_udp(sock: SocketAddr) -> ::std::io::Result { - let socket = UdpSocket::bind(sock).await?; - - use ::std::os::fd::AsRawFd; - let fd = socket.as_raw_fd(); - // can be useful later on for reloads - Self::enable_sock_opt(fd, ::libc::SO_REUSEADDR, 1)?; - Self::enable_sock_opt(fd, ::libc::SO_REUSEPORT, 1)?; - - // We will do path MTU discovery by ourselves, - // always set the "don't fragment" bit - if sock.is_ipv6() { - Self::enable_sock_opt(fd, ::libc::IPV6_DONTFRAG, 1)?; - } else { - // FIXME: linux only - Self::enable_sock_opt( - fd, - ::libc::IP_MTU_DISCOVER, - ::libc::IP_PMTUDISC_DO, - )?; - } - - Ok(socket) - } - /// Run a dedicated loop to read packets on the listening socket async fn listen_udp( mut stop_working: ::tokio::sync::broadcast::Receiver, @@ -350,121 +306,134 @@ impl Fenrir { dirsync::{self, DirSync}, HandshakeData, }; - match authinfo.handshake.data { - HandshakeData::DirSync(ds) => match ds { - DirSync::Req(req) => { - use dirsync::ReqInner; - let req_data = match req.data { - ReqInner::ClearText(req_data) => req_data, - _ => { - ::tracing::error!( - "token_check: expected ClearText" - ); - return; - } - }; - let is_authenticated = match tk_check( - req_data.auth.user, - req_data.auth.token, - req_data.auth.service_id, - req_data.auth.domain, - ) - .await - { - Ok(is_authenticated) => is_authenticated, - Err(_) => { - ::tracing::error!( - "error in token auth" - ); - // TODO: retry? - return; - } - }; - if !is_authenticated { - ::tracing::warn!( - "Wrong authentication for user {:?}", - req_data.auth.user - ); - // TODO: error response - return; - } - // Client has correctly authenticated - // TODO: contact the service, get the key and - // connection ID - let srv_conn_id = - connection::ID::new_rand(&self.rand); - let srv_secret = - enc::sym::Secret::new_rand(&self.rand); - let head_len = req.cipher.nonce_len(); - let tag_len = req.cipher.tag_len(); - - let raw_conn = Connection::new( - authinfo.hkdf, - req.cipher, - connection::Role::Server, - &self.rand, - ); - // track connection - let auth_conn = { - let mut lock = - self.connections.write().await; - lock.reserve_first(raw_conn) - }; - - let resp_data = dirsync::RespData { - client_nonce: req_data.nonce, - id: auth_conn.id, - service_id: srv_conn_id, - service_key: srv_secret, - }; - use crate::enc::sym::AAD; - // no aad for now - let aad = AAD(&mut []); - - use dirsync::RespInner; - let resp = dirsync::Resp { - client_key_id: req_data.client_key_id, - data: RespInner::ClearText(resp_data), - }; - let offset_to_encrypt = resp.encrypted_offset(); - let encrypt_until = offset_to_encrypt - + resp.encrypted_length() - + tag_len.0; - let resp_handshake = Handshake::new( - HandshakeData::DirSync(DirSync::Resp(resp)), - ); - use connection::{Packet, PacketData, ID}; - let packet = Packet { - id: ID::new_handshake(), - data: PacketData::Handshake(resp_handshake), - }; - let mut raw_out = - Vec::::with_capacity(packet.len()); - packet.serialize( - head_len, - tag_len, - &mut raw_out, - ); - - if let Err(e) = auth_conn.cipher_send.encrypt( - aad, - &mut raw_out - [offset_to_encrypt..encrypt_until], - ) { - ::tracing::error!("can't encrypt: {:?}", e); - return; - } - self.send_packet(raw_out, udp.src, udp.dst) - .await; - } - DirSync::Resp(resp) => { - todo!() - } - _ => { - todo!() - } - }, + let req; + if let HandshakeData::DirSync(DirSync::Req(r)) = + authinfo.handshake.data + { + req = r; + } else { + ::tracing::error!("AuthInfo on non DS::Req"); + return; } + use dirsync::ReqInner; + let req_data = match req.data { + ReqInner::ClearText(req_data) => req_data, + _ => { + ::tracing::error!( + "token_check: expected ClearText" + ); + return; + } + }; + let is_authenticated = match tk_check( + req_data.auth.user, + req_data.auth.token, + req_data.auth.service_id, + req_data.auth.domain, + ) + .await + { + Ok(is_authenticated) => is_authenticated, + Err(_) => { + ::tracing::error!("error in token auth"); + // TODO: retry? + return; + } + }; + if !is_authenticated { + ::tracing::warn!( + "Wrong authentication for user {:?}", + req_data.auth.user + ); + // TODO: error response + return; + } + // Client has correctly authenticated + // TODO: contact the service, get the key and + // connection ID + let srv_conn_id = connection::ID::new_rand(&self.rand); + let srv_secret = enc::sym::Secret::new_rand(&self.rand); + let head_len = req.cipher.nonce_len(); + let tag_len = req.cipher.tag_len(); + + let mut raw_conn = Connection::new( + authinfo.hkdf, + req.cipher, + connection::Role::Server, + &self.rand, + ); + raw_conn.id_send = IDSend(req_data.id); + // track connection + let auth_conn = { + let mut lock = self.connections.write().await; + lock.reserve_first(raw_conn) + }; + + let resp_data = dirsync::RespData { + client_nonce: req_data.nonce, + id: auth_conn.id_recv.0, + service_id: srv_conn_id, + service_key: srv_secret, + }; + use crate::enc::sym::AAD; + // no aad for now + let aad = AAD(&mut []); + + use dirsync::RespInner; + let resp = dirsync::Resp { + client_key_id: req_data.client_key_id, + data: RespInner::ClearText(resp_data), + }; + let offset_to_encrypt = resp.encrypted_offset(); + let encrypt_until = + offset_to_encrypt + resp.encrypted_length() + tag_len.0; + let resp_handshake = Handshake::new( + HandshakeData::DirSync(DirSync::Resp(resp)), + ); + use connection::{Packet, PacketData, ID}; + let packet = Packet { + id: ID::new_handshake(), + data: PacketData::Handshake(resp_handshake), + }; + let mut raw_out = Vec::::with_capacity(packet.len()); + packet.serialize(head_len, tag_len, &mut raw_out); + + if let Err(e) = auth_conn.cipher_send.encrypt( + aad, + &mut raw_out[offset_to_encrypt..encrypt_until], + ) { + ::tracing::error!("can't encrypt: {:?}", e); + return; + } + self.send_packet(raw_out, udp.src, udp.dst).await; + return; + } + HandshakeAction::ClientConnect(mut cci) => { + let ds_resp; + if let HandshakeData::DirSync(DirSync::Resp(resp)) = + cci.handshake.data + { + ds_resp = resp; + } else { + ::tracing::error!("ClientConnect on non DS::Resp"); + return; + } + // track connection + use handshake::dirsync; + let resp_data; + if let dirsync::RespInner::ClearText(r_data) = ds_resp.data + { + resp_data = r_data; + } else { + ::tracing::error!( + "ClientConnect on non DS::Resp::ClearText" + ); + return; + } + // FIXME: conn tracking and arc counting + let conn = Arc::get_mut(&mut cci.connection).unwrap(); + conn.id_send = IDSend(resp_data.id); + todo!(); } _ => {} }; From c0d6cf182452660c82292c52cb3579af28d0a8fd Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Tue, 23 May 2023 18:20:08 +0200 Subject: [PATCH 05/34] Per-thread work loop This will let us have a lot less locking. We can do better in the future with ebpf and pinning connection to a specific CPU with multiple listen() points on the same address, but good enough for now Signed-off-by: Luca Fulchir --- Cargo.toml | 1 + src/connection/mod.rs | 2 +- src/connection/packet.rs | 20 +++++ src/lib.rs | 159 ++++++++++++++++++++++++++++++++------- 4 files changed, 152 insertions(+), 30 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 906053a..8142aa8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,6 +35,7 @@ bitmaps = { version = "3.2" } chacha20poly1305 = { version = "0.10" } futures = { version = "0.3" } hkdf = { version = "0.12" } +hwloc2 = {version = "2.2" } libc = { version = "0.2" } num-traits = { version = "0.2" } num-derive = { version = "0.3" } diff --git a/src/connection/mod.rs b/src/connection/mod.rs index 3c74c94..fe75884 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -1,7 +1,7 @@ //! Connection handling and send/receive queues pub mod handshake; -mod packet; +pub mod packet; pub mod socket; use ::std::{sync::Arc, vec::Vec}; diff --git a/src/connection/packet.rs b/src/connection/packet.rs index 9e5eed3..400f900 100644 --- a/src/connection/packet.rs +++ b/src/connection/packet.rs @@ -91,6 +91,8 @@ impl From<[u8; 8]> for ConnectionID { pub enum PacketData { /// A parsed handshake packet Handshake(super::Handshake), + /// Raw packet. we only have the connection ID and packet length + Raw(usize), } impl PacketData { @@ -98,6 +100,7 @@ impl PacketData { pub fn len(&self) -> usize { match self { PacketData::Handshake(h) => h.len(), + PacketData::Raw(len) => *len } } /// serialize data into bytes @@ -111,10 +114,15 @@ impl PacketData { assert!(self.len() == out.len(), "PacketData: wrong buffer length"); match self { PacketData::Handshake(h) => h.serialize(head_len, tag_len, out), + PacketData::Raw(_) => { + ::tracing::error!("Tried to serialize a raw PacketData!"); + } } } } +const MIN_PACKET_BYTES: usize = 16; + /// Fenrir packet structure #[derive(Debug, Clone)] pub struct Packet { @@ -125,6 +133,18 @@ pub struct Packet { } impl Packet { + /// New recevied packet, yet unparsed + pub fn deserialize_id(raw: &[u8]) -> Result { + // TODO: proper min_packet length. 16 is too conservative. + if raw.len() < MIN_PACKET_BYTES { + return Err(()); + } + let raw_id: [u8; 8] = (raw[..8]).try_into().expect("unreachable"); + Ok(Packet { + id: raw_id.into(), + data: PacketData::Raw(raw.len()), + }) + } /// get the total length of the packet pub fn len(&self) -> usize { ConnectionID::len() + self.data.len() diff --git a/src/lib.rs b/src/lib.rs index a4faa07..94e67e7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,7 +24,7 @@ use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption}; use ::std::{ net::SocketAddr, pin::Pin, - sync::Arc, + sync::{Arc, Weak}, vec::{self, Vec}, }; use ::tokio::{ @@ -38,7 +38,7 @@ use crate::{ HandshakeServer, }, socket::{SocketList, UdpClient, UdpServer}, - ConnList, Connection, IDSend, + ConnList, Connection, IDSend, Packet, }, enc::{ asym, @@ -79,9 +79,10 @@ type TokenChecker = /// Track a raw Udp packet struct RawUdp { - data: Vec, src: UdpClient, dst: UdpServer, + data: Vec, + packet: Packet, } enum Work { @@ -103,14 +104,15 @@ pub struct Fenrir { _inner: Arc, /// where to ask for token check token_check: Arc>, - /// MPMC work queue. sender - work_send: Arc<::async_channel::Sender>, - /// MPMC work queue. receiver - work_recv: Arc<::async_channel::Receiver>, // PERF: rand uses syscalls. should we do that async? rand: ::ring::rand::SystemRandom, /// list of Established connections connections: Arc>, + _myself: Weak, + // TODO: find a way to both increase and decrease these two in a thread-safe + // manner + _thread_pool: Vec<::std::thread::JoinHandle<()>>, + _thread_work: Arc>>, } // TODO: graceful vs immediate stop @@ -123,22 +125,23 @@ impl Drop for Fenrir { impl Fenrir { /// Create a new Fenrir endpoint - pub fn new(config: &Config) -> Result { + pub fn new(config: &Config) -> Result, Error> { let listen_num = config.listen.len(); let (sender, _) = ::tokio::sync::broadcast::channel(1); let (work_send, work_recv) = ::async_channel::unbounded::(); - let endpoint = Fenrir { + let endpoint = Arc::new_cyclic(|myself| Fenrir { cfg: config.clone(), sockets: SocketList::new(), dnssec: None, stop_working: sender, _inner: Arc::new(inner::Tracker::new()), token_check: Arc::new(ArcSwapOption::from(None)), - work_send: Arc::new(work_send), - work_recv: Arc::new(work_recv), rand: ::ring::rand::SystemRandom::new(), connections: Arc::new(RwLock::new(ConnList::new())), - }; + _myself: myself.clone(), + _thread_pool: Vec::new(), + _thread_work: Arc::new(Vec::new()), + }); Ok(endpoint) } @@ -169,7 +172,6 @@ impl Fenrir { toempty_sockets.stop_all().await; self.dnssec = None; } - /// Add all UDP sockets found in config /// and start listening for packets async fn add_sockets(&self) -> ::std::io::Result<()> { @@ -187,7 +189,7 @@ impl Fenrir { let arc_s = Arc::new(s); let join = ::tokio::spawn(Self::listen_udp( stop_working, - self.work_send.clone(), + self._thread_work.clone(), arc_s.clone(), )); self.sockets.add_socket(arc_s, join); @@ -203,12 +205,13 @@ impl Fenrir { /// Run a dedicated loop to read packets on the listening socket async fn listen_udp( mut stop_working: ::tokio::sync::broadcast::Receiver, - work_queue: Arc<::async_channel::Sender>, + work_queues: Arc>>, socket: Arc, ) -> ::std::io::Result<()> { // jumbo frames are 9K max let sock_receiver = UdpServer(socket.local_addr()?); let mut buffer: [u8; 9000] = [0; 9000]; + let queues_num = work_queues.len() as u64; loop { let (bytes, sock_sender) = ::tokio::select! { _done = stop_working.recv() => { @@ -219,10 +222,33 @@ impl Fenrir { } }; let data: Vec = buffer[..bytes].to_vec(); - work_queue.send(Work::Recv(RawUdp { - data, + + // we very likely have multiple threads, pinned to different cpus. + // use the ConnectionID to send the same connection + // to the same thread. + // Handshakes have conenction ID 0, so we use the sender's UDP port + + let packet = match Packet::deserialize_id(&data) { + Ok(packet) => packet, + Err(_) => continue, // packet way too short, ignore. + }; + let thread_idx: usize = { + use connection::packet::ConnectionID; + match packet.id { + ConnectionID::Handshake => { + let send_port = sock_sender.port() as u64; + ((send_port % queues_num) - 1) as usize + } + ConnectionID::ID(id) => { + ((id.get() % queues_num) - 1) as usize + } + } + }; + work_queues[thread_idx].send(Work::Recv(RawUdp { src: UdpClient(sock_sender), dst: sock_receiver, + packet, + data, })); } Ok(()) @@ -242,15 +268,97 @@ impl Fenrir { Ok(dnssec::Dnssec::parse_txt_record(&record_str)?) } - /// Loop continuously and parse packets and other work - pub async fn work_loop(&self) { + /// Start one working thread for each physical cpu + /// threads are pinned to each cpu core. + /// Work will be divided and rerouted so that there is no need to lock + pub async fn start_work_threads_pinned( + &mut self, + tokio_rt: Arc<::tokio::runtime::Runtime>, + ) -> ::std::result::Result<(), ()> { + use ::std::sync::Mutex; + let hw_topology = match ::hwloc2::Topology::new() { + Some(hw_topology) => Arc::new(Mutex::new(hw_topology)), + None => return Err(()), + }; + let cores; + { + let topology_lock = hw_topology.lock().unwrap(); + let all_cores = match topology_lock + .objects_with_type(&::hwloc2::ObjectType::Core) + { + Ok(all_cores) => all_cores, + Err(_) => return Err(()), + }; + cores = all_cores.len(); + if cores <= 0 || !topology_lock.support().cpu().set_thread() { + ::tracing::error!("No support for CPU pinning"); + return Err(()); + } + } + for core in 0..cores { + ::tracing::debug!("Spawning thread {}", core); + let th_topology = hw_topology.clone(); + let th_tokio_rt = tokio_rt.clone(); + let th_myself = self._myself.upgrade().unwrap(); + let (work_send, work_recv) = ::async_channel::unbounded::(); + let join_handle = ::std::thread::spawn(move || { + // bind to a specific core + let th_pinning; + { + let mut th_topology_lock = th_topology.lock().unwrap(); + let th_cores = th_topology_lock + .objects_with_type(&::hwloc2::ObjectType::Core) + .unwrap(); + let cpuset = th_cores.get(core).unwrap().cpuset().unwrap(); + th_pinning = th_topology_lock.set_cpubind( + cpuset, + ::hwloc2::CpuBindFlags::CPUBIND_THREAD, + ); + } + match th_pinning { + Ok(_) => {} + Err(_) => { + ::tracing::error!("Can't bind thread to cpu"); + return; + } + } + // finally run the main listener. make sure things stay on this + // thread + let tk_local = ::tokio::task::LocalSet::new(); + let _ = tk_local.block_on( + &th_tokio_rt, + Self::work_loop_thread(th_myself, work_recv), + ); + }); + loop { + let queues_lock = match Arc::get_mut(&mut self._thread_work) { + Some(queues_lock) => queues_lock, + None => { + ::tokio::time::sleep( + ::std::time::Duration::from_millis(50), + ) + .await; + continue; + } + }; + queues_lock.push(work_send); + break; + } + self._thread_pool.push(join_handle); + } + Ok(()) + } + async fn work_loop_thread( + self: Arc, + work_recv: ::async_channel::Receiver, + ) { let mut stop_working = self.stop_working.subscribe(); loop { let work = ::tokio::select! { _done = stop_working.recv() => { break; } - maybe_work = self.work_recv.recv() => { + maybe_work = work_recv.recv() => { match maybe_work { Ok(work) => work, Err(_) => break, @@ -265,16 +373,9 @@ impl Fenrir { } } - const MIN_PACKET_BYTES: usize = 8; /// Read and do stuff with the raw udp packet async fn recv(&self, mut udp: RawUdp) { - if udp.data.len() < Self::MIN_PACKET_BYTES { - return; - } - use connection::ID; - let raw_id: [u8; 8] = (udp.data[..8]).try_into().expect("unreachable"); - if ID::from(raw_id).is_handshake() { - use connection::handshake::Handshake; + if udp.packet.id.is_handshake() { let handshake = match Handshake::deserialize(&udp.data[8..]) { Ok(handshake) => handshake, Err(e) => { @@ -390,7 +491,7 @@ impl Fenrir { let resp_handshake = Handshake::new( HandshakeData::DirSync(DirSync::Resp(resp)), ); - use connection::{Packet, PacketData, ID}; + use connection::{PacketData, ID}; let packet = Packet { id: ID::new_handshake(), data: PacketData::Handshake(resp_handshake), From 9b33ed882877995e605ed13bb393bc7a0605a1b6 Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Wed, 24 May 2023 15:45:37 +0200 Subject: [PATCH 06/34] Refactor, more pinned-thread work Signed-off-by: Luca Fulchir --- src/auth/mod.rs | 12 +- src/connection/socket.rs | 13 +- src/inner/mod.rs | 11 +- src/inner/worker.rs | 307 +++++++++++++++++++++++++++++++++++ src/lib.rs | 335 +++++++-------------------------------- 5 files changed, 385 insertions(+), 293 deletions(-) create mode 100644 src/inner/worker.rs diff --git a/src/auth/mod.rs b/src/auth/mod.rs index 17f43b9..52b51e0 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -1,4 +1,4 @@ -//! Authentication reslated struct definitions +//! Authentication related struct definitions use ::ring::rand::SecureRandom; use ::zeroize::Zeroize; @@ -53,6 +53,16 @@ impl ::core::fmt::Debug for Token { } } +/// Type of the function used to check the validity of the tokens +/// Reimplement this to use whatever database you want +pub type TokenChecker = + fn( + user: UserID, + token: Token, + service_id: ServiceID, + domain: Domain, + ) -> ::futures::future::BoxFuture<'static, Result>; + /// domain representation /// Security notice: internal representation is utf8, but we will /// further limit to a "safe" subset of utf8 diff --git a/src/connection/socket.rs b/src/connection/socket.rs index b95e7c9..1cc570d 100644 --- a/src/connection/socket.rs +++ b/src/connection/socket.rs @@ -8,17 +8,8 @@ use ::std::{ }; use ::tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle}; -// PERF: move to arcswap: -// We use multiple UDP sockets. -// We need a list of them all and to scan this list. -// But we want to be able to update this list without locking everyone -// So we should wrap into ArcSwap, and use unsafe `as_raf_fd()` and -// `from_raw_fd` to update the list. -// This means that we will have multiple `UdpSocket` per actual socket -// so we have to handle `drop()` manually, and garbage-collect the ones we -// are no longer using in the background. sigh. -// Just go with a ArcSwapAny, Arc>>); +/// Pair to easily track the socket and its async listening handle +pub type SocketTracker = (Arc, Arc>>); /// async free socket list pub(crate) struct SocketList { diff --git a/src/inner/mod.rs b/src/inner/mod.rs index 6007868..73490f8 100644 --- a/src/inner/mod.rs +++ b/src/inner/mod.rs @@ -2,7 +2,10 @@ //! This is meant to be **async-free** so that others might use it //! without the tokio runtime +pub(crate) mod worker; + use crate::{ + auth, connection::{ self, handshake::{self, Handshake, HandshakeClient, HandshakeServer}, @@ -49,7 +52,7 @@ pub enum HandshakeAction { } /// Async free but thread safe tracking of handhsakes and conenctions -pub struct Tracker { +pub struct HandshakeTracker { key_exchanges: ArcSwapAny>>, ciphers: ArcSwapAny>>, /// ephemeral keys used server side in key exchange @@ -58,11 +61,11 @@ pub struct Tracker { hshake_cli: ArcSwapAny>>, } #[allow(unsafe_code)] -unsafe impl Send for Tracker {} +unsafe impl Send for HandshakeTracker {} #[allow(unsafe_code)] -unsafe impl Sync for Tracker {} +unsafe impl Sync for HandshakeTracker {} -impl Tracker { +impl HandshakeTracker { pub fn new() -> Self { Self { ciphers: ArcSwapAny::new(Arc::new(Vec::new())), diff --git a/src/inner/worker.rs b/src/inner/worker.rs new file mode 100644 index 0000000..32e2148 --- /dev/null +++ b/src/inner/worker.rs @@ -0,0 +1,307 @@ +//! Worker thread implementation +use crate::{ + auth::TokenChecker, + connection::{ + self, + handshake::{ + self, + dirsync::{self, DirSync}, + Handshake, HandshakeClient, HandshakeData, + }, + socket::{UdpClient, UdpServer}, + ConnList, Connection, IDSend, Packet, ID, + }, + enc::sym::Secret, + inner::{HandshakeAction, HandshakeTracker}, +}; +use ::std::{sync::Arc, vec::Vec}; +/// This worker must be cpu-pinned +use ::tokio::{net::UdpSocket, sync::Mutex}; +use std::net::SocketAddr; + +/// Track a raw Udp packet +pub(crate) struct RawUdp { + pub src: UdpClient, + pub dst: UdpServer, + pub data: Vec, + pub packet: Packet, +} + +pub(crate) enum Work { + Recv(RawUdp), +} + +/// Actual worker implementation. +pub(crate) struct Worker { + // PERF: rand uses syscalls. how to do that async? + rand: ::ring::rand::SystemRandom, + stop_working: ::tokio::sync::broadcast::Receiver, + token_check: Option>>, + sockets: Vec, + queue: ::async_channel::Receiver, + thread_channels: Vec<::async_channel::Sender>, + connections: ConnList, + handshakes: HandshakeTracker, +} + +impl Worker { + pub(crate) async fn new( + stop_working: ::tokio::sync::broadcast::Receiver, + token_check: Option>>, + socket_addrs: Vec<::std::net::SocketAddr>, + queue: ::async_channel::Receiver, + ) -> ::std::io::Result { + // bind all sockets again so that we can easily + // send without sharing resources + // in the future we will want to have a thread-local listener too, + // but before that we need ebpf to pin a connection to a thread + // directly from the kernel + let socket_binding = + socket_addrs.into_iter().map(|s_addr| async move { + let socket = ::tokio::spawn(connection::socket::bind_udp( + s_addr.clone(), + )) + .await??; + Ok(socket) + }); + let sockets_bind_res = + ::futures::future::join_all(socket_binding).await; + let sockets: Result, ::std::io::Error> = + sockets_bind_res + .into_iter() + .map(|s_res| match s_res { + Ok(s) => Ok(s), + Err(e) => { + ::tracing::error!("Worker can't bind on socket: {}", e); + Err(e) + } + }) + .collect(); + let sockets = match sockets { + Ok(sockets) => sockets, + Err(e) => { + return Err(e); + } + }; + + Ok(Self { + rand: ::ring::rand::SystemRandom::new(), + stop_working, + token_check, + sockets, + queue, + thread_channels: Vec::new(), + connections: ConnList::new(), + handshakes: HandshakeTracker::new(), + }) + } + pub(crate) async fn work_loop(&mut self) { + loop { + let work = ::tokio::select! { + _done = self.stop_working.recv() => { + break; + } + maybe_work = self.queue.recv() => { + match maybe_work { + Ok(work) => work, + Err(_) => break, + } + } + }; + match work { + //TODO: reconf message to add channels + Work::Recv(pkt) => { + self.recv(pkt).await; + } + } + } + } + /// Read and do stuff with the raw udp packet + async fn recv(&mut self, mut udp: RawUdp) { + if udp.packet.id.is_handshake() { + let handshake = match Handshake::deserialize(&udp.data[8..]) { + Ok(handshake) => handshake, + Err(e) => { + ::tracing::warn!("Handshake parsing: {}", e); + return; + } + }; + let action = match self + .handshakes + .recv_handshake(handshake, &mut udp.data[8..]) + { + Ok(action) => action, + Err(err) => { + ::tracing::debug!("Handshake recv error {}", err); + return; + } + }; + match action { + HandshakeAction::AuthNeeded(authinfo) => { + let token_check = match self.token_check.as_ref() { + Some(token_check) => token_check, + None => { + ::tracing::error!( + "Authentication requested but \ + we have no token checker" + ); + return; + } + }; + let req; + if let HandshakeData::DirSync(DirSync::Req(r)) = + authinfo.handshake.data + { + req = r; + } else { + ::tracing::error!("AuthInfo on non DS::Req"); + return; + } + use dirsync::ReqInner; + let req_data = match req.data { + ReqInner::ClearText(req_data) => req_data, + _ => { + ::tracing::error!( + "token_check: expected ClearText" + ); + return; + } + }; + let is_authenticated = { + let tk_check = token_check.lock().await; + tk_check( + req_data.auth.user, + req_data.auth.token, + req_data.auth.service_id, + req_data.auth.domain, + ) + .await + }; + let is_authenticated = match is_authenticated { + Ok(is_authenticated) => is_authenticated, + Err(_) => { + ::tracing::error!("error in token auth"); + // TODO: retry? + return; + } + }; + if !is_authenticated { + ::tracing::warn!( + "Wrong authentication for user {:?}", + req_data.auth.user + ); + // TODO: error response + return; + } + // Client has correctly authenticated + // TODO: contact the service, get the key and + // connection ID + let srv_conn_id = ID::new_rand(&self.rand); + let srv_secret = Secret::new_rand(&self.rand); + let head_len = req.cipher.nonce_len(); + let tag_len = req.cipher.tag_len(); + + let mut raw_conn = Connection::new( + authinfo.hkdf, + req.cipher, + connection::Role::Server, + &self.rand, + ); + raw_conn.id_send = IDSend(req_data.id); + // track connection + let auth_conn = self.connections.reserve_first(raw_conn); + + let resp_data = dirsync::RespData { + client_nonce: req_data.nonce, + id: auth_conn.id_recv.0, + service_id: srv_conn_id, + service_key: srv_secret, + }; + use crate::enc::sym::AAD; + // no aad for now + let aad = AAD(&mut []); + + use dirsync::RespInner; + let resp = dirsync::Resp { + client_key_id: req_data.client_key_id, + data: RespInner::ClearText(resp_data), + }; + let offset_to_encrypt = resp.encrypted_offset(); + let encrypt_until = + offset_to_encrypt + resp.encrypted_length() + tag_len.0; + let resp_handshake = Handshake::new( + HandshakeData::DirSync(DirSync::Resp(resp)), + ); + use connection::{PacketData, ID}; + let packet = Packet { + id: ID::new_handshake(), + data: PacketData::Handshake(resp_handshake), + }; + let mut raw_out = Vec::::with_capacity(packet.len()); + packet.serialize(head_len, tag_len, &mut raw_out); + + if let Err(e) = auth_conn.cipher_send.encrypt( + aad, + &mut raw_out[offset_to_encrypt..encrypt_until], + ) { + ::tracing::error!("can't encrypt: {:?}", e); + return; + } + self.send_packet(raw_out, udp.src, udp.dst).await; + return; + } + HandshakeAction::ClientConnect(mut cci) => { + let ds_resp; + if let HandshakeData::DirSync(DirSync::Resp(resp)) = + cci.handshake.data + { + ds_resp = resp; + } else { + ::tracing::error!("ClientConnect on non DS::Resp"); + return; + } + // track connection + use handshake::dirsync; + let resp_data; + if let dirsync::RespInner::ClearText(r_data) = ds_resp.data + { + resp_data = r_data; + } else { + ::tracing::error!( + "ClientConnect on non DS::Resp::ClearText" + ); + return; + } + // FIXME: conn tracking and arc counting + let conn = Arc::get_mut(&mut cci.connection).unwrap(); + conn.id_send = IDSend(resp_data.id); + todo!(); + } + _ => {} + }; + } + // copy packet, spawn + todo!(); + } + async fn send_packet( + &self, + data: Vec, + client: UdpClient, + server: UdpServer, + ) { + let src_sock = match self + .sockets + .iter() + .find(|&s| s.local_addr().unwrap() == server.0) + { + Some(src_sock) => src_sock, + None => { + ::tracing::error!( + "Can't send packet: Server changed listening ip!" + ); + return; + } + }; + src_sock.send_to(&data, client.0); + } +} diff --git a/src/lib.rs b/src/lib.rs index 94e67e7..5c8a5d7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,32 +20,21 @@ pub mod dnssec; pub mod enc; mod inner; -use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption}; use ::std::{ net::SocketAddr, - pin::Pin, sync::{Arc, Weak}, - vec::{self, Vec}, -}; -use ::tokio::{ - macros::support::Future, net::UdpSocket, sync::RwLock, task::JoinHandle, + vec::Vec, }; +use ::tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle}; use crate::{ + auth::TokenChecker, connection::{ - handshake::{ - self, dirsync::DirSync, Handshake, HandshakeClient, HandshakeData, - HandshakeServer, - }, + handshake, socket::{SocketList, UdpClient, UdpServer}, - ConnList, Connection, IDSend, Packet, + Packet, }, - enc::{ - asym, - hkdf::HkdfSha3, - sym::{CipherKind, CipherRecv, CipherSend, HeadLen, TagLen}, - }, - inner::HandshakeAction, + inner::worker::{RawUdp, Work, Worker}, }; pub use config::Config; @@ -55,6 +44,9 @@ pub enum Error { /// The library was not initialized (run .start()) #[error("not initialized")] NotInitialized, + /// Error in setting up worker threads + #[error("Setup err: {0}")] + Setup(String), /// General I/O error #[error("IO: {0:?}")] IO(#[from] ::std::io::Error), @@ -69,26 +61,6 @@ pub enum Error { Key(#[from] crate::enc::Error), } -type TokenChecker = - fn( - user: auth::UserID, - token: auth::Token, - service_id: auth::ServiceID, - domain: auth::Domain, - ) -> ::futures::future::BoxFuture<'static, Result>; - -/// Track a raw Udp packet -struct RawUdp { - src: UdpClient, - dst: UdpServer, - data: Vec, - packet: Packet, -} - -enum Work { - Recv(RawUdp), -} - /// Instance of a fenrir endpoint #[allow(missing_copy_implementations, missing_debug_implementations)] pub struct Fenrir { @@ -101,14 +73,11 @@ pub struct Fenrir { /// Broadcast channel to tell workers to stop working stop_working: ::tokio::sync::broadcast::Sender, /// Private keys used in the handshake - _inner: Arc, + _inner: Arc, /// where to ask for token check - token_check: Arc>, + token_check: Option>>, // PERF: rand uses syscalls. should we do that async? rand: ::ring::rand::SystemRandom, - /// list of Established connections - connections: Arc>, - _myself: Weak, // TODO: find a way to both increase and decrease these two in a thread-safe // manner _thread_pool: Vec<::std::thread::JoinHandle<()>>, @@ -125,28 +94,30 @@ impl Drop for Fenrir { impl Fenrir { /// Create a new Fenrir endpoint - pub fn new(config: &Config) -> Result, Error> { + pub fn new(config: &Config) -> Result { let listen_num = config.listen.len(); let (sender, _) = ::tokio::sync::broadcast::channel(1); let (work_send, work_recv) = ::async_channel::unbounded::(); - let endpoint = Arc::new_cyclic(|myself| Fenrir { + let endpoint = Fenrir { cfg: config.clone(), sockets: SocketList::new(), dnssec: None, stop_working: sender, - _inner: Arc::new(inner::Tracker::new()), - token_check: Arc::new(ArcSwapOption::from(None)), + _inner: Arc::new(inner::HandshakeTracker::new()), + token_check: None, rand: ::ring::rand::SystemRandom::new(), - connections: Arc::new(RwLock::new(ConnList::new())), - _myself: myself.clone(), _thread_pool: Vec::new(), _thread_work: Arc::new(Vec::new()), - }); + }; Ok(endpoint) } /// Start all workers, listeners - pub async fn start(&mut self) -> Result<(), Error> { + pub async fn start( + &mut self, + tokio_rt: Arc<::tokio::runtime::Runtime>, + ) -> Result<(), Error> { + self.start_work_threads_pinned(tokio_rt).await?; if let Err(e) = self.add_sockets().await { self.stop().await; return Err(e.into()); @@ -159,17 +130,25 @@ impl Fenrir { /// asyncronous version for Drop fn stop_sync(&mut self) { let _ = self.stop_working.send(true); + // FIXME: wait for thread pool to actually stop let mut toempty_sockets = self.sockets.rm_all(); let task = ::tokio::task::spawn(toempty_sockets.stop_all()); let _ = ::futures::executor::block_on(task); + let mut old_thread_pool = Vec::new(); + ::std::mem::swap(&mut self._thread_pool, &mut old_thread_pool); + old_thread_pool.into_iter().map(|th| th.join()); self.dnssec = None; } /// Stop all workers, listeners pub async fn stop(&mut self) { let _ = self.stop_working.send(true); + // FIXME: wait for thread pool to actually stop let mut toempty_sockets = self.sockets.rm_all(); toempty_sockets.stop_all().await; + let mut old_thread_pool = Vec::new(); + ::std::mem::swap(&mut self._thread_pool, &mut old_thread_pool); + old_thread_pool.into_iter().map(|th| th.join()); self.dnssec = None; } /// Add all UDP sockets found in config @@ -271,14 +250,18 @@ impl Fenrir { /// Start one working thread for each physical cpu /// threads are pinned to each cpu core. /// Work will be divided and rerouted so that there is no need to lock - pub async fn start_work_threads_pinned( + async fn start_work_threads_pinned( &mut self, tokio_rt: Arc<::tokio::runtime::Runtime>, - ) -> ::std::result::Result<(), ()> { + ) -> ::std::result::Result<(), Error> { use ::std::sync::Mutex; let hw_topology = match ::hwloc2::Topology::new() { Some(hw_topology) => Arc::new(Mutex::new(hw_topology)), - None => return Err(()), + None => { + return Err(Error::Setup( + "Can't get hardware topology".to_owned(), + )) + } }; let cores; { @@ -287,20 +270,36 @@ impl Fenrir { .objects_with_type(&::hwloc2::ObjectType::Core) { Ok(all_cores) => all_cores, - Err(_) => return Err(()), + Err(_) => { + return Err(Error::Setup("can't list cores".to_owned())) + } }; cores = all_cores.len(); if cores <= 0 || !topology_lock.support().cpu().set_thread() { ::tracing::error!("No support for CPU pinning"); - return Err(()); + return Err(Error::Setup("No cpu pinning support".to_owned())); } } for core in 0..cores { ::tracing::debug!("Spawning thread {}", core); let th_topology = hw_topology.clone(); let th_tokio_rt = tokio_rt.clone(); - let th_myself = self._myself.upgrade().unwrap(); let (work_send, work_recv) = ::async_channel::unbounded::(); + let mut worker = match Worker::new( + self.stop_working.subscribe(), + self.token_check.clone(), + self.cfg.listen.clone(), + work_recv, + ) + .await + { + Ok(worker) => worker, + Err(e) => { + ::tracing::error!("can't start worker"); + return Err(Error::IO(e)); + } + }; + let join_handle = ::std::thread::spawn(move || { // bind to a specific core let th_pinning; @@ -322,13 +321,10 @@ impl Fenrir { return; } } - // finally run the main listener. make sure things stay on this - // thread + // finally run the main worker. + // make sure things stay on this thread let tk_local = ::tokio::task::LocalSet::new(); - let _ = tk_local.block_on( - &th_tokio_rt, - Self::work_loop_thread(th_myself, work_recv), - ); + let _ = tk_local.block_on(&th_tokio_rt, worker.work_loop()); }); loop { let queues_lock = match Arc::get_mut(&mut self._thread_work) { @@ -348,219 +344,4 @@ impl Fenrir { } Ok(()) } - async fn work_loop_thread( - self: Arc, - work_recv: ::async_channel::Receiver, - ) { - let mut stop_working = self.stop_working.subscribe(); - loop { - let work = ::tokio::select! { - _done = stop_working.recv() => { - break; - } - maybe_work = work_recv.recv() => { - match maybe_work { - Ok(work) => work, - Err(_) => break, - } - } - }; - match work { - Work::Recv(pkt) => { - self.recv(pkt).await; - } - } - } - } - - /// Read and do stuff with the raw udp packet - async fn recv(&self, mut udp: RawUdp) { - if udp.packet.id.is_handshake() { - let handshake = match Handshake::deserialize(&udp.data[8..]) { - Ok(handshake) => handshake, - Err(e) => { - ::tracing::warn!("Handshake parsing: {}", e); - return; - } - }; - let action = - match self._inner.recv_handshake(handshake, &mut udp.data[8..]) - { - Ok(action) => action, - Err(err) => { - ::tracing::debug!("Handshake recv error {}", err); - return; - } - }; - match action { - HandshakeAction::AuthNeeded(authinfo) => { - let tk_check = match self.token_check.load_full() { - Some(tk_check) => tk_check, - None => { - ::tracing::error!( - "Handshake received, but no tocken_checker" - ); - return; - } - }; - use handshake::{ - dirsync::{self, DirSync}, - HandshakeData, - }; - let req; - if let HandshakeData::DirSync(DirSync::Req(r)) = - authinfo.handshake.data - { - req = r; - } else { - ::tracing::error!("AuthInfo on non DS::Req"); - return; - } - use dirsync::ReqInner; - let req_data = match req.data { - ReqInner::ClearText(req_data) => req_data, - _ => { - ::tracing::error!( - "token_check: expected ClearText" - ); - return; - } - }; - let is_authenticated = match tk_check( - req_data.auth.user, - req_data.auth.token, - req_data.auth.service_id, - req_data.auth.domain, - ) - .await - { - Ok(is_authenticated) => is_authenticated, - Err(_) => { - ::tracing::error!("error in token auth"); - // TODO: retry? - return; - } - }; - if !is_authenticated { - ::tracing::warn!( - "Wrong authentication for user {:?}", - req_data.auth.user - ); - // TODO: error response - return; - } - // Client has correctly authenticated - // TODO: contact the service, get the key and - // connection ID - let srv_conn_id = connection::ID::new_rand(&self.rand); - let srv_secret = enc::sym::Secret::new_rand(&self.rand); - let head_len = req.cipher.nonce_len(); - let tag_len = req.cipher.tag_len(); - - let mut raw_conn = Connection::new( - authinfo.hkdf, - req.cipher, - connection::Role::Server, - &self.rand, - ); - raw_conn.id_send = IDSend(req_data.id); - // track connection - let auth_conn = { - let mut lock = self.connections.write().await; - lock.reserve_first(raw_conn) - }; - - let resp_data = dirsync::RespData { - client_nonce: req_data.nonce, - id: auth_conn.id_recv.0, - service_id: srv_conn_id, - service_key: srv_secret, - }; - use crate::enc::sym::AAD; - // no aad for now - let aad = AAD(&mut []); - - use dirsync::RespInner; - let resp = dirsync::Resp { - client_key_id: req_data.client_key_id, - data: RespInner::ClearText(resp_data), - }; - let offset_to_encrypt = resp.encrypted_offset(); - let encrypt_until = - offset_to_encrypt + resp.encrypted_length() + tag_len.0; - let resp_handshake = Handshake::new( - HandshakeData::DirSync(DirSync::Resp(resp)), - ); - use connection::{PacketData, ID}; - let packet = Packet { - id: ID::new_handshake(), - data: PacketData::Handshake(resp_handshake), - }; - let mut raw_out = Vec::::with_capacity(packet.len()); - packet.serialize(head_len, tag_len, &mut raw_out); - - if let Err(e) = auth_conn.cipher_send.encrypt( - aad, - &mut raw_out[offset_to_encrypt..encrypt_until], - ) { - ::tracing::error!("can't encrypt: {:?}", e); - return; - } - self.send_packet(raw_out, udp.src, udp.dst).await; - return; - } - HandshakeAction::ClientConnect(mut cci) => { - let ds_resp; - if let HandshakeData::DirSync(DirSync::Resp(resp)) = - cci.handshake.data - { - ds_resp = resp; - } else { - ::tracing::error!("ClientConnect on non DS::Resp"); - return; - } - // track connection - use handshake::dirsync; - let resp_data; - if let dirsync::RespInner::ClearText(r_data) = ds_resp.data - { - resp_data = r_data; - } else { - ::tracing::error!( - "ClientConnect on non DS::Resp::ClearText" - ); - return; - } - // FIXME: conn tracking and arc counting - let conn = Arc::get_mut(&mut cci.connection).unwrap(); - conn.id_send = IDSend(resp_data.id); - todo!(); - } - _ => {} - }; - } - // copy packet, spawn - todo!(); - } - async fn send_packet( - &self, - data: Vec, - client: UdpClient, - server: UdpServer, - ) { - let src_sock; - { - let sockets = self.sockets.lock(); - src_sock = match sockets.find(server) { - Some(src_sock) => src_sock, - None => { - ::tracing::error!( - "Can't send packet: Server changed listening ip!" - ); - return; - } - }; - } - src_sock.send_to(&data, client.0); - } } From 810cc16ce6b016cd3e1ef063d2f05183dfd76055 Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Wed, 24 May 2023 17:30:15 +0200 Subject: [PATCH 07/34] More thread-pinning work. No more Arc, Rc is better on the same thread. Track the thread number so we can generate the correct connection IDs Signed-off-by: Luca Fulchir --- src/connection/handshake/mod.rs | 4 +-- src/connection/mod.rs | 35 ++++++++++-------- src/inner/mod.rs | 64 +++++++++++++++++++-------------- src/inner/worker.rs | 34 +++++++++++++++--- src/lib.rs | 43 +++++++++++----------- 5 files changed, 109 insertions(+), 71 deletions(-) diff --git a/src/connection/handshake/mod.rs b/src/connection/handshake/mod.rs index 99eab75..a231bce 100644 --- a/src/connection/handshake/mod.rs +++ b/src/connection/handshake/mod.rs @@ -7,7 +7,7 @@ use crate::{ enc::sym::{HeadLen, TagLen}, }; use ::num_traits::FromPrimitive; -use ::std::sync::Arc; +use ::std::{rc::Rc, sync::Arc}; /// Handshake errors #[derive(::thiserror::Error, Debug, Copy, Clone)] @@ -36,7 +36,7 @@ pub(crate) struct HandshakeServer { pub(crate) struct HandshakeClient { pub id: crate::enc::asym::KeyID, pub key: crate::enc::asym::PrivKey, - pub connection: Arc, + pub connection: Rc, } /// Parsed handshake diff --git a/src/connection/mod.rs b/src/connection/mod.rs index fe75884..b79d910 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -4,16 +4,19 @@ pub mod handshake; pub mod packet; pub mod socket; -use ::std::{sync::Arc, vec::Vec}; +use ::std::{rc::Rc, sync::Arc, vec::Vec}; pub use crate::connection::{ handshake::Handshake, packet::{ConnectionID as ID, Packet, PacketData}, }; -use crate::enc::{ - hkdf::HkdfSha3, - sym::{CipherKind, CipherRecv, CipherSend}, +use crate::{ + enc::{ + hkdf::HkdfSha3, + sym::{CipherKind, CipherRecv, CipherSend}, + }, + inner::ThreadTracker, }; /// strong typedef for receiving connection id @@ -99,16 +102,18 @@ impl Connection { // PERF: Arc> loks a bit too much, need to find // faster ways to do this pub(crate) struct ConnList { - connections: Vec>>, + thread_id: ThreadTracker, + connections: Vec>>, /// Bitmap to track which connection ids are used or free ids_used: Vec<::bitmaps::Bitmap<1024>>, } impl ConnList { - pub(crate) fn new() -> Self { + pub(crate) fn new(thread_id: ThreadTracker) -> Self { let mut bitmap_id = ::bitmaps::Bitmap::<1024>::new(); bitmap_id.set(0, true); // ID(0) == handshake Self { + thread_id, connections: Vec::with_capacity(128), ids_used: vec![bitmap_id], } @@ -116,20 +121,20 @@ impl ConnList { pub(crate) fn reserve_first( &mut self, mut conn: Connection, - ) -> Arc { + ) -> Rc { // uhm... bad things are going on here: // * id must be initialized, but only because: // * rust does not understand that after the `!found` id is always // initialized // * `ID::new_u64` is really safe only with >0, but here it always is // ...we should probably rewrite it in better, safer rust - let mut id: u64 = 0; + let mut id_in_thread: u64 = 0; let mut found = false; for (i, b) in self.ids_used.iter_mut().enumerate() { match b.first_false_index() { Some(idx) => { b.set(idx, true); - id = ((i as u64) * 1024) + (idx as u64); + id_in_thread = ((i as u64) * 1024) + (idx as u64); found = true; break; } @@ -139,17 +144,19 @@ impl ConnList { if !found { let mut new_bitmap = ::bitmaps::Bitmap::<1024>::new(); new_bitmap.set(0, true); - id = (self.ids_used.len() as u64) * 1024; + id_in_thread = (self.ids_used.len() as u64) * 1024; self.ids_used.push(new_bitmap); } - let new_id = IDRecv(ID::new_u64(id)); + let actual_id = (id_in_thread * (self.thread_id.total as u64)) + + (self.thread_id.id as u64); + let new_id = IDRecv(ID::new_u64(actual_id)); conn.id_recv = new_id; - let conn = Arc::new(conn); - if (self.connections.len() as u64) < id { + let conn = Rc::new(conn); + if (self.connections.len() as u64) < id_in_thread { self.connections.push(Some(conn.clone())); } else { // very probably redundant - self.connections[id as usize] = Some(conn.clone()); + self.connections[id_in_thread as usize] = Some(conn.clone()); } conn } diff --git a/src/inner/mod.rs b/src/inner/mod.rs index 73490f8..c2569ec 100644 --- a/src/inner/mod.rs +++ b/src/inner/mod.rs @@ -19,11 +19,11 @@ use crate::{ Error, }; use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption}; -use ::std::{sync::Arc, vec::Vec}; +use ::std::{rc::Rc, sync::Arc, vec::Vec}; /// Information needed to reply after the key exchange #[derive(Debug, Clone)] -pub struct AuthNeededInfo { +pub(crate) struct AuthNeededInfo { /// Parsed handshake packet pub handshake: Handshake, /// hkdf generated from the handshake @@ -34,15 +34,15 @@ pub struct AuthNeededInfo { /// Client information needed to fully establish the conenction #[derive(Debug)] -pub struct ClientConnectInfo { +pub(crate) struct ClientConnectInfo { /// Parsed handshake packet pub handshake: Handshake, /// Connection - pub connection: Arc, + pub connection: Rc, } /// Intermediate actions to be taken while parsing the handshake #[derive(Debug)] -pub enum HandshakeAction { +pub(crate) enum HandshakeAction { /// Parsing finished, all ok, nothing to do None, /// Packet parsed, now go perform authentication @@ -51,14 +51,28 @@ pub enum HandshakeAction { ClientConnect(ClientConnectInfo), } +/// Track the total number of threads and our index +/// 65K cpus should be enough for anybody +#[derive(Debug, Clone, Copy)] +pub(crate) struct ThreadTracker { + pub total: u16, + /// Note: starts from 1 + pub id: u16, +} + /// Async free but thread safe tracking of handhsakes and conenctions -pub struct HandshakeTracker { - key_exchanges: ArcSwapAny>>, - ciphers: ArcSwapAny>>, +/// Note that we have multiple Handshake trackers, pinned to different cores +/// Each of them will handle a subset of all handshakes. +/// Each handshake is routed to a different tracker with: +/// (udp_src_sender_port % total_threads) - 1 +pub(crate) struct HandshakeTracker { + thread_id: ThreadTracker, + key_exchanges: Vec<(asym::Key, asym::KeyExchange)>, + ciphers: Vec, /// ephemeral keys used server side in key exchange - keys_srv: ArcSwapAny>>, + keys_srv: Vec, /// ephemeral keys used client side in key exchange - hshake_cli: ArcSwapAny>>, + hshake_cli: Vec, } #[allow(unsafe_code)] unsafe impl Send for HandshakeTracker {} @@ -66,12 +80,13 @@ unsafe impl Send for HandshakeTracker {} unsafe impl Sync for HandshakeTracker {} impl HandshakeTracker { - pub fn new() -> Self { + pub(crate) fn new(thread_id: ThreadTracker) -> Self { Self { - ciphers: ArcSwapAny::new(Arc::new(Vec::new())), - key_exchanges: ArcSwapAny::new(Arc::new(Vec::new())), - keys_srv: ArcSwapAny::new(Arc::new(Vec::new())), - hshake_cli: ArcSwapAny::new(Arc::new(Vec::new())), + thread_id, + ciphers: Vec::new(), + key_exchanges: Vec::new(), + keys_srv: Vec::new(), + hshake_cli: Vec::new(), } } pub(crate) fn recv_handshake( @@ -87,11 +102,8 @@ impl HandshakeTracker { HandshakeData::DirSync(ref mut ds) => match ds { DirSync::Req(ref mut req) => { let ephemeral_key = { - // Keep this block short to avoid contention - // on self.keys_srv - let keys = self.keys_srv.load(); if let Some(h_k) = - keys.iter().find(|k| k.id == req.key_id) + self.keys_srv.iter().find(|k| k.id == req.key_id) { use enc::asym::PrivKey; // Directory synchronized can only use keys @@ -114,9 +126,8 @@ impl HandshakeTracker { } let ephemeral_key = ephemeral_key.unwrap(); { - let exchanges = self.key_exchanges.load(); if None - == exchanges.iter().find(|&x| { + == self.key_exchanges.iter().find(|&x| { *x == (ephemeral_key.kind(), req.exchange) }) { @@ -126,8 +137,9 @@ impl HandshakeTracker { } } { - let ciphers = self.ciphers.load(); - if None == ciphers.iter().find(|&x| *x == req.cipher) { + if None + == self.ciphers.iter().find(|&x| *x == req.cipher) + { return Err(enc::Error::UnsupportedCipher.into()); } } @@ -164,10 +176,8 @@ impl HandshakeTracker { } DirSync::Resp(resp) => { let hshake = { - // Keep this block short to avoid contention - // on self.hshake_cli - let hshake_cli_lock = self.hshake_cli.load(); - match hshake_cli_lock + match self + .hshake_cli .iter() .find(|h| h.id == resp.client_key_id) { diff --git a/src/inner/worker.rs b/src/inner/worker.rs index 32e2148..7e3bd2e 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -12,9 +12,9 @@ use crate::{ ConnList, Connection, IDSend, Packet, ID, }, enc::sym::Secret, - inner::{HandshakeAction, HandshakeTracker}, + inner::{HandshakeAction, HandshakeTracker, ThreadTracker}, }; -use ::std::{sync::Arc, vec::Vec}; +use ::std::{rc::Rc, sync::Arc, vec::Vec}; /// This worker must be cpu-pinned use ::tokio::{net::UdpSocket, sync::Mutex}; use std::net::SocketAddr; @@ -33,6 +33,7 @@ pub(crate) enum Work { /// Actual worker implementation. pub(crate) struct Worker { + thread_id: ThreadTracker, // PERF: rand uses syscalls. how to do that async? rand: ::ring::rand::SystemRandom, stop_working: ::tokio::sync::broadcast::Receiver, @@ -45,7 +46,27 @@ pub(crate) struct Worker { } impl Worker { + pub(crate) async fn new_and_loop( + thread_id: ThreadTracker, + stop_working: ::tokio::sync::broadcast::Receiver, + token_check: Option>>, + socket_addrs: Vec<::std::net::SocketAddr>, + queue: ::async_channel::Receiver, + ) -> ::std::io::Result<()> { + // TODO: get a channel to send back information, and send the error + let mut worker = Self::new( + thread_id, + stop_working, + token_check, + socket_addrs, + queue, + ) + .await?; + worker.work_loop().await; + Ok(()) + } pub(crate) async fn new( + thread_id: ThreadTracker, stop_working: ::tokio::sync::broadcast::Receiver, token_check: Option>>, socket_addrs: Vec<::std::net::SocketAddr>, @@ -85,14 +106,15 @@ impl Worker { }; Ok(Self { + thread_id, rand: ::ring::rand::SystemRandom::new(), stop_working, token_check, sockets, queue, thread_channels: Vec::new(), - connections: ConnList::new(), - handshakes: HandshakeTracker::new(), + connections: ConnList::new(thread_id), + handshakes: HandshakeTracker::new(thread_id), }) } pub(crate) async fn work_loop(&mut self) { @@ -167,6 +189,8 @@ impl Worker { return; } }; + // FIXME: This part can take a while, + // we should just spawn it probably let is_authenticated = { let tk_check = token_check.lock().await; tk_check( @@ -273,7 +297,7 @@ impl Worker { return; } // FIXME: conn tracking and arc counting - let conn = Arc::get_mut(&mut cci.connection).unwrap(); + let conn = Rc::get_mut(&mut cci.connection).unwrap(); conn.id_send = IDSend(resp_data.id); todo!(); } diff --git a/src/lib.rs b/src/lib.rs index 5c8a5d7..02df264 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -34,7 +34,10 @@ use crate::{ socket::{SocketList, UdpClient, UdpServer}, Packet, }, - inner::worker::{RawUdp, Work, Worker}, + inner::{ + worker::{RawUdp, Work, Worker}, + ThreadTracker, + }, }; pub use config::Config; @@ -72,12 +75,8 @@ pub struct Fenrir { dnssec: Option, /// Broadcast channel to tell workers to stop working stop_working: ::tokio::sync::broadcast::Sender, - /// Private keys used in the handshake - _inner: Arc, /// where to ask for token check token_check: Option>>, - // PERF: rand uses syscalls. should we do that async? - rand: ::ring::rand::SystemRandom, // TODO: find a way to both increase and decrease these two in a thread-safe // manner _thread_pool: Vec<::std::thread::JoinHandle<()>>, @@ -103,9 +102,7 @@ impl Fenrir { sockets: SocketList::new(), dnssec: None, stop_working: sender, - _inner: Arc::new(inner::HandshakeTracker::new()), token_check: None, - rand: ::ring::rand::SystemRandom::new(), _thread_pool: Vec::new(), _thread_work: Arc::new(Vec::new()), }; @@ -130,7 +127,6 @@ impl Fenrir { /// asyncronous version for Drop fn stop_sync(&mut self) { let _ = self.stop_working.send(true); - // FIXME: wait for thread pool to actually stop let mut toempty_sockets = self.sockets.rm_all(); let task = ::tokio::task::spawn(toempty_sockets.stop_all()); let _ = ::futures::executor::block_on(task); @@ -143,7 +139,6 @@ impl Fenrir { /// Stop all workers, listeners pub async fn stop(&mut self) { let _ = self.stop_working.send(true); - // FIXME: wait for thread pool to actually stop let mut toempty_sockets = self.sockets.rm_all(); toempty_sockets.stop_all().await; let mut old_thread_pool = Vec::new(); @@ -285,19 +280,12 @@ impl Fenrir { let th_topology = hw_topology.clone(); let th_tokio_rt = tokio_rt.clone(); let (work_send, work_recv) = ::async_channel::unbounded::(); - let mut worker = match Worker::new( - self.stop_working.subscribe(), - self.token_check.clone(), - self.cfg.listen.clone(), - work_recv, - ) - .await - { - Ok(worker) => worker, - Err(e) => { - ::tracing::error!("can't start worker"); - return Err(Error::IO(e)); - } + let th_stop_working = self.stop_working.subscribe(); + let th_token_check = self.token_check.clone(); + let th_socket_addrs = self.cfg.listen.clone(); + let thread_id = ThreadTracker { + total: cores as u16, + id: 1 + (core as u16), }; let join_handle = ::std::thread::spawn(move || { @@ -324,7 +312,16 @@ impl Fenrir { // finally run the main worker. // make sure things stay on this thread let tk_local = ::tokio::task::LocalSet::new(); - let _ = tk_local.block_on(&th_tokio_rt, worker.work_loop()); + let _ = tk_local.block_on( + &th_tokio_rt, + Worker::new_and_loop( + thread_id, + th_stop_working, + th_token_check, + th_socket_addrs, + work_recv, + ), + ); }); loop { let queues_lock = match Arc::get_mut(&mut self._thread_work) { From 428754069593e4269d23fb5e17c6a4c84f621160 Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Wed, 24 May 2023 22:32:41 +0200 Subject: [PATCH 08/34] Upgrade flakes to 23.05 Signed-off-by: Luca Fulchir --- flake.lock | 20 ++++++++++---------- flake.nix | 4 ++-- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/flake.lock b/flake.lock index ce2e37c..7a5770d 100644 --- a/flake.lock +++ b/flake.lock @@ -38,27 +38,27 @@ }, "nixpkgs": { "locked": { - "lastModified": 1683478192, - "narHash": "sha256-7f7RR71w0jRABDgBwjq3vE1yY3nrVJyXk8hDzu5kl1E=", + "lastModified": 1684922889, + "narHash": "sha256-l0WZAmln8959O7RdYUJ3gnAIM9OPKFLKHKGX4q+Blrk=", "owner": "nixos", "repo": "nixpkgs", - "rev": "c568239bcc990050b7aedadb7387832440ad8fb1", + "rev": "04aaf8511678a0d0f347fdf1e8072fe01e4a509e", "type": "github" }, "original": { "owner": "nixos", - "ref": "nixos-22.11", + "ref": "nixos-23.05", "repo": "nixpkgs", "type": "github" } }, "nixpkgs-unstable": { "locked": { - "lastModified": 1683408522, - "narHash": "sha256-9kcPh6Uxo17a3kK3XCHhcWiV1Yu1kYj22RHiymUhMkU=", + "lastModified": 1684844536, + "narHash": "sha256-M7HhXYVqAuNb25r/d3FOO0z4GxPqDIZp5UjHFbBgw0Q=", "owner": "nixos", "repo": "nixpkgs", - "rev": "897876e4c484f1e8f92009fd11b7d988a121a4e7", + "rev": "d30264c2691128adc261d7c9388033645f0e742b", "type": "github" }, "original": { @@ -98,11 +98,11 @@ "nixpkgs": "nixpkgs_2" }, "locked": { - "lastModified": 1683512408, - "narHash": "sha256-QMJGp/37En+d5YocJuSU89GL14bBYkIJQ6mqhRfqkkc=", + "lastModified": 1684894917, + "narHash": "sha256-kwKCfmliHIxKuIjnM95TRcQxM/4AAEIZ+4A9nDJ6cJs=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "75b07756c3feb22cf230e75fb064c1b4c725b9bc", + "rev": "9ea38d547100edcf0da19aaebbdffa2810585495", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 37e97ab..4c4a18a 100644 --- a/flake.nix +++ b/flake.nix @@ -2,7 +2,7 @@ description = "libFenrir"; inputs = { - nixpkgs.url = "github:nixos/nixpkgs/nixos-22.11"; + nixpkgs.url = "github:nixos/nixpkgs/nixos-23.05"; nixpkgs-unstable.url = "github:nixos/nixpkgs/nixos-unstable"; rust-overlay.url = "github:oxalica/rust-overlay"; flake-utils.url = "github:numtide/flake-utils"; @@ -46,7 +46,7 @@ shellHook = '' # use zsh or other custom shell USER_SHELL="$(grep $USER /etc/passwd | cut -d ':' -f 7)" - if [ -n "$USER_SHELL" ] && [ "$USER_SHELL" != "$SHELL" ]; then + if [ -n "$USER_SHELL" ]; then exec $USER_SHELL fi ''; From e71167224c56543ff23d23c54d559859a6a8fefd Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Fri, 26 May 2023 15:02:21 +0200 Subject: [PATCH 09/34] Track auth and service connections client side Signed-off-by: Luca Fulchir --- src/auth/mod.rs | 4 ++ src/connection/handshake/dirsync.rs | 4 +- src/connection/handshake/mod.rs | 2 + src/connection/mod.rs | 70 +++++++++++++++++++++-------- src/enc/sym.rs | 15 ++++++- src/inner/mod.rs | 40 ++++++++++++----- src/inner/worker.rs | 44 +++++++++++++++--- 7 files changed, 139 insertions(+), 40 deletions(-) diff --git a/src/auth/mod.rs b/src/auth/mod.rs index 52b51e0..2d8dc65 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -103,4 +103,8 @@ impl ServiceID { pub const fn len() -> usize { 16 } + /// read the service id as bytes + pub fn as_bytes(&self) -> &[u8; 16] { + &self.0 + } } diff --git a/src/connection/handshake/dirsync.rs b/src/connection/handshake/dirsync.rs index 721ab9c..72d6da9 100644 --- a/src/connection/handshake/dirsync.rs +++ b/src/connection/handshake/dirsync.rs @@ -425,7 +425,7 @@ pub struct RespData { /// Server Connection ID pub id: ID, /// Service Connection ID - pub service_id: ID, + pub service_connection_id: ID, /// Service encryption key pub service_key: Secret, } @@ -448,7 +448,7 @@ impl RespData { self.id.serialize(&mut out[start..end]); start = end; end = end + Self::NONCE_LEN; - self.service_id.serialize(&mut out[start..end]); + self.service_connection_id.serialize(&mut out[start..end]); start = end; end = end + Self::NONCE_LEN; out[start..end].copy_from_slice(self.service_key.as_ref()); diff --git a/src/connection/handshake/mod.rs b/src/connection/handshake/mod.rs index a231bce..63de947 100644 --- a/src/connection/handshake/mod.rs +++ b/src/connection/handshake/mod.rs @@ -36,6 +36,8 @@ pub(crate) struct HandshakeServer { pub(crate) struct HandshakeClient { pub id: crate::enc::asym::KeyID, pub key: crate::enc::asym::PrivKey, + pub service_id: crate::auth::ServiceID, + pub service_conn_id: connection::IDRecv, pub connection: Rc, } diff --git a/src/connection/mod.rs b/src/connection/mod.rs index b79d910..2c85af7 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -20,10 +20,10 @@ use crate::{ }; /// strong typedef for receiving connection id -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, PartialEq)] pub struct IDRecv(pub ID); /// strong typedef for sending connection id -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, PartialEq)] pub struct IDSend(pub ID); /// Version of the fenrir protocol in use @@ -86,8 +86,8 @@ impl Connection { (hkdf.get_secret(b"to_client"), hkdf.get_secret(b"to_server")) } }; - let mut cipher_recv = CipherRecv::new(cipher, secret_recv); - let mut cipher_send = CipherSend::new(cipher, secret_send, rand); + let cipher_recv = CipherRecv::new(cipher, secret_recv); + let cipher_send = CipherSend::new(cipher, secret_send, rand); Self { id_recv: IDRecv(ID::Handshake), @@ -111,13 +111,17 @@ pub(crate) struct ConnList { impl ConnList { pub(crate) fn new(thread_id: ThreadTracker) -> Self { let mut bitmap_id = ::bitmaps::Bitmap::<1024>::new(); - bitmap_id.set(0, true); // ID(0) == handshake - Self { + const INITIAL_CAP: usize = 128; + let mut ret = Self { thread_id, - connections: Vec::with_capacity(128), + connections: Vec::with_capacity(INITIAL_CAP), ids_used: vec![bitmap_id], - } + }; + ret.connections.resize_with(INITIAL_CAP, || None); + ret } + /// Only *Reserve* a connection, + /// without actually tracking it in self.connections pub(crate) fn reserve_first( &mut self, mut conn: Connection, @@ -128,13 +132,13 @@ impl ConnList { // initialized // * `ID::new_u64` is really safe only with >0, but here it always is // ...we should probably rewrite it in better, safer rust - let mut id_in_thread: u64 = 0; + let mut id_in_thread: usize = 0; let mut found = false; for (i, b) in self.ids_used.iter_mut().enumerate() { match b.first_false_index() { Some(idx) => { b.set(idx, true); - id_in_thread = ((i as u64) * 1024) + (idx as u64); + id_in_thread = (i * 1024) + idx; found = true; break; } @@ -144,20 +148,48 @@ impl ConnList { if !found { let mut new_bitmap = ::bitmaps::Bitmap::<1024>::new(); new_bitmap.set(0, true); - id_in_thread = (self.ids_used.len() as u64) * 1024; + id_in_thread = self.ids_used.len() * 1024; self.ids_used.push(new_bitmap); } - let actual_id = (id_in_thread * (self.thread_id.total as u64)) + // make sure we have enough space in self.connections + let curr_capacity = self.connections.capacity(); + if self.connections.capacity() <= id_in_thread { + // Fill with "None", assure 64 connections without reallocations + let multiple = 64 + curr_capacity - 1; + let new_capacity = multiple - (multiple % curr_capacity); + self.connections.resize_with(new_capacity, || None); + } + // calculate the actual connection ID + let actual_id = ((id_in_thread as u64) * (self.thread_id.total as u64)) + (self.thread_id.id as u64); let new_id = IDRecv(ID::new_u64(actual_id)); conn.id_recv = new_id; - let conn = Rc::new(conn); - if (self.connections.len() as u64) < id_in_thread { - self.connections.push(Some(conn.clone())); - } else { - // very probably redundant - self.connections[id_in_thread as usize] = Some(conn.clone()); + // Return the new connection without tracking it + Rc::new(conn) + } + /// NOTE: does NOT check if the connection has been previously reserved! + pub(crate) fn track(&mut self, conn: Rc) -> Result<(), ()> { + let conn_id = match conn.id_recv { + IDRecv(ID::Handshake) => { + return Err(()); + } + IDRecv(ID::ID(conn_id)) => conn_id, + }; + let id_in_thread: usize = + (conn_id.get() / (self.thread_id.total as u64)) as usize; + self.connections[id_in_thread] = Some(conn); + Ok(()) + } + pub(crate) fn delete(&mut self, id: IDRecv) { + if let IDRecv(ID::ID(raw_id)) = id { + let id_in_thread: usize = + (raw_id.get() / (self.thread_id.total as u64)) as usize; + let vec_index = id_in_thread / 1024; + let bitmask_index = id_in_thread % 1024; + if let Some(bitmask) = self.ids_used.get_mut(vec_index) { + bitmask.set(bitmask_index, false); + self.connections[id_in_thread] = None; + } } - conn } } diff --git a/src/enc/sym.rs b/src/enc/sym.rs index e8db1d6..3de267c 100644 --- a/src/enc/sym.rs +++ b/src/enc/sym.rs @@ -115,6 +115,11 @@ impl Cipher { } } } + pub fn kind(&self) -> CipherKind { + match self { + Cipher::XChaCha20Poly1305(_) => CipherKind::XChaCha20Poly1305, + } + } fn nonce_len(&self) -> HeadLen { match self { Cipher::XChaCha20Poly1305(_) => { @@ -181,7 +186,7 @@ impl Cipher { aad: AAD, data: &mut [u8], ) -> Result<(), Error> { - // FIXME: check minimum buffer size + // FIXME: check minimum buffer size match self { Cipher::XChaCha20Poly1305(cipher) => { use ::chacha20poly1305::{ @@ -242,6 +247,10 @@ impl CipherRecv { ) -> Result<&'a [u8], Error> { self.0.decrypt(aad, data) } + /// return the underlying cipher id + pub fn kind(&self) -> CipherKind { + self.0.kind() + } } /// Allocate some data, with additional indexes to track @@ -313,6 +322,10 @@ impl CipherSend { self.cipher.encrypt(&old_nonce, aad, data)?; Ok(()) } + /// return the underlying cipher id + pub fn kind(&self) -> CipherKind { + self.cipher.kind() + } } /// XChaCha20Poly1305 cipher diff --git a/src/inner/mod.rs b/src/inner/mod.rs index c2569ec..72447f2 100644 --- a/src/inner/mod.rs +++ b/src/inner/mod.rs @@ -35,6 +35,10 @@ pub(crate) struct AuthNeededInfo { /// Client information needed to fully establish the conenction #[derive(Debug)] pub(crate) struct ClientConnectInfo { + /// The service ID that we are connecting to + pub service_id: auth::ServiceID, + /// The service ID that we are connecting to + pub service_connection_id: connection::IDRecv, /// Parsed handshake packet pub handshake: Handshake, /// Connection @@ -90,7 +94,7 @@ impl HandshakeTracker { } } pub(crate) fn recv_handshake( - &self, + &mut self, mut handshake: Handshake, handshake_raw: &mut [u8], ) -> Result { @@ -175,24 +179,28 @@ impl HandshakeTracker { })); } DirSync::Resp(resp) => { - let hshake = { + let hshake_idx = { match self .hshake_cli .iter() - .find(|h| h.id == resp.client_key_id) + .position(|h| h.id == resp.client_key_id) { Some(h) => Some(h.clone()), None => None, } }; - if hshake.is_none() { - ::tracing::debug!( - "No such client key id: {:?}", - resp.client_key_id - ); - return Err(handshake::Error::UnknownKeyID.into()); - } - let hshake = hshake.unwrap(); + let hshake_idx = { + if let Some(real_idx) = hshake_idx { + real_idx + } else { + ::tracing::debug!( + "No such client key id: {:?}", + resp.client_key_id + ); + return Err(handshake::Error::UnknownKeyID.into()); + } + }; + let hshake = &self.hshake_cli[hshake_idx]; let cipher_recv = &hshake.connection.cipher_recv; use crate::enc::sym::AAD; // no aad for now @@ -208,8 +216,18 @@ impl HandshakeTracker { return Err(handshake::Error::Key(e).into()); } } + // we can remove the handshake from the list + let hshake: HandshakeClient = { + let len = self.hshake_cli.len(); + if (hshake_idx + 1) != len { + self.hshake_cli.swap(hshake_idx, len - 1); + } + self.hshake_cli.pop().unwrap() + }; return Ok(HandshakeAction::ClientConnect( ClientConnectInfo { + service_id: hshake.service_id, + service_connection_id: hshake.service_conn_id, handshake, connection: hshake.connection, }, diff --git a/src/inner/worker.rs b/src/inner/worker.rs index 7e3bd2e..d11efdb 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -11,7 +11,7 @@ use crate::{ socket::{UdpClient, UdpServer}, ConnList, Connection, IDSend, Packet, ID, }, - enc::sym::Secret, + enc::{hkdf::HkdfSha3, sym::Secret}, inner::{HandshakeAction, HandshakeTracker, ThreadTracker}, }; use ::std::{rc::Rc, sync::Arc, vec::Vec}; @@ -238,7 +238,7 @@ impl Worker { let resp_data = dirsync::RespData { client_nonce: req_data.nonce, id: auth_conn.id_recv.0, - service_id: srv_conn_id, + service_connection_id: srv_conn_id, service_key: srv_secret, }; use crate::enc::sym::AAD; @@ -296,10 +296,40 @@ impl Worker { ); return; } - // FIXME: conn tracking and arc counting - let conn = Rc::get_mut(&mut cci.connection).unwrap(); - conn.id_send = IDSend(resp_data.id); - todo!(); + { + let conn = Rc::get_mut(&mut cci.connection).unwrap(); + conn.id_send = IDSend(resp_data.id); + } + // track the connection to the authentication server + if self.connections.track(cci.connection.clone()).is_err() { + self.connections.delete(cci.connection.id_recv); + } + if cci.connection.id_recv.0 + == resp_data.service_connection_id + { + // the user asked a single connection + // to the authentication server, without any additional + // service. No more connections to setup + return; + } + // create and track the connection to the service + //FIXME: the Secret should be XORed with the client stored + // secret (if any) + let hkdf = HkdfSha3::new( + cci.service_id.as_bytes(), + resp_data.service_key, + ); + let mut service_connection = Connection::new( + hkdf, + cci.connection.cipher_recv.kind(), + connection::Role::Client, + &self.rand, + ); + service_connection.id_recv = cci.service_connection_id; + service_connection.id_send = + IDSend(resp_data.service_connection_id); + self.connections.track(service_connection.into()); + return; } _ => {} }; @@ -326,6 +356,6 @@ impl Worker { return; } }; - src_sock.send_to(&data, client.0); + let _ = src_sock.send_to(&data, client.0).await; } } From 1259996201941f919faf0d3af32437fcbf5f3779 Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Sat, 27 May 2023 10:57:15 +0200 Subject: [PATCH 10/34] Connect boilerplate, cleanup Signed-off-by: Luca Fulchir --- flake.nix | 1 + src/connection/handshake/dirsync.rs | 4 +- src/connection/handshake/mod.rs | 6 +- src/connection/mod.rs | 11 +++- src/connection/socket.rs | 17 +++--- src/enc/asym.rs | 1 - src/enc/hkdf.rs | 11 ++-- src/enc/sym.rs | 57 ++---------------- src/inner/mod.rs | 10 +--- src/inner/worker.rs | 32 +++++++--- src/lib.rs | 90 ++++++++++++++++++++++------- 11 files changed, 123 insertions(+), 117 deletions(-) diff --git a/flake.nix b/flake.nix index 4c4a18a..fd0713c 100644 --- a/flake.nix +++ b/flake.nix @@ -37,6 +37,7 @@ #}) clippy cargo-watch + cargo-flamegraph cargo-license lld rust-bin.stable."1.69.0".default diff --git a/src/connection/handshake/dirsync.rs b/src/connection/handshake/dirsync.rs index 72d6da9..c5ef626 100644 --- a/src/connection/handshake/dirsync.rs +++ b/src/connection/handshake/dirsync.rs @@ -19,8 +19,6 @@ use crate::{ }; use ::arrayref::array_mut_ref; -use ::std::{collections::VecDeque, num::NonZeroU64, vec::Vec}; -use trust_dns_client::rr::rdata::key::Protocol; type Nonce = [u8; 16]; @@ -304,7 +302,7 @@ impl RespInner { pub fn len(&self) -> usize { match self { RespInner::CipherText(len) => *len, - RespInner::ClearText(d) => RespData::len(), + RespInner::ClearText(_) => RespData::len(), } } /* diff --git a/src/connection/handshake/mod.rs b/src/connection/handshake/mod.rs index 63de947..8897c1a 100644 --- a/src/connection/handshake/mod.rs +++ b/src/connection/handshake/mod.rs @@ -7,7 +7,7 @@ use crate::{ enc::sym::{HeadLen, TagLen}, }; use ::num_traits::FromPrimitive; -use ::std::{rc::Rc, sync::Arc}; +use ::std::rc::Rc; /// Handshake errors #[derive(::thiserror::Error, Debug, Copy, Clone)] @@ -145,10 +145,6 @@ impl Handshake { self.fenrir_version.serialize(&mut out[0]); self.data.serialize(head_len, tag_len, &mut out[1..]); } - - pub(crate) fn work(&self, keys: &[HandshakeServer]) -> Result<(), Error> { - todo!() - } } trait HandshakeParsing { diff --git a/src/connection/mod.rs b/src/connection/mod.rs index 2c85af7..9d8a4cb 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -4,7 +4,7 @@ pub mod handshake; pub mod packet; pub mod socket; -use ::std::{rc::Rc, sync::Arc, vec::Vec}; +use ::std::{rc::Rc, vec::Vec}; pub use crate::connection::{ handshake::Handshake, @@ -110,7 +110,7 @@ pub(crate) struct ConnList { impl ConnList { pub(crate) fn new(thread_id: ThreadTracker) -> Self { - let mut bitmap_id = ::bitmaps::Bitmap::<1024>::new(); + let bitmap_id = ::bitmaps::Bitmap::<1024>::new(); const INITIAL_CAP: usize = 128; let mut ret = Self { thread_id, @@ -120,6 +120,13 @@ impl ConnList { ret.connections.resize_with(INITIAL_CAP, || None); ret } + pub fn len(&self) -> usize { + let mut total: usize = 0; + for bitmap in self.ids_used.iter() { + total = total + bitmap.len() + } + total + } /// Only *Reserve* a connection, /// without actually tracking it in self.connections pub(crate) fn reserve_first( diff --git a/src/connection/socket.rs b/src/connection/socket.rs index 1cc570d..945dac6 100644 --- a/src/connection/socket.rs +++ b/src/connection/socket.rs @@ -1,15 +1,12 @@ //! Socket related types and functions -use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption}; -use ::std::{ - net::SocketAddr, - sync::Arc, - vec::{self, Vec}, -}; -use ::tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle}; +use ::arc_swap::ArcSwap; +use ::std::{net::SocketAddr, sync::Arc, vec::Vec}; +use ::tokio::{net::UdpSocket, task::JoinHandle}; /// Pair to easily track the socket and its async listening handle -pub type SocketTracker = (Arc, Arc>>); +pub type SocketTracker = + (Arc, Arc>>); /// async free socket list pub(crate) struct SocketList { @@ -48,7 +45,7 @@ impl SocketList { }); } /// This method assumes no other `add_sockets` are being run - pub(crate) async fn stop_all(mut self) { + pub(crate) async fn stop_all(self) { let mut arc_list = self.list.into_inner(); let list = loop { match Arc::try_unwrap(arc_list) { @@ -63,7 +60,7 @@ impl SocketList { } }; for (_socket, mut handle) in list.into_iter() { - Arc::get_mut(&mut handle).unwrap().await; + let _ = Arc::get_mut(&mut handle).unwrap().await; } } pub(crate) fn lock(&self) -> SocketListRef { diff --git a/src/enc/asym.rs b/src/enc/asym.rs index c705a5f..56027df 100644 --- a/src/enc/asym.rs +++ b/src/enc/asym.rs @@ -1,7 +1,6 @@ //! Asymmetric key handling and wrappers use ::num_traits::FromPrimitive; -use ::std::vec::Vec; use super::Error; use crate::enc::sym::Secret; diff --git a/src/enc/hkdf.rs b/src/enc/hkdf.rs index bb6ca59..15d7eca 100644 --- a/src/enc/hkdf.rs +++ b/src/enc/hkdf.rs @@ -51,13 +51,10 @@ impl HkdfSha3 { /// Instantiate a new HKDF with Sha3-256 pub fn new(salt: &[u8], key: Secret) -> Self { let hkdf = Hkdf::::new(Some(salt), key.as_ref()); - #[allow(unsafe_code)] - unsafe { - Self { - inner: HkdfInner { - hkdf: ::core::mem::ManuallyDrop::new(hkdf), - }, - } + Self { + inner: HkdfInner { + hkdf: ::core::mem::ManuallyDrop::new(hkdf), + }, } } /// Get a secret generated from the key and a given context diff --git a/src/enc/sym.rs b/src/enc/sym.rs index 3de267c..e6f9e11 100644 --- a/src/enc/sym.rs +++ b/src/enc/sym.rs @@ -1,7 +1,6 @@ //! Symmetric cypher stuff use super::Error; -use ::std::collections::VecDeque; use ::zeroize::Zeroize; /// Secret, used for keys. @@ -174,7 +173,7 @@ impl Cipher { } fn overhead(&self) -> usize { match self { - Cipher::XChaCha20Poly1305(cipher) => { + Cipher::XChaCha20Poly1305(_) => { let cipher = CipherKind::XChaCha20Poly1305; cipher.nonce_len().0 + cipher.tag_len().0 } @@ -189,9 +188,7 @@ impl Cipher { // FIXME: check minimum buffer size match self { Cipher::XChaCha20Poly1305(cipher) => { - use ::chacha20poly1305::{ - aead::generic_array::GenericArray, AeadInPlace, - }; + use ::chacha20poly1305::AeadInPlace; let tag_len: usize = ::ring::aead::CHACHA20_POLY1305.tag_len(); let data_len_notag = data.len() - tag_len; // write nonce @@ -211,10 +208,9 @@ impl Cipher { Ok(()) } Err(_) => Err(Error::Encrypt), - }; + } } } - todo!() } } @@ -253,35 +249,6 @@ impl CipherRecv { } } -/// Allocate some data, with additional indexes to track -/// where nonce and tags are -#[derive(Debug, Clone)] -pub struct Data { - data: Vec, - skip_start: usize, - skip_end: usize, -} - -impl Data { - /// Get the slice where you will write the actual data - /// this will skip the actual nonce and AEAD tag and give you - /// only the space for the data - pub fn get_slice(&mut self) -> &mut [u8] { - &mut self.data[self.skip_start..self.skip_end] - } - fn get_tag_slice(&mut self) -> &mut [u8] { - let start = self.data.len() - self.skip_end; - &mut self.data[start..] - } - fn get_slice_full(&mut self) -> &mut [u8] { - &mut self.data - } - /// Consume the data and return the whole raw vector - pub fn get_raw(self) -> Vec { - self.data - } -} - /// Send only cipher pub struct CipherSend { nonce: NonceSync, @@ -308,14 +275,6 @@ impl CipherSend { cipher: Cipher::new(kind, secret), } } - /// Allocate the memory for the data that will be encrypted - pub fn make_data(&self, length: usize) -> Data { - Data { - data: Vec::with_capacity(length + self.cipher.overhead()), - skip_start: self.cipher.nonce_len().0, - skip_end: self.cipher.tag_len().0, - } - } /// Encrypt the given data pub fn encrypt(&self, aad: AAD, data: &mut [u8]) -> Result<(), Error> { let old_nonce = self.nonce.advance(); @@ -380,10 +339,7 @@ impl Nonce { use ring::rand::SecureRandom; let mut raw = [0; 12]; rand.fill(&mut raw); - #[allow(unsafe_code)] - unsafe { - Self { raw } - } + Self { raw } } /// Length of this nonce in bytes pub const fn len() -> usize { @@ -398,10 +354,7 @@ impl Nonce { } /// Create Nonce from array pub fn from_slice(raw: [u8; 12]) -> Self { - #[allow(unsafe_code)] - unsafe { - Self { raw } - } + Self { raw } } /// Go to the next nonce pub fn advance(&mut self) { diff --git a/src/inner/mod.rs b/src/inner/mod.rs index 72447f2..ec09d1f 100644 --- a/src/inner/mod.rs +++ b/src/inner/mod.rs @@ -14,12 +14,11 @@ use crate::{ enc::{ self, asym, hkdf::HkdfSha3, - sym::{CipherKind, CipherRecv, CipherSend, HeadLen, TagLen}, + sym::{CipherKind, CipherRecv}, }, Error, }; -use ::arc_swap::{ArcSwap, ArcSwapAny, ArcSwapOption}; -use ::std::{rc::Rc, sync::Arc, vec::Vec}; +use ::std::{rc::Rc, vec::Vec}; /// Information needed to reply after the key exchange #[derive(Debug, Clone)] @@ -98,10 +97,7 @@ impl HandshakeTracker { mut handshake: Handshake, handshake_raw: &mut [u8], ) -> Result { - use connection::handshake::{ - dirsync::{self, DirSync}, - HandshakeData, - }; + use connection::handshake::{dirsync::DirSync, HandshakeData}; match handshake.data { HandshakeData::DirSync(ref mut ds) => match ds { DirSync::Req(ref mut req) => { diff --git a/src/inner/worker.rs b/src/inner/worker.rs index d11efdb..c7adf10 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -1,23 +1,25 @@ //! Worker thread implementation use crate::{ - auth::TokenChecker, + auth::{ServiceID, TokenChecker}, connection::{ self, handshake::{ - self, dirsync::{self, DirSync}, - Handshake, HandshakeClient, HandshakeData, + Handshake, HandshakeData, }, socket::{UdpClient, UdpServer}, - ConnList, Connection, IDSend, Packet, ID, + ConnList, Connection, IDSend, Packet, }, + dnssec, enc::{hkdf::HkdfSha3, sym::Secret}, inner::{HandshakeAction, HandshakeTracker, ThreadTracker}, }; use ::std::{rc::Rc, sync::Arc, vec::Vec}; /// This worker must be cpu-pinned -use ::tokio::{net::UdpSocket, sync::Mutex}; -use std::net::SocketAddr; +use ::tokio::{ + net::UdpSocket, + sync::{oneshot, Mutex}, +}; /// Track a raw Udp packet pub(crate) struct RawUdp { @@ -28,8 +30,15 @@ pub(crate) struct RawUdp { } pub(crate) enum Work { + /// ask the thread to report to the main thread the total number of + /// connections present + CountConnections(oneshot::Sender), + Connect((oneshot::Sender, dnssec::Record, ServiceID)), Recv(RawUdp), } +pub(crate) enum WorkAnswer { + UNUSED, +} /// Actual worker implementation. pub(crate) struct Worker { @@ -131,6 +140,13 @@ impl Worker { } }; match work { + Work::CountConnections(sender) => { + let conn_num = self.connections.len(); + let _ = sender.send(conn_num); + } + Work::Connect((send_res, dnssec_record, service_id)) => { + todo!() + } //TODO: reconf message to add channels Work::Recv(pkt) => { self.recv(pkt).await; @@ -285,7 +301,6 @@ impl Worker { return; } // track connection - use handshake::dirsync; let resp_data; if let dirsync::RespInner::ClearText(r_data) = ds_resp.data { @@ -313,6 +328,7 @@ impl Worker { return; } // create and track the connection to the service + // SECURITY: //FIXME: the Secret should be XORed with the client stored // secret (if any) let hkdf = HkdfSha3::new( @@ -328,7 +344,7 @@ impl Worker { service_connection.id_recv = cci.service_connection_id; service_connection.id_send = IDSend(resp_data.service_connection_id); - self.connections.track(service_connection.into()); + let _ = self.connections.track(service_connection.into()); return; } _ => {} diff --git a/src/lib.rs b/src/lib.rs index 02df264..1aed2e2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,12 +20,9 @@ pub mod dnssec; pub mod enc; mod inner; -use ::std::{ - net::SocketAddr, - sync::{Arc, Weak}, - vec::Vec, -}; -use ::tokio::{net::UdpSocket, sync::RwLock, task::JoinHandle}; +use ::std::{sync::Arc, vec::Vec}; +use ::tokio::net::UdpSocket; +use auth::ServiceID; use crate::{ auth::TokenChecker, @@ -94,9 +91,7 @@ impl Drop for Fenrir { impl Fenrir { /// Create a new Fenrir endpoint pub fn new(config: &Config) -> Result { - let listen_num = config.listen.len(); let (sender, _) = ::tokio::sync::broadcast::channel(1); - let (work_send, work_recv) = ::async_channel::unbounded::(); let endpoint = Fenrir { cfg: config.clone(), sockets: SocketList::new(), @@ -127,23 +122,23 @@ impl Fenrir { /// asyncronous version for Drop fn stop_sync(&mut self) { let _ = self.stop_working.send(true); - let mut toempty_sockets = self.sockets.rm_all(); + let toempty_sockets = self.sockets.rm_all(); let task = ::tokio::task::spawn(toempty_sockets.stop_all()); let _ = ::futures::executor::block_on(task); let mut old_thread_pool = Vec::new(); ::std::mem::swap(&mut self._thread_pool, &mut old_thread_pool); - old_thread_pool.into_iter().map(|th| th.join()); + let _ = old_thread_pool.into_iter().map(|th| th.join()); self.dnssec = None; } /// Stop all workers, listeners pub async fn stop(&mut self) { let _ = self.stop_working.send(true); - let mut toempty_sockets = self.sockets.rm_all(); + let toempty_sockets = self.sockets.rm_all(); toempty_sockets.stop_all().await; let mut old_thread_pool = Vec::new(); ::std::mem::swap(&mut self._thread_pool, &mut old_thread_pool); - old_thread_pool.into_iter().map(|th| th.join()); + let _ = old_thread_pool.into_iter().map(|th| th.join()); self.dnssec = None; } /// Add all UDP sockets found in config @@ -166,7 +161,7 @@ impl Fenrir { self._thread_work.clone(), arc_s.clone(), )); - self.sockets.add_socket(arc_s, join); + self.sockets.add_socket(arc_s, join).await; } Err(e) => { return Err(e); @@ -218,18 +213,19 @@ impl Fenrir { } } }; - work_queues[thread_idx].send(Work::Recv(RawUdp { - src: UdpClient(sock_sender), - dst: sock_receiver, - packet, - data, - })); + let _ = work_queues[thread_idx] + .send(Work::Recv(RawUdp { + src: UdpClient(sock_sender), + dst: sock_receiver, + packet, + data, + })) + .await; } Ok(()) } - /// Get the raw TXT record of a Fenrir domain - pub async fn resolv_str(&self, domain: &str) -> Result { + pub async fn resolv_txt(&self, domain: &str) -> Result { match &self.dnssec { Some(dnssec) => Ok(dnssec.resolv(domain).await?), None => Err(Error::NotInitialized), @@ -238,10 +234,60 @@ impl Fenrir { /// Get the raw TXT record of a Fenrir domain pub async fn resolv(&self, domain: &str) -> Result { - let record_str = self.resolv_str(domain).await?; + let record_str = self.resolv_txt(domain).await?; Ok(dnssec::Dnssec::parse_txt_record(&record_str)?) } + /// Connect to a service + pub async fn connect( + &self, + domain: &str, + service: ServiceID, + ) -> Result<(), Error> { + let resolved = self.resolv(domain).await?; + + // find the thread with less connections + + let th_num = self._thread_work.len(); + let mut conn_count = Vec::::with_capacity(th_num); + let mut wait_res = + Vec::<::tokio::sync::oneshot::Receiver>::with_capacity( + th_num, + ); + for th in self._thread_work.iter() { + let (send, recv) = ::tokio::sync::oneshot::channel(); + wait_res.push(recv); + let _ = th.send(Work::CountConnections(send)).await; + } + for ch in wait_res.into_iter() { + if let Ok(conn_num) = ch.await { + conn_count.push(conn_num); + } + } + if conn_count.len() != th_num { + return Err(Error::IO(::std::io::Error::new( + ::std::io::ErrorKind::NotConnected, + "can't connect to a thread", + ))); + } + let thread_idx = conn_count + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.cmp(b)) + .map(|(index, _)| index) + .unwrap(); + + // and tell that thread to connect somewhere + let (send, recv) = ::tokio::sync::oneshot::channel(); + let _ = self._thread_work[thread_idx] + .send(Work::Connect((send, resolved, service))) + .await; + + let _conn_res = recv.await; + + todo!() + } + /// Start one working thread for each physical cpu /// threads are pinned to each cpu core. /// Work will be divided and rerouted so that there is no need to lock From e581cb064aa11d95a69adcc0365b473317fe3a49 Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Sat, 27 May 2023 11:10:29 +0200 Subject: [PATCH 11/34] Update architecture.md Signed-off-by: Luca Fulchir --- src/Architecture.md | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/Architecture.md b/src/Architecture.md index b9f0c66..9e547d0 100644 --- a/src/Architecture.md +++ b/src/Architecture.md @@ -1,8 +1,17 @@ # Architecture -For now we will keep things as easy as possible, not caring about the performance. +The current architecture is based on tokio. +In the future we might want something more general purpose (and no-std), +but good enough for now. -This means we will use only ::tokio and spawn one job per connection +Tokio has its own thread pool, and we spawn one async loop per listening socket. + +Then we spawn our own thread pool. +These threads are pinned to the cpu cores. +We spawn one async worker per thread, making sure it remains pinned to that core. +This is done to avoid *a lot* of locking and multithread syncronizations. + +We do connection sharding on the connection id. # Future @@ -12,8 +21,6 @@ scaling and work queues: https://tokio.rs/blog/2019-10-scheduler What we want to do is: -* one thread per core -* thread pinning * ebpf BBR packet pacing * need to support non-ebpf BBR for mobile ios, too * connection sharding From 110a346551309e1044824e3ea009677c4bc2554f Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Sun, 28 May 2023 15:26:17 +0200 Subject: [PATCH 12/34] dnssec: use the proper enc::asym types Signed-off-by: Luca Fulchir --- src/connection/handshake/dirsync.rs | 4 +- src/connection/mod.rs | 9 ++ src/dnssec/record.rs | 117 ++++----------------- src/enc/asym.rs | 154 +++++++++++++++++++++++++--- src/enc/errors.rs | 9 +- src/inner/mod.rs | 2 +- 6 files changed, 172 insertions(+), 123 deletions(-) diff --git a/src/connection/handshake/dirsync.rs b/src/connection/handshake/dirsync.rs index c5ef626..775f301 100644 --- a/src/connection/handshake/dirsync.rs +++ b/src/connection/handshake/dirsync.rs @@ -78,7 +78,7 @@ impl Req { + KeyID::len() + KeyExchange::len() + CipherKind::len() - + self.exchange_key.len() + + self.exchange_key.kind().pub_len() } /// return the total length of the cleartext data pub fn encrypted_length(&self) -> usize { @@ -92,7 +92,7 @@ impl Req { KeyID::len() + KeyExchange::len() + CipherKind::len() - + self.exchange_key.len() + + self.exchange_key.kind().pub_len() + self.data.len() } /// Serialize into raw bytes diff --git a/src/connection/mod.rs b/src/connection/mod.rs index 9d8a4cb..e264cf7 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -200,3 +200,12 @@ impl ConnList { } } } + +/* +use ::std::collections::HashMap; + +pub(crate) struct AuthServerConnections { + conn_map : HashMap< + pub id: IDSend, +} +*/ diff --git a/src/dnssec/record.rs b/src/dnssec/record.rs index 5f5e09f..b4e9bb9 100644 --- a/src/dnssec/record.rs +++ b/src/dnssec/record.rs @@ -23,10 +23,10 @@ //! * Y bytes: pubkey //! ] +use crate::enc::{self, asym::PubKey}; use ::core::num::NonZeroU16; use ::num_traits::FromPrimitive; use ::std::{net::IpAddr, vec::Vec}; - /* * Public key data */ @@ -48,92 +48,6 @@ impl TryFrom<&str> for PublicKeyID { } } -/// Public Key Type -#[derive(::num_derive::FromPrimitive, Debug, Copy, Clone)] -// public enum: use non_exhaustive to force users to add a default case -// so in the future we can expand this easily -#[non_exhaustive] -#[repr(u8)] -pub enum PublicKeyType { - /// ed25519 asymmetric key - Ed25519 = 0, - /// Ephemeral X25519 (Curve25519) key. - /// Used in the directory synchronized handshake - X25519, -} - -impl PublicKeyType { - /// Get the size of a public key of this kind - pub fn key_len(&self) -> usize { - match &self { - PublicKeyType::Ed25519 => 32, - PublicKeyType::X25519 => 32, // FIXME: hopefully... - } - } -} - -impl TryFrom<&str> for PublicKeyType { - type Error = ::std::io::Error; - fn try_from(raw: &str) -> Result { - if let Ok(type_u8) = raw.parse::() { - if let Some(kind) = PublicKeyType::from_u8(type_u8) { - return Ok(kind); - } - } - return Err(::std::io::Error::new( - ::std::io::ErrorKind::InvalidData, - "Public Key Type 0 is the only one supported", - )); - } -} - -/// Public Key, with its type and id -#[derive(Debug, Clone)] -pub struct PublicKey { - /// public key raw data - pub raw: Vec, - /// type of public key - pub kind: PublicKeyType, - /// id of public key - pub id: PublicKeyID, -} - -impl PublicKey { - fn raw_len(&self) -> usize { - let size = 2; // Public Key Type + ID - size + self.raw.len() - } - fn encode_into(&self, raw: &mut Vec) { - raw.push(self.kind as u8); - raw.push(self.id.0); - raw.extend_from_slice(&self.raw); - } - fn decode_raw(raw: &[u8]) -> Result<(Self, usize), Error> { - if raw.len() < 4 { - return Err(Error::NotEnoughData(0)); - } - - let kind = PublicKeyType::from_u8(raw[0]).unwrap(); - let id = PublicKeyID(raw[1]); - if raw.len() < 2 + kind.key_len() { - return Err(Error::NotEnoughData(2)); - } - - let mut raw_key = Vec::with_capacity(kind.key_len()); - let total_length = 2 + kind.key_len(); - raw_key.extend_from_slice(&raw[2..total_length]); - - Ok(( - Self { - raw: raw_key, - kind, - id, - }, - total_length, - )) - } -} - /* * Address data */ @@ -358,7 +272,7 @@ impl Address { let raw_port = u16::from_le_bytes([raw[1], raw[2]]); - // Add publi key ids + // Add publickey ids let num_pubkey_ids = raw[3] as usize; if raw.len() < 3 + num_pubkey_ids { return Err(Error::NotEnoughData(3)); @@ -428,14 +342,14 @@ impl Address { } /* - * Actual record puuting it all toghether + * Actual record putting it all toghether */ /// All informations found in the DNSSEC record #[derive(Debug, Clone)] pub struct Record { /// Public keys used by any authentication server - pub public_keys: Vec, + pub public_keys: Vec<(PublicKeyID, PubKey)>, /// List of all authentication servers' addresses. /// Multiple ones can point to the same authentication server pub addresses: Vec
, @@ -461,7 +375,11 @@ impl Record { let total_size: usize = 1 + self.addresses.iter().map(|a| a.raw_len()).sum::() - + self.public_keys.iter().map(|a| a.raw_len()).sum::(); + + self + .public_keys + .iter() + .map(|(_, key)| 1 + key.kind().pub_len()) + .sum::(); let mut raw = Vec::with_capacity(total_size); @@ -475,8 +393,9 @@ impl Record { for address in self.addresses.iter() { address.encode_into(&mut raw); } - for public_key in self.public_keys.iter() { - public_key.encode_into(&mut raw); + for (public_key_id, public_key) in self.public_keys.iter() { + raw.push(public_key_id.0); + public_key.serialize_into(&mut raw); } Ok(::base85::encode(&raw)) @@ -514,19 +433,21 @@ impl Record { num_addresses = num_addresses - 1; } while num_public_keys > 0 { + let id = PublicKeyID(raw[bytes_parsed]); + bytes_parsed = bytes_parsed + 1; let (public_key, bytes) = - match PublicKey::decode_raw(&raw[bytes_parsed..]) { + match PubKey::deserialize(&raw[bytes_parsed..]) { Ok(public_key) => public_key, - Err(Error::UnsupportedData(b)) => { + Err(enc::Error::UnsupportedKey(b)) => { return Err(Error::UnsupportedData(bytes_parsed + b)) } - Err(Error::NotEnoughData(b)) => { + Err(enc::Error::NotEnoughData(b)) => { return Err(Error::NotEnoughData(bytes_parsed + b)) } - Err(e) => return Err(e), + _ => return Err(Error::UnknownData(bytes_parsed)), }; bytes_parsed = bytes_parsed + bytes; - result.public_keys.push(public_key); + result.public_keys.push((id, public_key)); num_public_keys = num_public_keys - 1; } if bytes_parsed != raw.len() { diff --git a/src/enc/asym.rs b/src/enc/asym.rs index 56027df..160ee12 100644 --- a/src/enc/asym.rs +++ b/src/enc/asym.rs @@ -20,18 +20,52 @@ impl KeyID { } } +/// Capabilities of each key +#[derive(Debug, Clone, Copy)] +pub enum KeyCapabilities { + /// signing *only* + Sign, + /// encrypt *only* + Encrypt, + /// key exchange *only* + Exchange, + /// both sign and encrypt + SignEncrypt, + /// both signing and key exchange + SignExchange, + /// both encrypt and key exchange + EncryptExchange, + /// All: sign, encrypt, Key Exchange + SignEncryptExchage, +} + /// Kind of key used in the handshake #[derive(Debug, Copy, Clone, PartialEq, ::num_derive::FromPrimitive)] +#[non_exhaustive] #[repr(u8)] -pub enum Key { - /// X25519 Public key - X25519 = 0, +pub enum KeyKind { + /// Ed25519 Public key (sign only) + Ed25519 = 0, + /// X25519 Public key (key exchange) + X25519, } -impl Key { - fn pub_len(&self) -> usize { +// FIXME: actually check this +const MIN_KEY_SIZE: usize = 32; +impl KeyKind { + /// return the expected length of the public key + pub fn pub_len(&self) -> usize { match self { // FIXME: 99% wrong size - Key::X25519 => ::ring::signature::ED25519_PUBLIC_KEY_LEN, + KeyKind::Ed25519 => ::ring::signature::ED25519_PUBLIC_KEY_LEN, + // FIXME: 99% wrong size + KeyKind::X25519 => ::ring::signature::ED25519_PUBLIC_KEY_LEN, + } + } + /// Get the capabilities of this key type + pub fn capabilities(&self) -> KeyCapabilities { + match self { + KeyKind::Ed25519 => KeyCapabilities::Sign, + KeyKind::X25519 => KeyCapabilities::Exchange, } } } @@ -50,9 +84,80 @@ impl KeyExchange { } } -/// Kind of key in the handshake +/// Kind of public key in the handshake +#[derive(Debug, Copy, Clone)] +#[allow(missing_debug_implementations)] +#[non_exhaustive] +pub enum PubKey { + /// Keys to be used only in key exchanges, not for signing + Exchange(ExchangePubKey), + /// Keys to be used only for signing + Signing, +} + +impl PubKey { + /// return the kind of public key + pub fn kind(&self) -> KeyKind { + match self { + // FIXME: lie, we don't fully support this + PubKey::Signing => KeyKind::Ed25519, + PubKey::Exchange(ex) => ex.kind(), + } + } + /// serialize the key into the buffer + /// NOTE: Assumes there is enough space + pub fn serialize_into(&self, out: &mut [u8]) { + assert!( + out.len() >= 1 + self.kind().pub_len(), + "Not enough out buffer", + ); + out[0] = self.kind() as u8; + match self { + PubKey::Signing => { + ::tracing::error!("serializing ed25519 not supported"); + return; + } + PubKey::Exchange(ex) => ex.serialize_into(&mut out[1..]), + } + } + /// Try to deserialize the pubkey from raw bytes + /// on success returns the public key and the number of parsed bytes + pub fn deserialize(raw: &[u8]) -> Result<(Self, usize), Error> { + if raw.len() < 1 + MIN_KEY_SIZE { + return Err(Error::NotEnoughData(0)); + } + let kind: KeyKind = match KeyKind::from_u8(raw[0]) { + Some(kind) => kind, + None => return Err(Error::UnsupportedKey(1)), + }; + if raw.len() < 1 + kind.pub_len() { + return Err(Error::NotEnoughData(1)); + } + match kind { + KeyKind::Ed25519 => { + ::tracing::error!("ed25519 keys are not yet supported"); + return Err(Error::Parsing); + } + KeyKind::X25519 => { + let pub_key: ::x25519_dalek::PublicKey = + match ::bincode::deserialize(&raw[1..(1 + kind.pub_len())]) + { + Ok(pub_key) => pub_key, + Err(_) => return Err(Error::Parsing), + }; + Ok(( + PubKey::Exchange(ExchangePubKey::X25519(pub_key)), + kind.pub_len(), + )) + } + } + } +} + +/// Kind of private key in the handshake #[derive(Clone)] #[allow(missing_debug_implementations)] +#[non_exhaustive] pub enum PrivKey { /// Keys to be used only in key exchanges, not for signing Exchange(ExchangePrivKey), @@ -63,6 +168,7 @@ pub enum PrivKey { /// Ephemeral private keys #[derive(Clone)] #[allow(missing_debug_implementations)] +#[non_exhaustive] pub enum ExchangePrivKey { /// X25519(Curve25519) used for key exchange X25519(::x25519_dalek::StaticSecret), @@ -70,9 +176,9 @@ pub enum ExchangePrivKey { impl ExchangePrivKey { /// Get the kind of key - pub fn kind(&self) -> Key { + pub fn kind(&self) -> KeyKind { match self { - ExchangePrivKey::X25519(_) => Key::X25519, + ExchangePrivKey::X25519(_) => KeyKind::X25519, } } /// Run the key exchange between two keys of the same kind @@ -99,30 +205,44 @@ impl ExchangePrivKey { /// all Ephemeral Public keys #[derive(Debug, Copy, Clone)] +#[non_exhaustive] pub enum ExchangePubKey { /// X25519(Curve25519) used for key exchange X25519(::x25519_dalek::PublicKey), } impl ExchangePubKey { - /// length of the public key used for key exchange - pub fn len(&self) -> usize { + /// Get the kind of key + pub fn kind(&self) -> KeyKind { match self { - ExchangePubKey::X25519(_) => 32, + ExchangePubKey::X25519(_) => KeyKind::X25519, + } + } + /// serialize the key into the buffer + /// NOTE: Assumes there is enough space + fn serialize_into(&self, out: &mut [u8]) { + match self { + ExchangePubKey::X25519(pk) => { + let bytes = pk.as_bytes(); + assert!(bytes.len() == 32, "x25519 should have been 32 bytes"); + out[..32].copy_from_slice(bytes); + } } } /// Load public key used for key exchange from it raw bytes /// The riesult is "unparsed" since we don't verify /// the actual key pub fn from_slice(raw: &[u8]) -> Result<(Self, usize), Error> { - // FIXME: get *real* minimum key size - const MIN_KEY_SIZE: usize = 32; if raw.len() < 1 + MIN_KEY_SIZE { - return Err(Error::NotEnoughData); + return Err(Error::NotEnoughData(0)); } - match Key::from_u8(raw[0]) { + match KeyKind::from_u8(raw[0]) { Some(kind) => match kind { - Key::X25519 => { + KeyKind::Ed25519 => { + ::tracing::error!("ed25519 keys are not yet supported"); + return Err(Error::Parsing); + } + KeyKind::X25519 => { let pub_key: ::x25519_dalek::PublicKey = match ::bincode::deserialize( &raw[1..(1 + kind.pub_len())], diff --git a/src/enc/errors.rs b/src/enc/errors.rs index ded1fc0..c0591b0 100644 --- a/src/enc/errors.rs +++ b/src/enc/errors.rs @@ -8,14 +8,13 @@ pub enum Error { Parsing, /// Not enough data #[error("not enough data")] - NotEnoughData, + NotEnoughData(usize), /// buffer too small #[error("buffer too small")] InsufficientBuffer, - /// Wrong Key type found. - /// You might have passed rsa keys where x25519 was expected - #[error("wrong key type")] - WrongKey, + /// Unsupported Key type found. + #[error("unsupported key type")] + UnsupportedKey(usize), /// Unsupported key exchange for this key #[error("unsupported key exchange")] UnsupportedKeyExchange, diff --git a/src/inner/mod.rs b/src/inner/mod.rs index ec09d1f..9ee6942 100644 --- a/src/inner/mod.rs +++ b/src/inner/mod.rs @@ -70,7 +70,7 @@ pub(crate) struct ThreadTracker { /// (udp_src_sender_port % total_threads) - 1 pub(crate) struct HandshakeTracker { thread_id: ThreadTracker, - key_exchanges: Vec<(asym::Key, asym::KeyExchange)>, + key_exchanges: Vec<(asym::KeyKind, asym::KeyExchange)>, ciphers: Vec, /// ephemeral keys used server side in key exchange keys_srv: Vec, From a3430f18135f2fb9b7b76f9667af6df85864143b Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Sun, 28 May 2023 18:23:14 +0200 Subject: [PATCH 13/34] Initial connections: share auth.server connection Signed-off-by: Luca Fulchir --- src/connection/mod.rs | 95 ++++++++++++++++++++++++++++++++++++++++--- src/dnssec/record.rs | 2 +- src/enc/asym.rs | 4 +- src/inner/worker.rs | 9 +++- src/lib.rs | 65 ++++++++++++++++++++++++++--- 5 files changed, 160 insertions(+), 15 deletions(-) diff --git a/src/connection/mod.rs b/src/connection/mod.rs index e264cf7..6d3325d 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -12,7 +12,9 @@ pub use crate::connection::{ }; use crate::{ + dnssec, enc::{ + asym::PubKey, hkdf::HkdfSha3, sym::{CipherKind, CipherRecv, CipherSend}, }, @@ -201,11 +203,94 @@ impl ConnList { } } -/* use ::std::collections::HashMap; -pub(crate) struct AuthServerConnections { - conn_map : HashMap< - pub id: IDSend, +enum MapEntry { + Present(IDSend), + Reserved, +} + +/// return wether we already have a connection, we are waiting for one, or you +/// can start one +#[derive(Debug, Clone, Copy)] +pub(crate) enum Reservation { + /// we already have a connection. use this ID. + Present(IDSend), + /// we don't have a connection, but we are waiting for one to be established. + Waiting, + /// we have reserved a spot for your connection. + Reserved, +} + +/// Link the public key of the authentication server to a connection id +/// so that we can reuse that connection to ask for more authentications +/// +/// Note that a server can have multiple public keys, +/// and the handshake will only ever verify one. +/// To avoid malicious publication fo keys that are not yours, +/// on connection we: +/// * reserve all public keys of the server +/// * wait for the connection to finish +/// * remove all those reservations, exept the one key that actually succeded +/// While searching, we return a connection ID if just one key is a match +pub(crate) struct AuthServerConnections { + conn_map: HashMap, + next_reservation: u64, +} + +impl AuthServerConnections { + pub(crate) fn new() -> Self { + Self { + conn_map: HashMap::with_capacity(32), + next_reservation: 0, + } + } + /// add an ID to the reserved spot, + /// and unlock the other pubkeys which have not been verified + pub(crate) fn add( + &mut self, + pubkey: &PubKey, + id: IDSend, + record: &dnssec::Record, + ) { + let _ = self.conn_map.insert(*pubkey, MapEntry::Present(id)); + for (_, pk) in record.public_keys.iter() { + if pk == pubkey { + continue; + } + let _ = self.conn_map.remove(pk); + } + } + /// remove a dropped connection + pub(crate) fn remove_reserved(&mut self, record: &dnssec::Record) { + for (_, pk) in record.public_keys.iter() { + let _ = self.conn_map.remove(pk); + } + } + /// remove a dropped connection + pub(crate) fn remove_conn(&mut self, pubkey: &PubKey) { + let _ = self.conn_map.remove(pubkey); + } + + /// each dnssec::Record has multiple Pubkeys. reserve and ID for them all. + /// later on, when `add` is called we will delete + /// those that have not actually benn used + pub(crate) fn get_or_reserve( + &mut self, + record: &dnssec::Record, + ) -> Reservation { + for (_, pk) in record.public_keys.iter() { + match self.conn_map.get(pk) { + None => {} + Some(MapEntry::Reserved) => return Reservation::Waiting, + Some(MapEntry::Present(id)) => { + return Reservation::Present(id.clone()) + } + } + } + for (_, pk) in record.public_keys.iter() { + let _ = self.conn_map.insert(*pk, MapEntry::Reserved); + } + Reservation::Reserved + } } -*/ diff --git a/src/dnssec/record.rs b/src/dnssec/record.rs index b4e9bb9..6709751 100644 --- a/src/dnssec/record.rs +++ b/src/dnssec/record.rs @@ -18,8 +18,8 @@ //! * X bytes: IP //! ] //! [ # list of pubkeys -//! * 1 byte: pubkey type //! * 1 byte: pubkey id +//! * 1 byte: pubkey type //! * Y bytes: pubkey //! ] diff --git a/src/enc/asym.rs b/src/enc/asym.rs index 160ee12..6fd7901 100644 --- a/src/enc/asym.rs +++ b/src/enc/asym.rs @@ -85,7 +85,7 @@ impl KeyExchange { } /// Kind of public key in the handshake -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, Eq, Hash, PartialEq)] #[allow(missing_debug_implementations)] #[non_exhaustive] pub enum PubKey { @@ -204,7 +204,7 @@ impl ExchangePrivKey { } /// all Ephemeral Public keys -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, Eq, Hash, PartialEq)] #[non_exhaustive] pub enum ExchangePubKey { /// X25519(Curve25519) used for key exchange diff --git a/src/inner/worker.rs b/src/inner/worker.rs index c7adf10..512c2f6 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -11,7 +11,7 @@ use crate::{ ConnList, Connection, IDSend, Packet, }, dnssec, - enc::{hkdf::HkdfSha3, sym::Secret}, + enc::{asym::PubKey, hkdf::HkdfSha3, sym::Secret}, inner::{HandshakeAction, HandshakeTracker, ThreadTracker}, }; use ::std::{rc::Rc, sync::Arc, vec::Vec}; @@ -29,11 +29,16 @@ pub(crate) struct RawUdp { pub packet: Packet, } +pub(crate) enum ConnectionResult { + Failed(crate::Error), + Established((PubKey, IDSend)), +} + pub(crate) enum Work { /// ask the thread to report to the main thread the total number of /// connections present CountConnections(oneshot::Sender), - Connect((oneshot::Sender, dnssec::Record, ServiceID)), + Connect((oneshot::Sender, dnssec::Record, ServiceID)), Recv(RawUdp), } pub(crate) enum WorkAnswer { diff --git a/src/lib.rs b/src/lib.rs index 1aed2e2..6f48ec5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,7 +21,7 @@ pub mod enc; mod inner; use ::std::{sync::Arc, vec::Vec}; -use ::tokio::net::UdpSocket; +use ::tokio::{net::UdpSocket, sync::Mutex}; use auth::ServiceID; use crate::{ @@ -29,7 +29,7 @@ use crate::{ connection::{ handshake, socket::{SocketList, UdpClient, UdpServer}, - Packet, + AuthServerConnections, Packet, }, inner::{ worker::{RawUdp, Work, Worker}, @@ -74,6 +74,9 @@ pub struct Fenrir { stop_working: ::tokio::sync::broadcast::Sender, /// where to ask for token check token_check: Option>>, + /// tracks the connections to authentication servers + /// so that we can reuse them + conn_auth_srv: Mutex, // TODO: find a way to both increase and decrease these two in a thread-safe // manner _thread_pool: Vec<::std::thread::JoinHandle<()>>, @@ -98,6 +101,7 @@ impl Fenrir { dnssec: None, stop_working: sender, token_check: None, + conn_auth_srv: Mutex::new(AuthServerConnections::new()), _thread_pool: Vec::new(), _thread_work: Arc::new(Vec::new()), }; @@ -245,6 +249,30 @@ impl Fenrir { service: ServiceID, ) -> Result<(), Error> { let resolved = self.resolv(domain).await?; + loop { + // check if we already have a connection to that auth. srv + let is_reserved = { + let mut conn_auth_lock = self.conn_auth_srv.lock().await; + conn_auth_lock.get_or_reserve(&resolved) + }; + use connection::Reservation; + match is_reserved { + Reservation::Waiting => { + use ::std::time::Duration; + use ::tokio::time::sleep; + // PERF: exponential backoff. + // or we can have a broadcast channel + sleep(Duration::from_millis(50)).await; + continue; + } + Reservation::Reserved => break, + Reservation::Present(id_send) => { + //TODO: reuse connection + todo!() + } + } + } + // Spot reserved for the connection // find the thread with less connections @@ -280,12 +308,39 @@ impl Fenrir { // and tell that thread to connect somewhere let (send, recv) = ::tokio::sync::oneshot::channel(); let _ = self._thread_work[thread_idx] - .send(Work::Connect((send, resolved, service))) + .send(Work::Connect((send, resolved.clone(), service))) .await; - let _conn_res = recv.await; + match recv.await { + Ok(res) => { + use crate::inner::worker::ConnectionResult; + match res { + ConnectionResult::Failed(e) => { + let mut conn_auth_lock = + self.conn_auth_srv.lock().await; + conn_auth_lock.remove_reserved(&resolved); + Err(e) + } + ConnectionResult::Established((pubkey, id_send)) => { + let mut conn_auth_lock = + self.conn_auth_srv.lock().await; + conn_auth_lock.add(&pubkey, id_send, &resolved); - todo!() + //FIXME: user needs to somehow track the connection + Ok(()) + } + } + } + Err(e) => { + // Thread dropped the sender. no more thread? + let mut conn_auth_lock = self.conn_auth_srv.lock().await; + conn_auth_lock.remove_reserved(&resolved); + Err(Error::IO(::std::io::Error::new( + ::std::io::ErrorKind::Interrupted, + "recv failure on connect: ".to_owned() + &e.to_string(), + ))) + } + } } /// Start one working thread for each physical cpu From c6a3bf08202e4aca75961b451ba23a452a33401c Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Tue, 30 May 2023 10:52:54 +0200 Subject: [PATCH 14/34] More work on connect(), use our own Random We use ::ring::rand::SystemRandom, but we need to wrap it for a couple of traits needed by ::x25519_dalek Signed-off-by: Luca Fulchir --- Cargo.toml | 1 + src/auth/mod.rs | 4 +- src/connection/handshake/mod.rs | 1 + src/connection/mod.rs | 3 +- src/connection/packet.rs | 12 +++--- src/dnssec/record.rs | 2 +- src/enc/asym.rs | 26 ++++++++++++- src/enc/mod.rs | 66 ++++++++++++++++++++++++++++++++ src/enc/sym.rs | 15 +++----- src/inner/worker.rs | 67 +++++++++++++++++++++++++++------ src/lib.rs | 13 ++++--- 11 files changed, 173 insertions(+), 37 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 8142aa8..12ed78d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ hwloc2 = {version = "2.2" } libc = { version = "0.2" } num-traits = { version = "0.2" } num-derive = { version = "0.3" } +rand_core = {version = "0.6" } ring = { version = "0.16" } bincode = { version = "1.3" } sha3 = { version = "0.10" } diff --git a/src/auth/mod.rs b/src/auth/mod.rs index 2d8dc65..955cb12 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -1,6 +1,6 @@ //! Authentication related struct definitions -use ::ring::rand::SecureRandom; +use crate::enc::Random; use ::zeroize::Zeroize; /// User identifier. 16 bytes for easy uuid conversion @@ -15,7 +15,7 @@ impl From<[u8; 16]> for UserID { impl UserID { /// New random user id - pub fn new(rand: &::ring::rand::SystemRandom) -> Self { + pub fn new(rand: &Random) -> Self { let mut ret = Self([0; 16]); rand.fill(&mut ret.0); ret diff --git a/src/connection/handshake/mod.rs b/src/connection/handshake/mod.rs index 8897c1a..76eb353 100644 --- a/src/connection/handshake/mod.rs +++ b/src/connection/handshake/mod.rs @@ -39,6 +39,7 @@ pub(crate) struct HandshakeClient { pub service_id: crate::auth::ServiceID, pub service_conn_id: connection::IDRecv, pub connection: Rc, + pub timeout: Rc, } /// Parsed handshake diff --git a/src/connection/mod.rs b/src/connection/mod.rs index 6d3325d..2f8e3c9 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -17,6 +17,7 @@ use crate::{ asym::PubKey, hkdf::HkdfSha3, sym::{CipherKind, CipherRecv, CipherSend}, + Random, }, inner::ThreadTracker, }; @@ -78,7 +79,7 @@ impl Connection { hkdf: HkdfSha3, cipher: CipherKind, role: Role, - rand: &::ring::rand::SystemRandom, + rand: &Random, ) -> Self { let (secret_recv, secret_send) = match role { Role::Server => { diff --git a/src/connection/packet.rs b/src/connection/packet.rs index 400f900..b051594 100644 --- a/src/connection/packet.rs +++ b/src/connection/packet.rs @@ -1,7 +1,10 @@ // //! Raw packet handling, encryption, decryption, parsing -use crate::enc::sym::{HeadLen, TagLen}; +use crate::enc::{ + sym::{HeadLen, TagLen}, + Random, +}; /// Fenrir Connection id /// 0 is special as it represents the handshake @@ -33,8 +36,7 @@ impl ConnectionID { } } /// New random service ID - pub fn new_rand(rand: &::ring::rand::SystemRandom) -> Self { - use ::ring::rand::SecureRandom; + pub fn new_rand(rand: &Random) -> Self { let mut raw = [0; 8]; let mut num = 0; while num == 0 { @@ -100,7 +102,7 @@ impl PacketData { pub fn len(&self) -> usize { match self { PacketData::Handshake(h) => h.len(), - PacketData::Raw(len) => *len + PacketData::Raw(len) => *len, } } /// serialize data into bytes @@ -134,7 +136,7 @@ pub struct Packet { impl Packet { /// New recevied packet, yet unparsed - pub fn deserialize_id(raw: &[u8]) -> Result { + pub fn deserialize_id(raw: &[u8]) -> Result { // TODO: proper min_packet length. 16 is too conservative. if raw.len() < MIN_PACKET_BYTES { return Err(()); diff --git a/src/dnssec/record.rs b/src/dnssec/record.rs index 6709751..48e8edd 100644 --- a/src/dnssec/record.rs +++ b/src/dnssec/record.rs @@ -32,7 +32,7 @@ use ::std::{net::IpAddr, vec::Vec}; */ /// Public Key ID -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, PartialEq)] pub struct PublicKeyID(u8); impl TryFrom<&str> for PublicKeyID { diff --git a/src/enc/asym.rs b/src/enc/asym.rs index 6fd7901..ff8dda1 100644 --- a/src/enc/asym.rs +++ b/src/enc/asym.rs @@ -3,7 +3,7 @@ use ::num_traits::FromPrimitive; use super::Error; -use crate::enc::sym::Secret; +use crate::enc::{sym::Secret, Random}; /// Public key ID #[derive(Debug, Copy, Clone, PartialEq)] @@ -72,6 +72,7 @@ impl KeyKind { /// Kind of key exchange #[derive(Debug, Copy, Clone, PartialEq, ::num_derive::FromPrimitive)] +#[non_exhaustive] #[repr(u8)] pub enum KeyExchange { /// X25519 Public key @@ -82,6 +83,23 @@ impl KeyExchange { pub fn len() -> usize { 1 } + /// Build a new keypair for key exchange + pub fn new_keypair( + &self, + rnd: &Random, + ) -> Result<(ExchangePrivKey, ExchangePubKey), Error> { + match self { + KeyExchange::X25519DiffieHellman => { + let raw_priv = ::x25519_dalek::StaticSecret::new(rnd); + let pub_key = ExchangePubKey::X25519( + ::x25519_dalek::PublicKey::from(&raw_priv), + ); + let priv_key = ExchangePrivKey::X25519(raw_priv); + Ok((priv_key, pub_key)) + } + _ => Err(Error::UnsupportedKeyExchange), + } + } } /// Kind of public key in the handshake @@ -162,6 +180,7 @@ pub enum PrivKey { /// Keys to be used only in key exchanges, not for signing Exchange(ExchangePrivKey), /// Keys to be used only for signing + // TODO: implement ed25519 Signing, } @@ -259,3 +278,8 @@ impl ExchangePubKey { } } } + +/// Build a new pair of private/public key pair +pub fn new_keypair(kind: KeyKind, rnd: &Random) -> (PrivKey, PubKey) { + todo!() +} diff --git a/src/enc/mod.rs b/src/enc/mod.rs index eda3385..4da9a0c 100644 --- a/src/enc/mod.rs +++ b/src/enc/mod.rs @@ -6,3 +6,69 @@ pub mod hkdf; pub mod sym; pub use errors::Error; + +use ::ring::rand::SecureRandom; + +/// wrapper where we implement whatever random traint stuff each library needs +pub struct Random { + /// actual source of randomness + rnd: ::ring::rand::SystemRandom, +} + +impl Random { + /// Build a nre Random source + pub fn new() -> Self { + Self { + rnd: ::ring::rand::SystemRandom::new(), + } + } + /// Fill a buffer with randomness + pub fn fill(&self, out: &mut [u8]) { + self.rnd.fill(out); + } +} + +// Fake debug implementation to avoid leaking secrets +impl ::core::fmt::Debug for Random { + fn fmt( + &self, + f: &mut core::fmt::Formatter<'_>, + ) -> Result<(), ::std::fmt::Error> { + ::core::fmt::Debug::fmt("[hidden randomness]", f) + } +} + +// ::rand_core::{RngCore, CryptoRng} needed for ::x25519::dalek +impl ::rand_core::RngCore for &Random { + fn next_u32(&mut self) -> u32 { + use ::core::mem::MaybeUninit; + let mut out: MaybeUninit<[u8; 4]> = MaybeUninit::uninit(); + #[allow(unsafe_code)] + unsafe { + let _ = self.rnd.fill(out.assume_init_mut()); + u32::from_le_bytes(out.assume_init()) + } + } + fn next_u64(&mut self) -> u64 { + use ::core::mem::MaybeUninit; + let mut out: MaybeUninit<[u8; 8]> = MaybeUninit::uninit(); + #[allow(unsafe_code)] + unsafe { + let _ = self.rnd.fill(out.assume_init_mut()); + u64::from_le_bytes(out.assume_init()) + } + } + fn fill_bytes(&mut self, dest: &mut [u8]) { + let _ = self.rnd.fill(dest); + } + fn try_fill_bytes( + &mut self, + dest: &mut [u8], + ) -> Result<(), ::rand_core::Error> { + match self.rnd.fill(dest) { + Ok(()) => Ok(()), + Err(e) => Err(::rand_core::Error::new(e)), + } + } +} +impl ::rand_core::CryptoRng for &Random {} diff --git a/src/enc/sym.rs b/src/enc/sym.rs index e6f9e11..d9673a0 100644 --- a/src/enc/sym.rs +++ b/src/enc/sym.rs @@ -1,6 +1,7 @@ //! Symmetric cypher stuff use super::Error; +use crate::enc::Random; use ::zeroize::Zeroize; /// Secret, used for keys. @@ -20,8 +21,7 @@ impl ::core::fmt::Debug for Secret { impl Secret { /// New randomly generated secret - pub fn new_rand(rand: &::ring::rand::SystemRandom) -> Self { - use ::ring::rand::SecureRandom; + pub fn new_rand(rand: &Random) -> Self { let mut ret = Self([0; 32]); rand.fill(&mut ret.0); ret @@ -265,11 +265,7 @@ impl ::core::fmt::Debug for CipherSend { impl CipherSend { /// Build a new Cipher - pub fn new( - kind: CipherKind, - secret: Secret, - rand: &::ring::rand::SystemRandom, - ) -> Self { + pub fn new(kind: CipherKind, secret: Secret, rand: &Random) -> Self { Self { nonce: NonceSync::new(rand), cipher: Cipher::new(kind, secret), @@ -335,8 +331,7 @@ impl ::core::fmt::Debug for Nonce { impl Nonce { /// Generate a new random Nonce - pub fn new(rand: &::ring::rand::SystemRandom) -> Self { - use ring::rand::SecureRandom; + pub fn new(rand: &Random) -> Self { let mut raw = [0; 12]; rand.fill(&mut raw); Self { raw } @@ -376,7 +371,7 @@ pub struct NonceSync { } impl NonceSync { /// Create a new thread safe nonce - pub fn new(rand: &::ring::rand::SystemRandom) -> Self { + pub fn new(rand: &Random) -> Self { Self { nonce: ::std::sync::Mutex::new(Nonce::new(rand)), } diff --git a/src/inner/worker.rs b/src/inner/worker.rs index 512c2f6..fcc69ea 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -11,7 +11,7 @@ use crate::{ ConnList, Connection, IDSend, Packet, }, dnssec, - enc::{asym::PubKey, hkdf::HkdfSha3, sym::Secret}, + enc::{asym::PubKey, hkdf::HkdfSha3, sym::Secret, Random}, inner::{HandshakeAction, HandshakeTracker, ThreadTracker}, }; use ::std::{rc::Rc, sync::Arc, vec::Vec}; @@ -29,16 +29,17 @@ pub(crate) struct RawUdp { pub packet: Packet, } -pub(crate) enum ConnectionResult { - Failed(crate::Error), - Established((PubKey, IDSend)), -} - pub(crate) enum Work { /// ask the thread to report to the main thread the total number of /// connections present CountConnections(oneshot::Sender), - Connect((oneshot::Sender, dnssec::Record, ServiceID)), + Connect( + ( + oneshot::Sender>, + dnssec::Record, + ServiceID, + ), + ), Recv(RawUdp), } pub(crate) enum WorkAnswer { @@ -49,7 +50,7 @@ pub(crate) enum WorkAnswer { pub(crate) struct Worker { thread_id: ThreadTracker, // PERF: rand uses syscalls. how to do that async? - rand: ::ring::rand::SystemRandom, + rand: Random, stop_working: ::tokio::sync::broadcast::Receiver, token_check: Option>>, sockets: Vec, @@ -121,7 +122,7 @@ impl Worker { Ok(Self { thread_id, - rand: ::ring::rand::SystemRandom::new(), + rand: Random::new(), stop_working, token_check, sockets, @@ -132,7 +133,7 @@ impl Worker { }) } pub(crate) async fn work_loop(&mut self) { - loop { + 'mainloop: loop { let work = ::tokio::select! { _done = self.stop_working.recv() => { break; @@ -149,7 +150,51 @@ impl Worker { let conn_num = self.connections.len(); let _ = sender.send(conn_num); } - Work::Connect((send_res, dnssec_record, service_id)) => { + Work::Connect((send_res, dnssec_record, _service_id)) => { + let destination = + dnssec_record.addresses.iter().find_map(|addr| { + let maybe_key = + dnssec_record.public_keys.iter().find( + |(id, _)| addr.public_key_ids.contains(id), + ); + match maybe_key { + Some(key) => Some((addr, key)), + None => None, + } + }); + let (addr, key) = match destination { + Some((addr, key)) => (addr, key), + None => { + let _ = + send_res.send(Err(crate::Error::Resolution( + "No selectable address and key combination" + .to_owned(), + ))); + continue 'mainloop; + } + }; + use crate::enc::asym; + let exchange = asym::KeyExchange::X25519DiffieHellman; + let (priv_key, pub_key) = + match exchange.new_keypair(&self.rand) { + Ok(pair) => pair, + Err(_) => todo!(), + }; + // build request + /* + let req = dirsync::Req { + key_id: key.0, + exchange: exchange, + cipher: 42, + exchange_key: client_pub_key, + data: 42, + }; + */ + + // start timeout + + // send packet + todo!() } //TODO: reconf message to add channels diff --git a/src/lib.rs b/src/lib.rs index 6f48ec5..d09d39e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -59,6 +59,9 @@ pub enum Error { /// Key error #[error("key: {0:?}")] Key(#[from] crate::enc::Error), + /// Resolution problems. wrong or incomplete DNSSEC data + #[error("DNSSEC resolution: {0}")] + Resolution(String), } /// Instance of a fenrir endpoint @@ -199,7 +202,7 @@ impl Fenrir { // we very likely have multiple threads, pinned to different cpus. // use the ConnectionID to send the same connection // to the same thread. - // Handshakes have conenction ID 0, so we use the sender's UDP port + // Handshakes have connection ID 0, so we use the sender's UDP port let packet = match Packet::deserialize_id(&data) { Ok(packet) => packet, @@ -266,7 +269,7 @@ impl Fenrir { continue; } Reservation::Reserved => break, - Reservation::Present(id_send) => { + Reservation::Present(_id_send) => { //TODO: reuse connection todo!() } @@ -275,7 +278,6 @@ impl Fenrir { // Spot reserved for the connection // find the thread with less connections - let th_num = self._thread_work.len(); let mut conn_count = Vec::::with_capacity(th_num); let mut wait_res = @@ -313,15 +315,14 @@ impl Fenrir { match recv.await { Ok(res) => { - use crate::inner::worker::ConnectionResult; match res { - ConnectionResult::Failed(e) => { + Err(e) => { let mut conn_auth_lock = self.conn_auth_srv.lock().await; conn_auth_lock.remove_reserved(&resolved); Err(e) } - ConnectionResult::Established((pubkey, id_send)) => { + Ok((pubkey, id_send)) => { let mut conn_auth_lock = self.conn_auth_srv.lock().await; conn_auth_lock.add(&pubkey, id_send, &resolved); From 1bae4c9953e2fc4a89dda8bb8226da7729d39636 Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Tue, 30 May 2023 13:47:08 +0200 Subject: [PATCH 15/34] DNSSEC: add ciphers/key exchanges/hkdfs Signed-off-by: Luca Fulchir --- src/connection/mod.rs | 6 +- src/dnssec/record.rs | 259 +++++++++++++++++++++++++++++++++--------- src/enc/hkdf.rs | 48 +++++++- src/inner/mod.rs | 6 +- src/inner/worker.rs | 10 +- 5 files changed, 262 insertions(+), 67 deletions(-) diff --git a/src/connection/mod.rs b/src/connection/mod.rs index 2f8e3c9..7c1d733 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -15,7 +15,7 @@ use crate::{ dnssec, enc::{ asym::PubKey, - hkdf::HkdfSha3, + hkdf::Hkdf, sym::{CipherKind, CipherRecv, CipherSend}, Random, }, @@ -55,7 +55,7 @@ pub struct Connection { /// Sending Connection ID pub id_send: IDSend, /// The main hkdf used for all secrets in this connection - pub hkdf: HkdfSha3, + pub hkdf: Hkdf, /// Cipher for decrypting data pub cipher_recv: CipherRecv, /// Cipher for encrypting data @@ -76,7 +76,7 @@ pub enum Role { impl Connection { pub(crate) fn new( - hkdf: HkdfSha3, + hkdf: Hkdf, cipher: CipherKind, role: Role, rand: &Random, diff --git a/src/dnssec/record.rs b/src/dnssec/record.rs index 48e8edd..51f7fd3 100644 --- a/src/dnssec/record.rs +++ b/src/dnssec/record.rs @@ -3,27 +3,51 @@ //! //! Encoding and decoding in base85, RFC1924 //! -//! //! Basic encoding idea: -//! * 1 byte: half-bytes +//! * 1 byte: divided in two: //! * half: num of addresses //! * half: num of pubkeys +//! * 1 byte: divided in half: +//! * half: number of key exchanges +//! * half: number of Hkdfs +//! * 1 byte: divided in half: +//! * half: number of ciphers +//! * half: nothing //! [ # list of addresses //! * 1 byte: bitfield //! * 0..1 ipv4/ipv6 //! * 2..4 priority (for failover) //! * 5..7 weight between priority -//! * 1 byte: public key id +//! * 1 byte: divided in half: +//! * half: num of public key ids +//! * half: num of handhskae ids //! * 2 bytes: UDP port +//! * [ 1 byte per public key id ] +//! * [ 1 byte per handshake id ] //! * X bytes: IP //! ] //! [ # list of pubkeys //! * 1 byte: pubkey id +//! * 1 byte: pubkey length //! * 1 byte: pubkey type //! * Y bytes: pubkey //! ] +//! [ # list of supported key exchanges +//! * 1 byte for each cipher +//! ] +//! [ # list of supported HDKFs +//! * 1 byte for each hkdf +//! ] +//! [ # list of supported ciphers +//! * 1 byte for each cipher +//! ] -use crate::enc::{self, asym::PubKey}; +use crate::enc::{ + self, + asym::{KeyExchange, PubKey}, + hkdf::HkdfKind, + sym::CipherKind, +}; use ::core::num::NonZeroU16; use ::num_traits::FromPrimitive; use ::std::{net::IpAddr, vec::Vec}; @@ -220,6 +244,10 @@ impl Address { bitfield |= self.weight as u8; raw.push(bitfield); + let len_combined: u8 = self.public_key_ids.len() as u8; + let len_combined = len_combined << 4; + let len_combined = len_combined | self.handshake_ids.len() as u8; + raw.push(len_combined); raw.extend_from_slice( &(match self.port { @@ -228,10 +256,12 @@ impl Address { }), ); - raw.push(self.public_key_ids.len() as u8); for id in self.public_key_ids.iter() { raw.push(id.0); } + for id in self.handshake_ids.iter() { + raw.push(*id as u8); + } match self.ip { IpAddr::V4(ip) => { @@ -250,19 +280,11 @@ impl Address { } let ip_type = raw[0] >> 6; let is_ipv6: bool; - let total_length: usize; match ip_type { 0 => { is_ipv6 = false; - total_length = 8; - } - 1 => { - total_length = 20; - if raw.len() < total_length { - return Err(Error::NotEnoughData(1)); - } - is_ipv6 = true } + 1 => is_ipv6 = true, _ => return Err(Error::UnsupportedData(0)), } let raw_priority = (raw[0] << 2) >> 5; @@ -270,28 +292,33 @@ impl Address { let priority = AddressPriority::from_u8(raw_priority).unwrap(); let weight = AddressWeight::from_u8(raw_weight).unwrap(); + // UDP port let raw_port = u16::from_le_bytes([raw[1], raw[2]]); + let port = if raw_port == 0 { + None + } else { + Some(NonZeroU16::new(raw_port).unwrap()) + }; // Add publickey ids - let num_pubkey_ids = raw[3] as usize; - if raw.len() < 3 + num_pubkey_ids { + let num_pubkey_ids = (raw[3] >> 4) as usize; + let num_handshake_ids = (raw[3] & 0x0F) as usize; + if raw.len() <= 3 + num_pubkey_ids + num_handshake_ids { return Err(Error::NotEnoughData(3)); } + let mut bytes_parsed = 4; let mut public_key_ids = Vec::with_capacity(num_pubkey_ids); - - for raw_pubkey_id in raw[4..num_pubkey_ids].iter() { + for raw_pubkey_id in + raw[bytes_parsed..(bytes_parsed + num_pubkey_ids)].iter() + { public_key_ids.push(PublicKeyID(*raw_pubkey_id)); } // add handshake ids - let next_ptr = 3 + num_pubkey_ids; - let num_handshake_ids = raw[next_ptr] as usize; - if raw.len() < next_ptr + num_handshake_ids { - return Err(Error::NotEnoughData(next_ptr)); - } + bytes_parsed = bytes_parsed + num_pubkey_ids; let mut handshake_ids = Vec::with_capacity(num_handshake_ids); for raw_handshake_id in - raw[next_ptr..(next_ptr + num_pubkey_ids)].iter() + raw[bytes_parsed..(bytes_parsed + num_handshake_ids)].iter() { match HandshakeID::from_u8(*raw_handshake_id) { Some(h_id) => handshake_ids.push(h_id), @@ -304,26 +331,24 @@ impl Address { } } } - let next_ptr = next_ptr + num_pubkey_ids; + bytes_parsed = bytes_parsed + num_handshake_ids; - let port = if raw_port == 0 { - None - } else { - Some(NonZeroU16::new(raw_port).unwrap()) - }; let ip = if is_ipv6 { - let ip_end = next_ptr + 16; + let ip_end = bytes_parsed + 16; if raw.len() < ip_end { - return Err(Error::NotEnoughData(next_ptr)); + return Err(Error::NotEnoughData(bytes_parsed)); } - let raw_ip: [u8; 16] = raw[next_ptr..ip_end].try_into().unwrap(); + let raw_ip: [u8; 16] = + raw[bytes_parsed..ip_end].try_into().unwrap(); + bytes_parsed = bytes_parsed + 16; IpAddr::from(raw_ip) } else { - let ip_end = next_ptr + 4; + let ip_end = bytes_parsed + 4; if raw.len() < ip_end { - return Err(Error::NotEnoughData(next_ptr)); + return Err(Error::NotEnoughData(bytes_parsed)); } - let raw_ip: [u8; 4] = raw[next_ptr..ip_end].try_into().unwrap(); + let raw_ip: [u8; 4] = raw[bytes_parsed..ip_end].try_into().unwrap(); + bytes_parsed = bytes_parsed + 4; IpAddr::from(raw_ip) }; @@ -336,7 +361,7 @@ impl Address { public_key_ids, handshake_ids, }, - total_length, + bytes_parsed, )) } } @@ -353,6 +378,12 @@ pub struct Record { /// List of all authentication servers' addresses. /// Multiple ones can point to the same authentication server pub addresses: Vec
, + /// List of supported key exchanges + pub key_exchanges: Vec, + /// List of supported key exchanges + pub hkdfs: Vec, + /// List of supported ciphers + pub ciphers: Vec, } impl Record { @@ -371,15 +402,27 @@ impl Record { if self.addresses.len() > 16 { return Err(Error::Max16Addresses); } + if self.key_exchanges.len() > 16 { + return Err(Error::Max16KeyExchanges); + } + if self.hkdfs.len() > 16 { + return Err(Error::Max16Hkdfs); + } + if self.ciphers.len() > 16 { + return Err(Error::Max16Ciphers); + } // everything else is all good - let total_size: usize = 1 + let total_size: usize = 3 + self.addresses.iter().map(|a| a.raw_len()).sum::() + self .public_keys .iter() - .map(|(_, key)| 1 + key.kind().pub_len()) - .sum::(); + .map(|(_, key)| 3 + key.kind().pub_len()) + .sum::() + + self.key_exchanges.len() + + self.hkdfs.len() + + self.ciphers.len(); let mut raw = Vec::with_capacity(total_size); @@ -387,33 +430,56 @@ impl Record { let len_combined: u8 = self.addresses.len() as u8; let len_combined = len_combined << 4; let len_combined = len_combined | self.public_keys.len() as u8; - raw.push(len_combined); + // number of key exchanges and hkdfs + let len_combined: u8 = self.key_exchanges.len() as u8; + let len_combined = len_combined << 4; + let len_combined = len_combined | self.hkdfs.len() as u8; + raw.push(len_combined); + let num_of_ciphers: u8 = (self.ciphers.len() as u8) << 4; + raw.push(num_of_ciphers); for address in self.addresses.iter() { address.encode_into(&mut raw); } for (public_key_id, public_key) in self.public_keys.iter() { raw.push(public_key_id.0); + raw.push(public_key.kind().pub_len() as u8); + raw.push(public_key.kind() as u8); public_key.serialize_into(&mut raw); } + for k_x in self.key_exchanges.iter() { + raw.push(*k_x as u8); + } + for h in self.hkdfs.iter() { + raw.push(*h as u8); + } + for c in self.ciphers.iter() { + raw.push(*c as u8); + } Ok(::base85::encode(&raw)) } /// Decode from base85 to the actual object pub fn decode(raw: &[u8]) -> Result { - // bare minimum for 1 address and key - const MIN_RAW_LENGTH: usize = 1 + 8 + 8; + // bare minimum for 1 address, 1 key, 1 key exchange and 1 cipher + const MIN_RAW_LENGTH: usize = 1 + 1 + 1 + 8 + 9 + 1 + 1; if raw.len() <= MIN_RAW_LENGTH { return Err(Error::NotEnoughData(0)); } let mut num_addresses = (raw[0] >> 4) as usize; let mut num_public_keys = (raw[0] & 0x0F) as usize; - let mut bytes_parsed = 1; + let mut num_key_exchanges = (raw[1] >> 4) as usize; + let mut num_hkdfs = (raw[1] & 0x0F) as usize; + let mut num_ciphers = (raw[2] >> 4) as usize; + let mut bytes_parsed = 3; let mut result = Self { addresses: Vec::with_capacity(num_addresses), public_keys: Vec::with_capacity(num_public_keys), + key_exchanges: Vec::with_capacity(num_key_exchanges), + hkdfs: Vec::with_capacity(num_hkdfs), + ciphers: Vec::with_capacity(num_ciphers), }; while num_addresses > 0 { @@ -433,23 +499,97 @@ impl Record { num_addresses = num_addresses - 1; } while num_public_keys > 0 { + if bytes_parsed + 2 >= raw.len() { + return Err(Error::NotEnoughData(bytes_parsed)); + } let id = PublicKeyID(raw[bytes_parsed]); bytes_parsed = bytes_parsed + 1; - let (public_key, bytes) = - match PubKey::deserialize(&raw[bytes_parsed..]) { - Ok(public_key) => public_key, - Err(enc::Error::UnsupportedKey(b)) => { - return Err(Error::UnsupportedData(bytes_parsed + b)) - } - Err(enc::Error::NotEnoughData(b)) => { - return Err(Error::NotEnoughData(bytes_parsed + b)) - } - _ => return Err(Error::UnknownData(bytes_parsed)), - }; + let pubkey_length = raw[bytes_parsed] as usize; + bytes_parsed = bytes_parsed + 1; + if pubkey_length + bytes_parsed >= raw.len() { + return Err(Error::NotEnoughData(bytes_parsed)); + } + let (public_key, bytes) = match PubKey::deserialize( + &raw[bytes_parsed..(bytes_parsed + pubkey_length)], + ) { + Ok(public_key_and_bytes) => public_key_and_bytes, + Err(enc::Error::UnsupportedKey(_)) => { + // continue parsing. This could be a new pubkey type + // that is not supported by an older client + ::tracing::warn!("Unsupported public key type"); + bytes_parsed = bytes_parsed + pubkey_length; + continue; + } + Err(_) => { + return Err(Error::UnsupportedData(bytes_parsed)); + } + }; + if bytes != 1 + pubkey_length { + return Err(Error::UnsupportedData(bytes_parsed)); + } bytes_parsed = bytes_parsed + bytes; result.public_keys.push((id, public_key)); num_public_keys = num_public_keys - 1; } + if bytes_parsed + num_key_exchanges + num_hkdfs + num_ciphers + != raw.len() + { + return Err(Error::NotEnoughData(bytes_parsed)); + } + while num_key_exchanges > 0 { + let key_exchange = match KeyExchange::from_u8(raw[bytes_parsed]) { + Some(key_exchange) => key_exchange, + None => { + // continue parsing. This could be a new key exchange type + // that is not supported by an older client + ::tracing::warn!( + "Unknown Key exchange {}. Ignoring", + raw[bytes_parsed] + ); + bytes_parsed = bytes_parsed + 1; + continue; + } + }; + bytes_parsed = bytes_parsed + 1; + result.key_exchanges.push(key_exchange); + num_key_exchanges = num_key_exchanges - 1; + } + while num_hkdfs > 0 { + let hkdf = match HkdfKind::from_u8(raw[bytes_parsed]) { + Some(hkdf) => hkdf, + None => { + // continue parsing. This could be a new hkdf type + // that is not supported by an older client + ::tracing::warn!( + "Unknown hkdf {}. Ignoring", + raw[bytes_parsed] + ); + bytes_parsed = bytes_parsed + 1; + continue; + } + }; + bytes_parsed = bytes_parsed + 1; + result.hkdfs.push(hkdf); + num_hkdfs = num_hkdfs - 1; + } + while num_ciphers > 0 { + let cipher = match CipherKind::from_u8(raw[bytes_parsed]) { + Some(cipher) => cipher, + None => { + // continue parsing. This could be a new cipher type + // that is not supported by an older client + ::tracing::warn!( + "Unknown Cipher {}. Ignoring", + raw[bytes_parsed] + ); + bytes_parsed = bytes_parsed + 1; + continue; + } + }; + bytes_parsed = bytes_parsed + 1; + result.ciphers.push(cipher); + num_ciphers = num_ciphers - 1; + } if bytes_parsed != raw.len() { Err(Error::UnknownData(bytes_parsed)) } else { @@ -470,6 +610,15 @@ pub enum Error { /// Too many addresses (max 16) #[error("can't encode more than 16 addresses")] Max16Addresses, + /// Too many key exchanges (max 16) + #[error("can't encode more than 16 key exchanges")] + Max16KeyExchanges, + /// Too many Hkdfs (max 16) + #[error("can't encode more than 16 Hkdfs")] + Max16Hkdfs, + /// Too many ciphers (max 16) + #[error("can't encode more than 16 Ciphers")] + Max16Ciphers, /// We need at least one public key #[error("no public keys found")] NoPublicKeyFound, diff --git a/src/enc/hkdf.rs b/src/enc/hkdf.rs index 15d7eca..a4b6868 100644 --- a/src/enc/hkdf.rs +++ b/src/enc/hkdf.rs @@ -1,12 +1,52 @@ //! Hash-based Key Derivation Function //! We just repackage other crates -use ::hkdf::Hkdf; use ::sha3::Sha3_256; use ::zeroize::Zeroize; use crate::enc::sym::Secret; +/// Kind of HKDF +#[derive(Debug, Copy, Clone, PartialEq, ::num_derive::FromPrimitive)] +#[non_exhaustive] +#[repr(u8)] +pub enum HkdfKind { + /// Sha3 + Sha3 = 0, +} + +/// Generic wrapper on Hkdfs +#[derive(Clone)] +pub enum Hkdf { + /// Sha3 based + Sha3(HkdfSha3), +} + +// Fake debug implementation to avoid leaking secrets +impl ::core::fmt::Debug for Hkdf { + fn fmt( + &self, + f: &mut core::fmt::Formatter<'_>, + ) -> Result<(), ::std::fmt::Error> { + ::core::fmt::Debug::fmt("[hidden hkdf]", f) + } +} + +impl Hkdf { + /// New Hkdf + pub fn new(kind: HkdfKind, salt: &[u8], key: Secret) -> Self { + match kind { + HkdfKind::Sha3 => Self::Sha3(HkdfSha3::new(salt, key)), + } + } + /// Get a secret generated from the key and a given context + pub fn get_secret(&self, context: &[u8]) -> Secret { + match self { + Hkdf::Sha3(sha3) => sha3.get_secret(context), + } + } +} + // Hack & tricks: // HKDF are pretty important, but this lib don't zero out the data. // we can't use #[derive(Zeroing)] either. @@ -14,10 +54,10 @@ use crate::enc::sym::Secret; #[derive(Zeroize)] #[zeroize(drop)] -struct Zeroable([u8; ::core::mem::size_of::>()]); +struct Zeroable([u8; ::core::mem::size_of::<::hkdf::Hkdf>()]); union HkdfInner { - hkdf: ::core::mem::ManuallyDrop>, + hkdf: ::core::mem::ManuallyDrop<::hkdf::Hkdf>, zeroable: ::core::mem::ManuallyDrop, } @@ -50,7 +90,7 @@ pub struct HkdfSha3 { impl HkdfSha3 { /// Instantiate a new HKDF with Sha3-256 pub fn new(salt: &[u8], key: Secret) -> Self { - let hkdf = Hkdf::::new(Some(salt), key.as_ref()); + let hkdf = ::hkdf::Hkdf::::new(Some(salt), key.as_ref()); Self { inner: HkdfInner { hkdf: ::core::mem::ManuallyDrop::new(hkdf), diff --git a/src/inner/mod.rs b/src/inner/mod.rs index 9ee6942..b782258 100644 --- a/src/inner/mod.rs +++ b/src/inner/mod.rs @@ -13,7 +13,7 @@ use crate::{ }, enc::{ self, asym, - hkdf::HkdfSha3, + hkdf::{Hkdf, HkdfKind}, sym::{CipherKind, CipherRecv}, }, Error, @@ -26,7 +26,7 @@ pub(crate) struct AuthNeededInfo { /// Parsed handshake packet pub handshake: Handshake, /// hkdf generated from the handshake - pub hkdf: HkdfSha3, + pub hkdf: Hkdf, /// cipher to be used in both directions pub cipher: CipherKind, } @@ -149,7 +149,7 @@ impl HandshakeTracker { Ok(shared_key) => shared_key, Err(e) => return Err(handshake::Error::Key(e).into()), }; - let hkdf = HkdfSha3::new(b"fenrir", shared_key); + let hkdf = Hkdf::new(HkdfKind::Sha3, b"fenrir", shared_key); let secret_recv = hkdf.get_secret(b"to_server"); let cipher_recv = CipherRecv::new(req.cipher, secret_recv); use crate::enc::sym::AAD; diff --git a/src/inner/worker.rs b/src/inner/worker.rs index fcc69ea..1e34b56 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -11,7 +11,12 @@ use crate::{ ConnList, Connection, IDSend, Packet, }, dnssec, - enc::{asym::PubKey, hkdf::HkdfSha3, sym::Secret, Random}, + enc::{ + asym::PubKey, + hkdf::{Hkdf, HkdfKind}, + sym::Secret, + Random, + }, inner::{HandshakeAction, HandshakeTracker, ThreadTracker}, }; use ::std::{rc::Rc, sync::Arc, vec::Vec}; @@ -381,7 +386,8 @@ impl Worker { // SECURITY: //FIXME: the Secret should be XORed with the client stored // secret (if any) - let hkdf = HkdfSha3::new( + let hkdf = Hkdf::new( + HkdfKind::Sha3, cci.service_id.as_bytes(), resp_data.service_key, ); From ac213a6528dc65ab6f857612946217f62b21db4d Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Thu, 1 Jun 2023 11:41:10 +0200 Subject: [PATCH 16/34] More work on key exhcnage negotiation Signed-off-by: Luca Fulchir --- src/config/mod.rs | 16 +++ src/connection/handshake/dirsync.rs | 52 ++++++++-- src/connection/handshake/mod.rs | 31 ++++++ src/dnssec/record.rs | 145 +++++++++++++--------------- src/enc/asym.rs | 43 ++++++++- src/enc/hkdf.rs | 40 +++++++- src/enc/sym.rs | 29 +++++- src/inner/worker.rs | 97 ++++++++++++++++--- src/lib.rs | 5 + 9 files changed, 350 insertions(+), 108 deletions(-) diff --git a/src/config/mod.rs b/src/config/mod.rs index e3fd4c8..3e489b9 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,6 +1,10 @@ //! //! Configuration to initialize the Fenrir networking library +use crate::{ + connection::handshake::HandshakeID, + enc::{asym::KeyExchange, hkdf::HkdfKind, sym::CipherKind}, +}; use ::std::{ net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, num::NonZeroUsize, @@ -18,6 +22,14 @@ pub struct Config { pub listen: Vec, /// List of DNS resolvers to use pub resolvers: Vec, + /// Supported handshakes + pub handshakes: Vec, + /// Supported key exchanges + pub key_exchanges: Vec, + /// Supported Hkdfs + pub hkdfs: Vec, + /// Supported Ciphers + pub ciphers: Vec, } impl Default for Config { @@ -34,6 +46,10 @@ impl Default for Config { ), ], resolvers: Vec::new(), + handshakes: [HandshakeID::DirectorySynchronized].to_vec(), + key_exchanges: [KeyExchange::X25519DiffieHellman].to_vec(), + hkdfs: [HkdfKind::Sha3].to_vec(), + ciphers: [CipherKind::XChaCha20Poly1305].to_vec(), } } } diff --git a/src/connection/handshake/dirsync.rs b/src/connection/handshake/dirsync.rs index 775f301..02bb097 100644 --- a/src/connection/handshake/dirsync.rs +++ b/src/connection/handshake/dirsync.rs @@ -14,13 +14,41 @@ use crate::{ connection::{ProtocolVersion, ID}, enc::{ asym::{ExchangePubKey, KeyExchange, KeyID}, + hkdf::HkdfKind, sym::{CipherKind, HeadLen, Secret, TagLen}, + Random, }, }; use ::arrayref::array_mut_ref; -type Nonce = [u8; 16]; +// TODO: merge with crate::enc::sym::Nonce +/// random nonce +#[derive(Debug, Clone, Copy)] +pub struct Nonce([u8; 16]); + +impl Nonce { + /// Create a new random Nonce + pub fn new(rnd: &Random) -> Self { + use ::core::mem::MaybeUninit; + let mut out: MaybeUninit<[u8; 16]>; + #[allow(unsafe_code)] + unsafe { + out = MaybeUninit::uninit(); + let _ = rnd.fill(out.assume_init_mut()); + Self(out.assume_init()) + } + } + /// Length of the serialized Nonce + pub const fn len() -> usize { + 16 + } +} +impl From<&[u8; 16]> for Nonce { + fn from(raw: &[u8; 16]) -> Self { + Self(raw.clone()) + } +} /// Parsed handshake #[derive(Debug, Clone)] @@ -61,6 +89,8 @@ pub struct Req { pub key_id: KeyID, /// Selected key exchange pub exchange: KeyExchange, + /// Selected hkdf + pub hkdf: HkdfKind, /// Selected cipher pub cipher: CipherKind, /// Client ephemeral public key used for key exchanges @@ -77,6 +107,7 @@ impl Req { ProtocolVersion::len() + KeyID::len() + KeyExchange::len() + + HkdfKind::len() + CipherKind::len() + self.exchange_key.kind().pub_len() } @@ -91,6 +122,7 @@ impl Req { pub fn len(&self) -> usize { KeyID::len() + KeyExchange::len() + + HkdfKind::len() + CipherKind::len() + self.exchange_key.kind().pub_len() + self.data.len() @@ -121,18 +153,23 @@ impl super::HandshakeParsing for Req { Some(exchange) => exchange, None => return Err(Error::Parsing), }; - let cipher: CipherKind = match CipherKind::from_u8(raw[3]) { + let hkdf: HkdfKind = match HkdfKind::from_u8(raw[3]) { + Some(exchange) => exchange, + None => return Err(Error::Parsing), + }; + let cipher: CipherKind = match CipherKind::from_u8(raw[4]) { Some(cipher) => cipher, None => return Err(Error::Parsing), }; - let (exchange_key, len) = match ExchangePubKey::from_slice(&raw[4..]) { + let (exchange_key, len) = match ExchangePubKey::from_slice(&raw[5..]) { Ok(exchange_key) => exchange_key, Err(e) => return Err(e.into()), }; - let data = ReqInner::CipherText(raw.len() - (4 + len)); + let data = ReqInner::CipherText(raw.len() - (5 + len)); Ok(HandshakeData::DirSync(DirSync::Req(Self { key_id, exchange, + hkdf, cipher, exchange_key, data, @@ -253,7 +290,7 @@ pub struct ReqData { impl ReqData { /// actual length of the request data pub fn len(&self) -> usize { - self.nonce.len() + KeyID::len() + ID::len() + self.auth.len() + Nonce::len() + KeyID::len() + ID::len() + self.auth.len() } /// Minimum byte length of the request data pub const MIN_PKT_LEN: usize = @@ -265,7 +302,8 @@ impl ReqData { } let mut start = 0; let mut end = 16; - let nonce: Nonce = raw[start..end].try_into().unwrap(); + let raw_sized: &[u8; 16] = raw[start..end].try_into().unwrap(); + let nonce: Nonce = raw_sized.into(); start = end; end = end + KeyID::len(); let client_key_id = @@ -440,7 +478,7 @@ impl RespData { assert!(out.len() == Self::len(), "wrong buffer size"); let mut start = 0; let mut end = Self::NONCE_LEN; - out[start..end].copy_from_slice(&self.client_nonce); + out[start..end].copy_from_slice(&self.client_nonce.0); start = end; end = end + Self::NONCE_LEN; self.id.serialize(&mut out[start..end]); diff --git a/src/connection/handshake/mod.rs b/src/connection/handshake/mod.rs index 76eb353..778a003 100644 --- a/src/connection/handshake/mod.rs +++ b/src/connection/handshake/mod.rs @@ -27,6 +27,37 @@ pub enum Error { NotEnoughData, } +/// List of possible handshakes +#[derive(::num_derive::FromPrimitive, Debug, Clone, Copy, PartialEq)] +#[repr(u8)] +pub enum HandshakeID { + /// 1-RTT Directory synchronized handshake. Fast, no forward secrecy + DirectorySynchronized = 0, + /// 2-RTT Stateful exchange. Little DDos protection + Stateful, + /// 3-RTT stateless exchange. Forward secrecy and ddos protection + Stateless, +} + +impl TryFrom<&str> for HandshakeID { + type Error = ::std::io::Error; + // TODO: from actual names, not only numeric + fn try_from(raw: &str) -> Result { + if let Ok(handshake_u8) = raw.parse::() { + if handshake_u8 >= 1 { + if let Some(handshake) = HandshakeID::from_u8(handshake_u8 - 1) + { + return Ok(handshake); + } + } + } + return Err(::std::io::Error::new( + ::std::io::ErrorKind::InvalidData, + "Unknown handshake ID", + )); + } +} + pub(crate) struct HandshakeServer { pub id: crate::enc::asym::KeyID, pub key: crate::enc::asym::PrivKey, diff --git a/src/dnssec/record.rs b/src/dnssec/record.rs index 51f7fd3..2457c0b 100644 --- a/src/dnssec/record.rs +++ b/src/dnssec/record.rs @@ -19,15 +19,15 @@ //! * 2..4 priority (for failover) //! * 5..7 weight between priority //! * 1 byte: divided in half: -//! * half: num of public key ids -//! * half: num of handhskae ids +//! * half: num of public key indexes +//! * half: num of handshake ids //! * 2 bytes: UDP port -//! * [ 1 byte per public key id ] +//! * [ HALF byte per public key idx ] (index on the list of public keys) //! * [ 1 byte per handshake id ] //! * X bytes: IP //! ] -//! [ # list of pubkeys -//! * 1 byte: pubkey id +//! [ # list of pubkeys (max: 16) +//! * 2 byte: pubkey id //! * 1 byte: pubkey length //! * 1 byte: pubkey type //! * Y bytes: pubkey @@ -42,11 +42,14 @@ //! * 1 byte for each cipher //! ] -use crate::enc::{ - self, - asym::{KeyExchange, PubKey}, - hkdf::HkdfKind, - sym::CipherKind, +use crate::{ + connection::handshake::HandshakeID, + enc::{ + self, + asym::{KeyExchange, KeyID, PubKey}, + hkdf::HkdfKind, + sym::CipherKind, + }, }; use ::core::num::NonZeroU16; use ::num_traits::FromPrimitive; @@ -55,22 +58,12 @@ use ::std::{net::IpAddr, vec::Vec}; * Public key data */ -/// Public Key ID +/// Public Key Index. +/// this points to the index of the public key in the array of public keys. +/// needed to have one byte less in the list of public keys +/// supported by and address #[derive(Debug, Copy, Clone, PartialEq)] -pub struct PublicKeyID(u8); - -impl TryFrom<&str> for PublicKeyID { - type Error = ::std::io::Error; - fn try_from(raw: &str) -> Result { - if let Ok(id_u8) = raw.parse::() { - return Ok(PublicKeyID(id_u8)); - } - return Err(::std::io::Error::new( - ::std::io::ErrorKind::InvalidData, - "Public Key ID must be between 0 and 256", - )); - } -} +pub struct PubKeyIdx(pub u8); /* * Address data @@ -168,37 +161,6 @@ impl TryFrom<&str> for AddressWeight { } } -/// List of possible handshakes -#[derive(::num_derive::FromPrimitive, Debug, Clone, Copy)] -#[repr(u8)] -pub enum HandshakeID { - /// 1-RTT Directory synchronized handshake. Fast, no forward secrecy - DirectorySynchronized = 0, - /// 2-RTT Stateful exchange. Little DDos protection - Stateful, - /// 3-RTT stateless exchange. Forward secrecy and ddos protection - Stateless, -} - -impl TryFrom<&str> for HandshakeID { - type Error = ::std::io::Error; - // TODO: from actual names, not only numeric - fn try_from(raw: &str) -> Result { - if let Ok(handshake_u8) = raw.parse::() { - if handshake_u8 >= 1 { - if let Some(handshake) = HandshakeID::from_u8(handshake_u8 - 1) - { - return Ok(handshake); - } - } - } - return Err(::std::io::Error::new( - ::std::io::ErrorKind::InvalidData, - "Unknown handshake ID", - )); - } -} - /// Authentication server address information: /// * ip /// * udp port @@ -220,14 +182,16 @@ pub struct Address { /// List of supported handshakes pub handshake_ids: Vec, /// Public key IDs used by this address - pub public_key_ids: Vec, + pub public_key_idx: Vec, } impl Address { fn raw_len(&self) -> usize { // UDP port + Priority + Weight + pubkey_len + handshake_len let mut size = 6; - size = size + self.public_key_ids.len() + self.handshake_ids.len(); + let num_pubkey_idx = self.public_key_idx.len(); + let idx_bytes = (num_pubkey_idx / 2) + (num_pubkey_idx % 2); + size = size + idx_bytes + self.handshake_ids.len(); size + match self.ip { IpAddr::V4(_) => size + 4, IpAddr::V6(_) => size + 16, @@ -244,7 +208,7 @@ impl Address { bitfield |= self.weight as u8; raw.push(bitfield); - let len_combined: u8 = self.public_key_ids.len() as u8; + let len_combined: u8 = self.public_key_idx.len() as u8; let len_combined = len_combined << 4; let len_combined = len_combined | self.handshake_ids.len() as u8; raw.push(len_combined); @@ -256,8 +220,18 @@ impl Address { }), ); - for id in self.public_key_ids.iter() { - raw.push(id.0); + // pair every idx, since the max is 16 + for chunk in self.public_key_idx.chunks(2) { + let second = { + if chunk.len() == 2 { + chunk[1].0 + } else { + 0 + } + }; + let tmp = chunk[0].0 << 4; + let tmp = tmp | second; + raw.push(tmp); } for id in self.handshake_ids.iter() { raw.push(*id as u8); @@ -301,21 +275,29 @@ impl Address { }; // Add publickey ids - let num_pubkey_ids = (raw[3] >> 4) as usize; + let num_pubkey_idx = (raw[3] >> 4) as usize; let num_handshake_ids = (raw[3] & 0x0F) as usize; - if raw.len() <= 3 + num_pubkey_ids + num_handshake_ids { + if raw.len() <= 3 + num_pubkey_idx + num_handshake_ids { return Err(Error::NotEnoughData(3)); } let mut bytes_parsed = 4; - let mut public_key_ids = Vec::with_capacity(num_pubkey_ids); - for raw_pubkey_id in - raw[bytes_parsed..(bytes_parsed + num_pubkey_ids)].iter() + let mut public_key_idx = Vec::with_capacity(num_pubkey_idx); + let idx_bytes = (num_pubkey_idx / 2) + (num_pubkey_idx % 2); + let mut idx_added = 0; + for raw_pubkey_idx_pair in + raw[bytes_parsed..(bytes_parsed + idx_bytes)].iter() { - public_key_ids.push(PublicKeyID(*raw_pubkey_id)); + let first = PubKeyIdx(raw_pubkey_idx_pair >> 4); + let second = PubKeyIdx(raw_pubkey_idx_pair & 0x0F); + public_key_idx.push(first); + if num_pubkey_idx - idx_added >= 2 { + public_key_idx.push(second); + } + idx_added = idx_added + 2; } // add handshake ids - bytes_parsed = bytes_parsed + num_pubkey_ids; + bytes_parsed = bytes_parsed + idx_bytes; let mut handshake_ids = Vec::with_capacity(num_handshake_ids); for raw_handshake_id in raw[bytes_parsed..(bytes_parsed + num_handshake_ids)].iter() @@ -358,7 +340,7 @@ impl Address { port, priority, weight, - public_key_ids, + public_key_idx, handshake_ids, }, bytes_parsed, @@ -374,7 +356,7 @@ impl Address { #[derive(Debug, Clone)] pub struct Record { /// Public keys used by any authentication server - pub public_keys: Vec<(PublicKeyID, PubKey)>, + pub public_keys: Vec<(KeyID, PubKey)>, /// List of all authentication servers' addresses. /// Multiple ones can point to the same authentication server pub addresses: Vec
, @@ -443,7 +425,8 @@ impl Record { address.encode_into(&mut raw); } for (public_key_id, public_key) in self.public_keys.iter() { - raw.push(public_key_id.0); + let key_id_bytes = public_key_id.0.to_le_bytes(); + raw.extend_from_slice(&key_id_bytes); raw.push(public_key.kind().pub_len() as u8); raw.push(public_key.kind() as u8); public_key.serialize_into(&mut raw); @@ -462,8 +445,8 @@ impl Record { } /// Decode from base85 to the actual object pub fn decode(raw: &[u8]) -> Result { - // bare minimum for 1 address, 1 key, 1 key exchange and 1 cipher - const MIN_RAW_LENGTH: usize = 1 + 1 + 1 + 8 + 9 + 1 + 1; + // bare minimum for lengths, (1 address), (1 key), cipher negotiation + const MIN_RAW_LENGTH: usize = 3 + (6 + 4) + (4 + 32) + 1 + 1 + 1; if raw.len() <= MIN_RAW_LENGTH { return Err(Error::NotEnoughData(0)); } @@ -499,11 +482,14 @@ impl Record { num_addresses = num_addresses - 1; } while num_public_keys > 0 { - if bytes_parsed + 2 >= raw.len() { + if bytes_parsed + 3 >= raw.len() { return Err(Error::NotEnoughData(bytes_parsed)); } - let id = PublicKeyID(raw[bytes_parsed]); - bytes_parsed = bytes_parsed + 1; + + let raw_key_id = + u16::from_le_bytes([raw[bytes_parsed], raw[bytes_parsed + 1]]); + let id = KeyID(raw_key_id); + bytes_parsed = bytes_parsed + 2; let pubkey_length = raw[bytes_parsed] as usize; bytes_parsed = bytes_parsed + 1; if pubkey_length + bytes_parsed >= raw.len() { @@ -590,6 +576,13 @@ impl Record { result.ciphers.push(cipher); num_ciphers = num_ciphers - 1; } + for addr in result.addresses.iter() { + for idx in addr.public_key_idx.iter() { + if idx.0 as usize >= result.public_keys.len() { + return Err(Error::Max16PublicKeys); + } + } + } if bytes_parsed != raw.len() { Err(Error::UnknownData(bytes_parsed)) } else { diff --git a/src/enc/asym.rs b/src/enc/asym.rs index ff8dda1..49f530c 100644 --- a/src/enc/asym.rs +++ b/src/enc/asym.rs @@ -3,7 +3,10 @@ use ::num_traits::FromPrimitive; use super::Error; -use crate::enc::{sym::Secret, Random}; +use crate::{ + config::Config, + enc::{sym::Secret, Random}, +}; /// Public key ID #[derive(Debug, Copy, Clone, PartialEq)] @@ -68,8 +71,19 @@ impl KeyKind { KeyKind::X25519 => KeyCapabilities::Exchange, } } + /// Returns the key exchanges supported by this key + pub fn key_exchanges(&self) -> &'static [KeyExchange] { + const EMPTY: [KeyExchange; 0] = []; + const X25519_KEY_EXCHANGES: [KeyExchange; 1] = + [KeyExchange::X25519DiffieHellman]; + match self { + KeyKind::Ed25519 => &EMPTY, + KeyKind::X25519 => &X25519_KEY_EXCHANGES, + } + } } +// FIXME: rename in KeyExchangeKind /// Kind of key exchange #[derive(Debug, Copy, Clone, PartialEq, ::num_derive::FromPrimitive)] #[non_exhaustive] @@ -279,7 +293,28 @@ impl ExchangePubKey { } } -/// Build a new pair of private/public key pair -pub fn new_keypair(kind: KeyKind, rnd: &Random) -> (PrivKey, PubKey) { - todo!() +/// Select the best key exchange from our supported list +/// and the other endpoint supported list. +/// Give priority to our list +pub fn server_select_key_exchange( + cfg: &Config, + client_supported: &Vec, +) -> Option { + cfg.key_exchanges + .iter() + .find(|k| client_supported.contains(k)) + .copied() +} +/// Select the best key exchange from our supported list +/// and the other endpoint supported list. +/// Give priority to the server list +/// This is used only in the Directory Synchronized handshake +pub fn client_select_key_exchange( + cfg: &Config, + server_supported: &Vec, +) -> Option { + server_supported + .iter() + .find(|k| cfg.key_exchanges.contains(k)) + .copied() } diff --git a/src/enc/hkdf.rs b/src/enc/hkdf.rs index a4b6868..872b7f0 100644 --- a/src/enc/hkdf.rs +++ b/src/enc/hkdf.rs @@ -4,7 +4,7 @@ use ::sha3::Sha3_256; use ::zeroize::Zeroize; -use crate::enc::sym::Secret; +use crate::{config::Config, enc::sym::Secret}; /// Kind of HKDF #[derive(Debug, Copy, Clone, PartialEq, ::num_derive::FromPrimitive)] @@ -14,6 +14,12 @@ pub enum HkdfKind { /// Sha3 Sha3 = 0, } +impl HkdfKind { + /// Length of the serialized type + pub const fn len() -> usize { + 1 + } +} /// Generic wrapper on Hkdfs #[derive(Clone)] @@ -52,6 +58,8 @@ impl Hkdf { // we can't use #[derive(Zeroing)] either. // So we craete a union with a Zeroing object, and drop both manually. +// TODO: move this to Hkdf instead of Sha3 + #[derive(Zeroize)] #[zeroize(drop)] struct Zeroable([u8; ::core::mem::size_of::<::hkdf::Hkdf>()]); @@ -89,7 +97,7 @@ pub struct HkdfSha3 { impl HkdfSha3 { /// Instantiate a new HKDF with Sha3-256 - pub fn new(salt: &[u8], key: Secret) -> Self { + pub(crate) fn new(salt: &[u8], key: Secret) -> Self { let hkdf = ::hkdf::Hkdf::::new(Some(salt), key.as_ref()); Self { inner: HkdfInner { @@ -98,7 +106,7 @@ impl HkdfSha3 { } } /// Get a secret generated from the key and a given context - pub fn get_secret(&self, context: &[u8]) -> Secret { + pub(crate) fn get_secret(&self, context: &[u8]) -> Secret { let mut out: [u8; 32] = [0; 32]; #[allow(unsafe_code)] unsafe { @@ -117,3 +125,29 @@ impl ::core::fmt::Debug for HkdfSha3 { ::core::fmt::Debug::fmt("[hidden hkdf]", f) } } + +/// Select the best hkdf from our supported list +/// and the other endpoint supported list. +/// Give priority to our list +pub fn server_select_hkdf( + cfg: &Config, + client_supported: &Vec, +) -> Option { + cfg.hkdfs + .iter() + .find(|h| client_supported.contains(h)) + .copied() +} +/// Select the best hkdf from our supported list +/// and the other endpoint supported list. +/// Give priority to the server list +/// this is used only in the directory synchronized handshake +pub fn client_select_hkdf( + cfg: &Config, + server_supported: &Vec, +) -> Option { + server_supported + .iter() + .find(|h| cfg.hkdfs.contains(h)) + .copied() +} diff --git a/src/enc/sym.rs b/src/enc/sym.rs index d9673a0..f8d76e3 100644 --- a/src/enc/sym.rs +++ b/src/enc/sym.rs @@ -1,7 +1,7 @@ //! Symmetric cypher stuff use super::Error; -use crate::enc::Random; +use crate::{config::Config, enc::Random}; use ::zeroize::Zeroize; /// Secret, used for keys. @@ -299,7 +299,7 @@ impl XChaCha20Poly1305 { } // -// TODO: For efficiency "Nonce" should become a reference. +// TODO: Merge crate::{enc::sym::Nonce, connection::handshake::dirsync::Nonce} // #[derive(Debug, Copy, Clone)] @@ -387,3 +387,28 @@ impl NonceSync { old_nonce } } +/// Select the best cipher from our supported list +/// and the other endpoint supported list. +/// Give priority to our list +pub fn server_select_cipher( + cfg: &Config, + client_supported: &Vec, +) -> Option { + cfg.ciphers + .iter() + .find(|c| client_supported.contains(c)) + .copied() +} +/// Select the best cipher from our supported list +/// and the other endpoint supported list. +/// Give priority to the server list +/// This is used only in the Directory synchronized handshake +pub fn client_select_cipher( + cfg: &Config, + server_supported: &Vec, +) -> Option { + server_supported + .iter() + .find(|c| cfg.ciphers.contains(c)) + .copied() +} diff --git a/src/inner/worker.rs b/src/inner/worker.rs index 1e34b56..5ce3201 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -1,6 +1,7 @@ //! Worker thread implementation use crate::{ auth::{ServiceID, TokenChecker}, + config::Config, connection::{ self, handshake::{ @@ -12,9 +13,9 @@ use crate::{ }, dnssec, enc::{ - asym::PubKey, - hkdf::{Hkdf, HkdfKind}, - sym::Secret, + asym::{self, PubKey}, + hkdf::{self, Hkdf, HkdfKind}, + sym::{self, Secret}, Random, }, inner::{HandshakeAction, HandshakeTracker, ThreadTracker}, @@ -53,6 +54,7 @@ pub(crate) enum WorkAnswer { /// Actual worker implementation. pub(crate) struct Worker { + cfg: Config, thread_id: ThreadTracker, // PERF: rand uses syscalls. how to do that async? rand: Random, @@ -67,6 +69,7 @@ pub(crate) struct Worker { impl Worker { pub(crate) async fn new_and_loop( + cfg: Config, thread_id: ThreadTracker, stop_working: ::tokio::sync::broadcast::Receiver, token_check: Option>>, @@ -75,6 +78,7 @@ impl Worker { ) -> ::std::io::Result<()> { // TODO: get a channel to send back information, and send the error let mut worker = Self::new( + cfg, thread_id, stop_working, token_check, @@ -86,6 +90,7 @@ impl Worker { Ok(()) } pub(crate) async fn new( + cfg: Config, thread_id: ThreadTracker, stop_working: ::tokio::sync::broadcast::Receiver, token_check: Option>>, @@ -126,6 +131,7 @@ impl Worker { }; Ok(Self { + cfg, thread_id, rand: Random::new(), stop_working, @@ -156,19 +162,52 @@ impl Worker { let _ = sender.send(conn_num); } Work::Connect((send_res, dnssec_record, _service_id)) => { + // PERF: geolocation + + // Find the first destination with a coherent + // pubkey/key exchange let destination = dnssec_record.addresses.iter().find_map(|addr| { - let maybe_key = - dnssec_record.public_keys.iter().find( - |(id, _)| addr.public_key_ids.contains(id), - ); - match maybe_key { - Some(key) => Some((addr, key)), - None => None, + if addr + .handshake_ids + .iter() + .find(|h_srv| { + self.cfg.handshakes.contains(h_srv) + }) + .is_none() + { + // skip servers with no corresponding + // handshake types + return None; } + + for idx in addr.public_key_idx.iter() { + let key_supported_k_x = + dnssec_record.public_keys[idx.0 as usize] + .1 + .kind() + .key_exchanges(); + match self + .cfg + .key_exchanges + .iter() + .find(|x| key_supported_k_x.contains(x)) + { + Some(exchange) => { + return Some(( + addr, + dnssec_record.public_keys + [idx.0 as usize], + exchange, + )) + } + None => return None, + } + } + return None; }); - let (addr, key) = match destination { - Some((addr, key)) => (addr, key), + let (addr, key, exchange) = match destination { + Some((addr, key, exchange)) => (addr, key, exchange), None => { let _ = send_res.send(Err(crate::Error::Resolution( @@ -178,8 +217,29 @@ impl Worker { continue 'mainloop; } }; - use crate::enc::asym; - let exchange = asym::KeyExchange::X25519DiffieHellman; + let hkdf = match hkdf::client_select_hkdf( + &self.cfg, + &dnssec_record.hkdfs, + ) { + Some(hkdf) => hkdf, + None => { + let _ = send_res + .send(Err(crate::Error::HandshakeNegotiation)); + continue 'mainloop; + } + }; + let cipher = match sym::client_select_cipher( + &self.cfg, + &dnssec_record.ciphers, + ) { + Some(cipher) => cipher, + None => { + let _ = send_res + .send(Err(crate::Error::HandshakeNegotiation)); + continue 'mainloop; + } + }; + let (priv_key, pub_key) = match exchange.new_keypair(&self.rand) { Ok(pair) => pair, @@ -187,10 +247,15 @@ impl Worker { }; // build request /* + let req_data = dirsync::ReqData { + nonce: dirsync::Nonce::new(&self.rand), + client_key_id: + }; let req = dirsync::Req { key_id: key.0, - exchange: exchange, - cipher: 42, + exchange, + hkdf, + cipher, exchange_key: client_pub_key, data: 42, }; diff --git a/src/lib.rs b/src/lib.rs index d09d39e..6156f21 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -62,6 +62,9 @@ pub enum Error { /// Resolution problems. wrong or incomplete DNSSEC data #[error("DNSSEC resolution: {0}")] Resolution(String), + /// No common cryptographic primitives + #[error("No common cryptographic primitives")] + HandshakeNegotiation, } /// Instance of a fenrir endpoint @@ -381,6 +384,7 @@ impl Fenrir { ::tracing::debug!("Spawning thread {}", core); let th_topology = hw_topology.clone(); let th_tokio_rt = tokio_rt.clone(); + let th_config = self.cfg.clone(); let (work_send, work_recv) = ::async_channel::unbounded::(); let th_stop_working = self.stop_working.subscribe(); let th_token_check = self.token_check.clone(); @@ -417,6 +421,7 @@ impl Fenrir { let _ = tk_local.block_on( &th_tokio_rt, Worker::new_and_loop( + th_config, thread_id, th_stop_working, th_token_check, From 08d2755656eb7bd98cc271bad644c4e1e4d1d6b8 Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Thu, 1 Jun 2023 11:48:32 +0200 Subject: [PATCH 17/34] KeyExchange->KeyExchangeKind for consistency Signed-off-by: Luca Fulchir --- src/config/mod.rs | 6 +++--- src/connection/handshake/dirsync.rs | 33 +++++------------------------ src/dnssec/record.rs | 7 +++--- src/enc/asym.rs | 26 +++++++++++------------ src/inner/mod.rs | 2 +- 5 files changed, 26 insertions(+), 48 deletions(-) diff --git a/src/config/mod.rs b/src/config/mod.rs index 3e489b9..09773c3 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -3,7 +3,7 @@ use crate::{ connection::handshake::HandshakeID, - enc::{asym::KeyExchange, hkdf::HkdfKind, sym::CipherKind}, + enc::{asym::KeyExchangeKind, hkdf::HkdfKind, sym::CipherKind}, }; use ::std::{ net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, @@ -25,7 +25,7 @@ pub struct Config { /// Supported handshakes pub handshakes: Vec, /// Supported key exchanges - pub key_exchanges: Vec, + pub key_exchanges: Vec, /// Supported Hkdfs pub hkdfs: Vec, /// Supported Ciphers @@ -47,7 +47,7 @@ impl Default for Config { ], resolvers: Vec::new(), handshakes: [HandshakeID::DirectorySynchronized].to_vec(), - key_exchanges: [KeyExchange::X25519DiffieHellman].to_vec(), + key_exchanges: [KeyExchangeKind::X25519DiffieHellman].to_vec(), hkdfs: [HkdfKind::Sha3].to_vec(), ciphers: [CipherKind::XChaCha20Poly1305].to_vec(), } diff --git a/src/connection/handshake/dirsync.rs b/src/connection/handshake/dirsync.rs index 02bb097..c0c7961 100644 --- a/src/connection/handshake/dirsync.rs +++ b/src/connection/handshake/dirsync.rs @@ -13,7 +13,7 @@ use crate::{ auth, connection::{ProtocolVersion, ID}, enc::{ - asym::{ExchangePubKey, KeyExchange, KeyID}, + asym::{ExchangePubKey, KeyExchangeKind, KeyID}, hkdf::HkdfKind, sym::{CipherKind, HeadLen, Secret, TagLen}, Random, @@ -88,7 +88,7 @@ pub struct Req { /// Id of the server key used for the key exchange pub key_id: KeyID, /// Selected key exchange - pub exchange: KeyExchange, + pub exchange: KeyExchangeKind, /// Selected hkdf pub hkdf: HkdfKind, /// Selected cipher @@ -106,7 +106,7 @@ impl Req { pub fn encrypted_offset(&self) -> usize { ProtocolVersion::len() + KeyID::len() - + KeyExchange::len() + + KeyExchangeKind::len() + HkdfKind::len() + CipherKind::len() + self.exchange_key.kind().pub_len() @@ -121,7 +121,7 @@ impl Req { /// actual length of the directory synchronized request pub fn len(&self) -> usize { KeyID::len() - + KeyExchange::len() + + KeyExchangeKind::len() + HkdfKind::len() + CipherKind::len() + self.exchange_key.kind().pub_len() @@ -149,7 +149,7 @@ impl super::HandshakeParsing for Req { let key_id: KeyID = KeyID(u16::from_le_bytes(raw[0..1].try_into().unwrap())); use ::num_traits::FromPrimitive; - let exchange: KeyExchange = match KeyExchange::from_u8(raw[2]) { + let exchange: KeyExchangeKind = match KeyExchangeKind::from_u8(raw[2]) { Some(exchange) => exchange, None => return Err(Error::Parsing), }; @@ -343,15 +343,6 @@ impl RespInner { RespInner::ClearText(_) => RespData::len(), } } - /* - /// Get the ciptertext, or panic - pub fn ciphertext<'a>(&'a mut self) -> &'a mut VecDeque { - match self { - RespInner::CipherText(data) => data, - _ => panic!(), - } - } - */ /// parse the cleartext pub fn deserialize_as_cleartext(&mut self, raw: &[u8]) { let clear = match self { @@ -369,20 +360,6 @@ impl RespInner { }; *self = RespInner::ClearText(clear); } - /* - /// switch from ciphertext to cleartext - pub fn mark_as_cleartext(&mut self) { - let mut newdata: VecDeque; - match self { - RespInner::CipherText(data) => { - newdata = VecDeque::new(); - ::core::mem::swap(&mut newdata, data); - } - _ => return, - } - *self = RespInner::ClearText(newdata); - } - */ /// serialize, but only if ciphertext pub fn serialize(&self, out: &mut [u8]) { todo!() diff --git a/src/dnssec/record.rs b/src/dnssec/record.rs index 2457c0b..1219bdd 100644 --- a/src/dnssec/record.rs +++ b/src/dnssec/record.rs @@ -46,7 +46,7 @@ use crate::{ connection::handshake::HandshakeID, enc::{ self, - asym::{KeyExchange, KeyID, PubKey}, + asym::{KeyExchangeKind, KeyID, PubKey}, hkdf::HkdfKind, sym::CipherKind, }, @@ -361,7 +361,7 @@ pub struct Record { /// Multiple ones can point to the same authentication server pub addresses: Vec
, /// List of supported key exchanges - pub key_exchanges: Vec, + pub key_exchanges: Vec, /// List of supported key exchanges pub hkdfs: Vec, /// List of supported ciphers @@ -523,7 +523,8 @@ impl Record { return Err(Error::NotEnoughData(bytes_parsed)); } while num_key_exchanges > 0 { - let key_exchange = match KeyExchange::from_u8(raw[bytes_parsed]) { + let key_exchange = match KeyExchangeKind::from_u8(raw[bytes_parsed]) + { Some(key_exchange) => key_exchange, None => { // continue parsing. This could be a new key exchange type diff --git a/src/enc/asym.rs b/src/enc/asym.rs index 49f530c..ad8f41d 100644 --- a/src/enc/asym.rs +++ b/src/enc/asym.rs @@ -72,10 +72,10 @@ impl KeyKind { } } /// Returns the key exchanges supported by this key - pub fn key_exchanges(&self) -> &'static [KeyExchange] { - const EMPTY: [KeyExchange; 0] = []; - const X25519_KEY_EXCHANGES: [KeyExchange; 1] = - [KeyExchange::X25519DiffieHellman]; + pub fn key_exchanges(&self) -> &'static [KeyExchangeKind] { + const EMPTY: [KeyExchangeKind; 0] = []; + const X25519_KEY_EXCHANGES: [KeyExchangeKind; 1] = + [KeyExchangeKind::X25519DiffieHellman]; match self { KeyKind::Ed25519 => &EMPTY, KeyKind::X25519 => &X25519_KEY_EXCHANGES, @@ -88,11 +88,11 @@ impl KeyKind { #[derive(Debug, Copy, Clone, PartialEq, ::num_derive::FromPrimitive)] #[non_exhaustive] #[repr(u8)] -pub enum KeyExchange { +pub enum KeyExchangeKind { /// X25519 Public key X25519DiffieHellman = 0, } -impl KeyExchange { +impl KeyExchangeKind { /// The serialize length of the field pub fn len() -> usize { 1 @@ -103,7 +103,7 @@ impl KeyExchange { rnd: &Random, ) -> Result<(ExchangePrivKey, ExchangePubKey), Error> { match self { - KeyExchange::X25519DiffieHellman => { + KeyExchangeKind::X25519DiffieHellman => { let raw_priv = ::x25519_dalek::StaticSecret::new(rnd); let pub_key = ExchangePubKey::X25519( ::x25519_dalek::PublicKey::from(&raw_priv), @@ -217,12 +217,12 @@ impl ExchangePrivKey { /// Run the key exchange between two keys of the same kind pub fn key_exchange( &self, - exchange: KeyExchange, + exchange: KeyExchangeKind, pub_key: ExchangePubKey, ) -> Result { match self { ExchangePrivKey::X25519(priv_key) => { - if exchange != KeyExchange::X25519DiffieHellman { + if exchange != KeyExchangeKind::X25519DiffieHellman { return Err(Error::UnsupportedKeyExchange); } if let ExchangePubKey::X25519(inner_pub_key) = pub_key { @@ -298,8 +298,8 @@ impl ExchangePubKey { /// Give priority to our list pub fn server_select_key_exchange( cfg: &Config, - client_supported: &Vec, -) -> Option { + client_supported: &Vec, +) -> Option { cfg.key_exchanges .iter() .find(|k| client_supported.contains(k)) @@ -311,8 +311,8 @@ pub fn server_select_key_exchange( /// This is used only in the Directory Synchronized handshake pub fn client_select_key_exchange( cfg: &Config, - server_supported: &Vec, -) -> Option { + server_supported: &Vec, +) -> Option { server_supported .iter() .find(|k| cfg.key_exchanges.contains(k)) diff --git a/src/inner/mod.rs b/src/inner/mod.rs index b782258..ea79f47 100644 --- a/src/inner/mod.rs +++ b/src/inner/mod.rs @@ -70,7 +70,7 @@ pub(crate) struct ThreadTracker { /// (udp_src_sender_port % total_threads) - 1 pub(crate) struct HandshakeTracker { thread_id: ThreadTracker, - key_exchanges: Vec<(asym::KeyKind, asym::KeyExchange)>, + key_exchanges: Vec<(asym::KeyKind, asym::KeyExchangeKind)>, ciphers: Vec, /// ephemeral keys used server side in key exchange keys_srv: Vec, From 5b338c8758066e5d6ad80deb42a0d4eed31e27c6 Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Thu, 1 Jun 2023 12:52:43 +0200 Subject: [PATCH 18/34] More on negotiation and dnssec record verification Signed-off-by: Luca Fulchir --- src/connection/handshake/mod.rs | 7 +++++ src/dnssec/record.rs | 12 +++++-- src/enc/asym.rs | 13 +++++++- src/inner/worker.rs | 56 ++++++++++++++++++++++++++------- src/lib.rs | 3 -- 5 files changed, 74 insertions(+), 17 deletions(-) diff --git a/src/connection/handshake/mod.rs b/src/connection/handshake/mod.rs index 778a003..0e56395 100644 --- a/src/connection/handshake/mod.rs +++ b/src/connection/handshake/mod.rs @@ -11,6 +11,7 @@ use ::std::rc::Rc; /// Handshake errors #[derive(::thiserror::Error, Debug, Copy, Clone)] +#[non_exhaustive] pub enum Error { /// Error while parsing the handshake packet /// TODO: more detailed parsing errors @@ -25,6 +26,12 @@ pub enum Error { /// Not enough data #[error("not enough data")] NotEnoughData, + /// Could not find common cryptography + #[error("Negotiation of keys/hkdfs/ciphers failed")] + Negotiation, + /// Could not generate Keys + #[error("Key generation failed")] + KeyGeneration, } /// List of possible handshakes diff --git a/src/dnssec/record.rs b/src/dnssec/record.rs index 1219bdd..35b7489 100644 --- a/src/dnssec/record.rs +++ b/src/dnssec/record.rs @@ -46,7 +46,7 @@ use crate::{ connection::handshake::HandshakeID, enc::{ self, - asym::{KeyExchangeKind, KeyID, PubKey}, + asym::{ExchangePubKey, KeyExchangeKind, KeyID, PubKey}, hkdf::HkdfKind, sym::CipherKind, }, @@ -498,7 +498,7 @@ impl Record { let (public_key, bytes) = match PubKey::deserialize( &raw[bytes_parsed..(bytes_parsed + pubkey_length)], ) { - Ok(public_key_and_bytes) => public_key_and_bytes, + Ok((public_key, bytes)) => (public_key, bytes), Err(enc::Error::UnsupportedKey(_)) => { // continue parsing. This could be a new pubkey type // that is not supported by an older client @@ -582,6 +582,14 @@ impl Record { if idx.0 as usize >= result.public_keys.len() { return Err(Error::Max16PublicKeys); } + if !result.public_keys[idx.0 as usize] + .1 + .kind() + .capabilities() + .has_exchange() + { + return Err(Error::UnsupportedData(bytes_parsed)); + } } } if bytes_parsed != raw.len() { diff --git a/src/enc/asym.rs b/src/enc/asym.rs index ad8f41d..a35ebbd 100644 --- a/src/enc/asym.rs +++ b/src/enc/asym.rs @@ -41,6 +41,17 @@ pub enum KeyCapabilities { /// All: sign, encrypt, Key Exchange SignEncryptExchage, } +impl KeyCapabilities { + /// Check if this key supports eky exchage + pub fn has_exchange(&self) -> bool { + match self { + KeyCapabilities::Exchange + | KeyCapabilities::SignExchange + | KeyCapabilities::SignEncryptExchage => true, + _ => false, + } + } +} /// Kind of key used in the handshake #[derive(Debug, Copy, Clone, PartialEq, ::num_derive::FromPrimitive)] @@ -253,7 +264,7 @@ impl ExchangePubKey { } /// serialize the key into the buffer /// NOTE: Assumes there is enough space - fn serialize_into(&self, out: &mut [u8]) { + pub fn serialize_into(&self, out: &mut [u8]) { match self { ExchangePubKey::X25519(pk) => { let bytes = pk.as_bytes(); diff --git a/src/inner/worker.rs b/src/inner/worker.rs index 5ce3201..01a1c66 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -5,6 +5,7 @@ use crate::{ connection::{ self, handshake::{ + self, dirsync::{self, DirSync}, Handshake, HandshakeData, }, @@ -13,7 +14,7 @@ use crate::{ }, dnssec, enc::{ - asym::{self, PubKey}, + asym::{self, PrivKey, PubKey}, hkdf::{self, Hkdf, HkdfKind}, sym::{self, Secret}, Random, @@ -181,6 +182,9 @@ impl Worker { return None; } + // make sure this server has a public key + // that supports one of the key exchanges that + // *we* support for idx in addr.public_key_idx.iter() { let key_supported_k_x = dnssec_record.public_keys[idx.0 as usize] @@ -198,7 +202,7 @@ impl Worker { addr, dnssec_record.public_keys [idx.0 as usize], - exchange, + exchange.clone(), )) } None => return None, @@ -217,25 +221,27 @@ impl Worker { continue 'mainloop; } }; - let hkdf = match hkdf::client_select_hkdf( + let hkdf_selected = match hkdf::client_select_hkdf( &self.cfg, &dnssec_record.hkdfs, ) { - Some(hkdf) => hkdf, + Some(hkdf_selected) => hkdf_selected, None => { - let _ = send_res - .send(Err(crate::Error::HandshakeNegotiation)); + let _ = send_res.send(Err( + handshake::Error::Negotiation.into(), + )); continue 'mainloop; } }; - let cipher = match sym::client_select_cipher( + let cipher_selected = match sym::client_select_cipher( &self.cfg, &dnssec_record.ciphers, ) { - Some(cipher) => cipher, + Some(cipher_selected) => cipher_selected, None => { - let _ = send_res - .send(Err(crate::Error::HandshakeNegotiation)); + let _ = send_res.send(Err( + handshake::Error::Negotiation.into(), + )); continue 'mainloop; } }; @@ -243,8 +249,36 @@ impl Worker { let (priv_key, pub_key) = match exchange.new_keypair(&self.rand) { Ok(pair) => pair, - Err(_) => todo!(), + Err(_) => { + ::tracing::error!("Failed to generate keys"); + let _ = send_res.send(Err( + handshake::Error::KeyGeneration.into(), + )); + continue 'mainloop; + } }; + let hkdf; + if let PubKey::Exchange(srv_pub) = key.1 { + let secret = + match priv_key.key_exchange(exchange, srv_pub) { + Ok(secret) => secret, + Err(_) => { + ::tracing::warn!( + "Could not run the key exchange" + ); + let _ = send_res.send(Err( + handshake::Error::Negotiation.into(), + )); + continue 'mainloop; + } + }; + hkdf = Hkdf::new(hkdf_selected, b"fenrir", secret); + } else { + // crate::dnssec already verifies that the keys + // listed in dnssec::Record.addresses.public_key_idx + // are PubKey::Exchange + unreachable!() + } // build request /* let req_data = dirsync::ReqData { diff --git a/src/lib.rs b/src/lib.rs index 6156f21..b604f2d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -62,9 +62,6 @@ pub enum Error { /// Resolution problems. wrong or incomplete DNSSEC data #[error("DNSSEC resolution: {0}")] Resolution(String), - /// No common cryptographic primitives - #[error("No common cryptographic primitives")] - HandshakeNegotiation, } /// Instance of a fenrir endpoint From 9634fbba31d19649610c716dc45875d1efc46c3d Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Thu, 1 Jun 2023 12:56:52 +0200 Subject: [PATCH 19/34] Move enc::sym::Secret to enc::Secret Signed-off-by: Luca Fulchir --- src/connection/handshake/dirsync.rs | 4 +-- src/enc/asym.rs | 2 +- src/enc/hkdf.rs | 2 +- src/enc/mod.rs | 40 +++++++++++++++++++++++++ src/enc/sym.rs | 45 +++-------------------------- src/inner/worker.rs | 4 +-- 6 files changed, 50 insertions(+), 47 deletions(-) diff --git a/src/connection/handshake/dirsync.rs b/src/connection/handshake/dirsync.rs index c0c7961..031f789 100644 --- a/src/connection/handshake/dirsync.rs +++ b/src/connection/handshake/dirsync.rs @@ -15,8 +15,8 @@ use crate::{ enc::{ asym::{ExchangePubKey, KeyExchangeKind, KeyID}, hkdf::HkdfKind, - sym::{CipherKind, HeadLen, Secret, TagLen}, - Random, + sym::{CipherKind, HeadLen, TagLen}, + Random, Secret, }, }; diff --git a/src/enc/asym.rs b/src/enc/asym.rs index a35ebbd..417d502 100644 --- a/src/enc/asym.rs +++ b/src/enc/asym.rs @@ -5,7 +5,7 @@ use ::num_traits::FromPrimitive; use super::Error; use crate::{ config::Config, - enc::{sym::Secret, Random}, + enc::{Random, Secret}, }; /// Public key ID diff --git a/src/enc/hkdf.rs b/src/enc/hkdf.rs index 872b7f0..de888ce 100644 --- a/src/enc/hkdf.rs +++ b/src/enc/hkdf.rs @@ -4,7 +4,7 @@ use ::sha3::Sha3_256; use ::zeroize::Zeroize; -use crate::{config::Config, enc::sym::Secret}; +use crate::{config::Config, enc::Secret}; /// Kind of HKDF #[derive(Debug, Copy, Clone, PartialEq, ::num_derive::FromPrimitive)] diff --git a/src/enc/mod.rs b/src/enc/mod.rs index 4da9a0c..09feb7b 100644 --- a/src/enc/mod.rs +++ b/src/enc/mod.rs @@ -8,6 +8,7 @@ pub mod sym; pub use errors::Error; use ::ring::rand::SecureRandom; +use ::zeroize::Zeroize; /// wrapper where we implement whatever random traint stuff each library needs pub struct Random { @@ -72,3 +73,42 @@ impl ::rand_core::RngCore for &Random { } } impl ::rand_core::CryptoRng for &Random {} + +/// Secret, used for keys. +/// Grants that on drop() we will zero out memory +#[derive(Zeroize, Clone)] +#[zeroize(drop)] +pub struct Secret([u8; 32]); +// Fake debug implementation to avoid leaking secrets +impl ::core::fmt::Debug for Secret { + fn fmt( + &self, + f: &mut core::fmt::Formatter<'_>, + ) -> Result<(), ::std::fmt::Error> { + ::core::fmt::Debug::fmt("[hidden secret]", f) + } +} + +impl Secret { + /// New randomly generated secret + pub fn new_rand(rand: &Random) -> Self { + let mut ret = Self([0; 32]); + rand.fill(&mut ret.0); + ret + } + /// return a reference to the secret + pub fn as_ref(&self) -> &[u8; 32] { + &self.0 + } +} +impl From<[u8; 32]> for Secret { + fn from(shared_secret: [u8; 32]) -> Self { + Self(shared_secret) + } +} + +impl From<::x25519_dalek::SharedSecret> for Secret { + fn from(shared_secret: ::x25519_dalek::SharedSecret) -> Self { + Self(shared_secret.to_bytes()) + } +} diff --git a/src/enc/sym.rs b/src/enc/sym.rs index f8d76e3..0a204b9 100644 --- a/src/enc/sym.rs +++ b/src/enc/sym.rs @@ -1,49 +1,12 @@ //! Symmetric cypher stuff use super::Error; -use crate::{config::Config, enc::Random}; +use crate::{ + config::Config, + enc::{Random, Secret}, +}; use ::zeroize::Zeroize; -/// Secret, used for keys. -/// Grants that on drop() we will zero out memory -#[derive(Zeroize, Clone)] -#[zeroize(drop)] -pub struct Secret([u8; 32]); -// Fake debug implementation to avoid leaking secrets -impl ::core::fmt::Debug for Secret { - fn fmt( - &self, - f: &mut core::fmt::Formatter<'_>, - ) -> Result<(), ::std::fmt::Error> { - ::core::fmt::Debug::fmt("[hidden secret]", f) - } -} - -impl Secret { - /// New randomly generated secret - pub fn new_rand(rand: &Random) -> Self { - let mut ret = Self([0; 32]); - rand.fill(&mut ret.0); - ret - } - /// return a reference to the secret - pub fn as_ref(&self) -> &[u8; 32] { - &self.0 - } -} - -impl From<[u8; 32]> for Secret { - fn from(shared_secret: [u8; 32]) -> Self { - Self(shared_secret) - } -} - -impl From<::x25519_dalek::SharedSecret> for Secret { - fn from(shared_secret: ::x25519_dalek::SharedSecret) -> Self { - Self(shared_secret.to_bytes()) - } -} - /// List of possible Ciphers #[derive(Debug, Copy, Clone, PartialEq, ::num_derive::FromPrimitive)] #[repr(u8)] diff --git a/src/inner/worker.rs b/src/inner/worker.rs index 01a1c66..96a52f4 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -16,8 +16,8 @@ use crate::{ enc::{ asym::{self, PrivKey, PubKey}, hkdf::{self, Hkdf, HkdfKind}, - sym::{self, Secret}, - Random, + sym::{self}, + Random, Secret, }, inner::{HandshakeAction, HandshakeTracker, ThreadTracker}, }; From 289c6c318e25e695553fe338fb69500864efb17a Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Mon, 5 Jun 2023 09:18:32 +0200 Subject: [PATCH 20/34] More work on Dirsync request sending Signed-off-by: Luca Fulchir --- src/auth/mod.rs | 26 +++++- src/connection/handshake/mod.rs | 117 ++++++++++++++++++++++--- src/connection/mod.rs | 9 +- src/dnssec/mod.rs | 6 +- src/enc/asym.rs | 2 + src/enc/hkdf.rs | 6 ++ src/enc/sym.rs | 2 +- src/inner/mod.rs | 79 ++++++++--------- src/inner/worker.rs | 148 ++++++++++++++++++++++---------- src/lib.rs | 24 ++++-- 10 files changed, 302 insertions(+), 117 deletions(-) diff --git a/src/auth/mod.rs b/src/auth/mod.rs index 955cb12..464e02a 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -16,9 +16,17 @@ impl From<[u8; 16]> for UserID { impl UserID { /// New random user id pub fn new(rand: &Random) -> Self { - let mut ret = Self([0; 16]); - rand.fill(&mut ret.0); - ret + use ::core::mem::MaybeUninit; + let mut out: MaybeUninit<[u8; 16]> = MaybeUninit::uninit(); + #[allow(unsafe_code)] + unsafe { + let _ = rand.fill(out.assume_init_mut()); + Self(out.assume_init()) + } + } + /// Anonymous user id + pub fn new_anonymous() -> Self { + UserID([0; 16]) } /// length of the User ID in bytes pub const fn len() -> usize { @@ -31,6 +39,16 @@ impl UserID { pub struct Token([u8; 32]); impl Token { + /// New random token, anonymous should not check this anyway + pub fn new_anonymous(rand: &Random) -> Self { + use ::core::mem::MaybeUninit; + let mut out: MaybeUninit<[u8; 32]> = MaybeUninit::uninit(); + #[allow(unsafe_code)] + unsafe { + let _ = rand.fill(out.assume_init_mut()); + Self(out.assume_init()) + } + } /// length of the token in bytes pub const fn len() -> usize { 32 @@ -68,7 +86,7 @@ pub type TokenChecker = /// further limit to a "safe" subset of utf8 // SECURITY: TODO: limit to a subset of utf8 #[derive(Debug, Clone, PartialEq)] -pub struct Domain(String); +pub struct Domain(pub String); impl TryFrom<&[u8]> for Domain { type Error = (); diff --git a/src/connection/handshake/mod.rs b/src/connection/handshake/mod.rs index 0e56395..273b57e 100644 --- a/src/connection/handshake/mod.rs +++ b/src/connection/handshake/mod.rs @@ -3,11 +3,15 @@ pub mod dirsync; use crate::{ - connection::{self, ProtocolVersion}, - enc::sym::{HeadLen, TagLen}, + auth::ServiceID, + connection::{self, Connection, IDRecv, ProtocolVersion}, + enc::{ + asym::{KeyID, PrivKey, PubKey}, + sym::{HeadLen, TagLen}, + }, }; use ::num_traits::FromPrimitive; -use ::std::rc::Rc; +use ::std::{collections::VecDeque, rc::Rc}; /// Handshake errors #[derive(::thiserror::Error, Debug, Copy, Clone)] @@ -32,6 +36,9 @@ pub enum Error { /// Could not generate Keys #[error("Key generation failed")] KeyGeneration, + /// Too many client handshakes currently running + #[error("Too many client handshakes")] + TooManyClientHandshakes, } /// List of possible handshakes @@ -66,20 +73,108 @@ impl TryFrom<&str> for HandshakeID { } pub(crate) struct HandshakeServer { - pub id: crate::enc::asym::KeyID, - pub key: crate::enc::asym::PrivKey, + pub id: KeyID, + pub key: PrivKey, } -#[derive(Clone)] pub(crate) struct HandshakeClient { - pub id: crate::enc::asym::KeyID, - pub key: crate::enc::asym::PrivKey, - pub service_id: crate::auth::ServiceID, - pub service_conn_id: connection::IDRecv, - pub connection: Rc, + pub id: KeyID, + pub key: PrivKey, + pub service_id: ServiceID, + pub service_conn_id: IDRecv, + pub connection: Connection, pub timeout: Rc, } +/// Tracks the keys used by the client and the handshake +/// they are associated with +pub(crate) struct HandshakeClientList { + used: Vec<::bitmaps::Bitmap<1024>>, // index = KeyID + keys: Vec>, + list: Vec>, +} + +impl HandshakeClientList { + pub(crate) fn new() -> Self { + Self { + used: [::bitmaps::Bitmap::<1024>::new()].to_vec(), + keys: Vec::with_capacity(16), + list: Vec::with_capacity(16), + } + } + pub(crate) fn get(&self, id: KeyID) -> Option<&HandshakeClient> { + if id.0 as usize >= self.list.len() { + return None; + } + self.list[id.0 as usize].as_ref() + } + pub(crate) fn remove(&mut self, id: KeyID) -> Option { + if id.0 as usize >= self.list.len() { + return None; + } + let used_vec_idx = id.0 as usize / 1024; + let used_bitmap_idx = id.0 as usize % 1024; + let used_iter = match self.used.get_mut(used_vec_idx) { + Some(used_iter) => used_iter, + None => return None, + }; + used_iter.set(used_bitmap_idx, false); + self.keys[id.0 as usize] = None; + let mut owned = None; + ::core::mem::swap(&mut self.list[id.0 as usize], &mut owned); + owned + } + pub(crate) fn add( + &mut self, + priv_key: PrivKey, + pub_key: PubKey, + service_id: ServiceID, + service_conn_id: IDRecv, + connection: Connection, + ) -> Result<(KeyID, &HandshakeClient), ()> { + let maybe_free_key_idx = + self.used.iter().enumerate().find_map(|(idx, bmap)| { + match bmap.first_false_index() { + Some(false_idx) => Some(((idx * 1024), false_idx)), + None => None, + } + }); + let free_key_idx = match maybe_free_key_idx { + Some((idx, false_idx)) => { + let free_key_idx = idx * 1024 + false_idx; + if free_key_idx > KeyID::MAX as usize { + return Err(()); + } + self.used[idx].set(false_idx, true); + free_key_idx + } + None => { + let mut bmap = ::bitmaps::Bitmap::<1024>::new(); + bmap.set(0, true); + self.used.push(bmap); + self.used.len() * 1024 + } + }; + if self.keys.len() >= free_key_idx { + self.keys.push(None); + self.list.push(None); + } + self.keys[free_key_idx] = Some((priv_key.clone(), pub_key)); + self.list[free_key_idx] = Some(HandshakeClient { + id: KeyID(free_key_idx as u16), + key: priv_key, + service_id, + service_conn_id, + connection, + timeout: Rc::new(0), + }); + Ok(( + KeyID(free_key_idx as u16), + self.list[free_key_idx].as_ref().unwrap(), + )) + } +} + /// Parsed handshake #[derive(Debug, Clone)] pub enum HandshakeData { diff --git a/src/connection/mod.rs b/src/connection/mod.rs index 7c1d733..8dac30c 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -132,10 +132,7 @@ impl ConnList { } /// Only *Reserve* a connection, /// without actually tracking it in self.connections - pub(crate) fn reserve_first( - &mut self, - mut conn: Connection, - ) -> Rc { + pub(crate) fn reserve_first(&mut self) -> IDRecv { // uhm... bad things are going on here: // * id must be initialized, but only because: // * rust does not understand that after the `!found` id is always @@ -173,9 +170,7 @@ impl ConnList { let actual_id = ((id_in_thread as u64) * (self.thread_id.total as u64)) + (self.thread_id.id as u64); let new_id = IDRecv(ID::new_u64(actual_id)); - conn.id_recv = new_id; - // Return the new connection without tracking it - Rc::new(conn) + new_id } /// NOTE: does NOT check if the connection has been previously reserved! pub(crate) fn track(&mut self, conn: Rc) -> Result<(), ()> { diff --git a/src/dnssec/mod.rs b/src/dnssec/mod.rs index 912321d..143863c 100644 --- a/src/dnssec/mod.rs +++ b/src/dnssec/mod.rs @@ -7,6 +7,8 @@ use ::trust_dns_resolver::TokioAsyncResolver; pub mod record; pub use record::Record; +use crate::auth::Domain; + /// Common errors for Dnssec setup and usage #[derive(::thiserror::Error, Debug)] pub enum Error { @@ -88,10 +90,10 @@ impl Dnssec { } const TXT_RECORD_START: &str = "v=Fenrir1 "; /// Get the fenrir data for a domain - pub async fn resolv(&self, domain: &str) -> ::std::io::Result { + pub async fn resolv(&self, domain: &Domain) -> ::std::io::Result { use ::trust_dns_client::rr::Name; - let fqdn_str = "_fenrir.".to_owned() + domain; + let fqdn_str = "_fenrir.".to_owned() + &domain.0; ::tracing::debug!("Resolving: {}", fqdn_str); let fqdn = Name::from_utf8(&fqdn_str)?; let answers = self.resolver.txt_lookup(fqdn).await?; diff --git a/src/enc/asym.rs b/src/enc/asym.rs index 417d502..83e8dd5 100644 --- a/src/enc/asym.rs +++ b/src/enc/asym.rs @@ -17,6 +17,8 @@ impl KeyID { pub const fn len() -> usize { 2 } + /// Maximum possible KeyID + pub const MAX: u16 = u16::MAX; /// Serialize into raw bytes pub fn serialize(&self, out: &mut [u8; KeyID::len()]) { out.copy_from_slice(&self.0.to_le_bytes()); diff --git a/src/enc/hkdf.rs b/src/enc/hkdf.rs index de888ce..c8a706d 100644 --- a/src/enc/hkdf.rs +++ b/src/enc/hkdf.rs @@ -51,6 +51,12 @@ impl Hkdf { Hkdf::Sha3(sha3) => sha3.get_secret(context), } } + /// get the kind of this Hkdf + pub fn kind(&self) -> HkdfKind { + match self { + Hkdf::Sha3(_) => HkdfKind::Sha3, + } + } } // Hack & tricks: diff --git a/src/enc/sym.rs b/src/enc/sym.rs index 0a204b9..4d28c64 100644 --- a/src/enc/sym.rs +++ b/src/enc/sym.rs @@ -39,7 +39,7 @@ impl CipherKind { /// Additional Authenticated Data #[derive(Debug)] -pub struct AAD<'a>(pub &'a mut [u8]); +pub struct AAD<'a>(pub &'a [u8]); /// Cipher direction, to make sure we don't reuse the same cipher /// for both decrypting and encrypting diff --git a/src/inner/mod.rs b/src/inner/mod.rs index ea79f47..147984f 100644 --- a/src/inner/mod.rs +++ b/src/inner/mod.rs @@ -5,20 +5,24 @@ pub(crate) mod worker; use crate::{ - auth, + auth::ServiceID, connection::{ self, - handshake::{self, Handshake, HandshakeClient, HandshakeServer}, - Connection, + handshake::{ + self, Handshake, HandshakeClient, HandshakeClientList, + HandshakeServer, + }, + Connection, IDRecv, }, enc::{ - self, asym, + self, + asym::{self, KeyID, PrivKey, PubKey}, hkdf::{Hkdf, HkdfKind}, sym::{CipherKind, CipherRecv}, }, Error, }; -use ::std::{rc::Rc, vec::Vec}; +use ::std::vec::Vec; /// Information needed to reply after the key exchange #[derive(Debug, Clone)] @@ -35,13 +39,13 @@ pub(crate) struct AuthNeededInfo { #[derive(Debug)] pub(crate) struct ClientConnectInfo { /// The service ID that we are connecting to - pub service_id: auth::ServiceID, + pub service_id: ServiceID, /// The service ID that we are connecting to - pub service_connection_id: connection::IDRecv, + pub service_connection_id: IDRecv, /// Parsed handshake packet pub handshake: Handshake, /// Connection - pub connection: Rc, + pub connection: Connection, } /// Intermediate actions to be taken while parsing the handshake #[derive(Debug)] @@ -63,11 +67,11 @@ pub(crate) struct ThreadTracker { pub id: u16, } -/// Async free but thread safe tracking of handhsakes and conenctions +/// Tracking of handhsakes and conenctions /// Note that we have multiple Handshake trackers, pinned to different cores /// Each of them will handle a subset of all handshakes. -/// Each handshake is routed to a different tracker with: -/// (udp_src_sender_port % total_threads) - 1 +/// Each handshake is routed to a different tracker by checking +/// core = (udp_src_sender_port % total_threads) - 1 pub(crate) struct HandshakeTracker { thread_id: ThreadTracker, key_exchanges: Vec<(asym::KeyKind, asym::KeyExchangeKind)>, @@ -75,12 +79,8 @@ pub(crate) struct HandshakeTracker { /// ephemeral keys used server side in key exchange keys_srv: Vec, /// ephemeral keys used client side in key exchange - hshake_cli: Vec, + hshake_cli: HandshakeClientList, } -#[allow(unsafe_code)] -unsafe impl Send for HandshakeTracker {} -#[allow(unsafe_code)] -unsafe impl Sync for HandshakeTracker {} impl HandshakeTracker { pub(crate) fn new(thread_id: ThreadTracker) -> Self { @@ -89,9 +89,25 @@ impl HandshakeTracker { ciphers: Vec::new(), key_exchanges: Vec::new(), keys_srv: Vec::new(), - hshake_cli: Vec::new(), + hshake_cli: HandshakeClientList::new(), } } + pub(crate) fn new_client( + &mut self, + priv_key: PrivKey, + pub_key: PubKey, + service_id: ServiceID, + service_conn_id: IDRecv, + connection: Connection, + ) -> Result<(KeyID, &HandshakeClient), ()> { + self.hshake_cli.add( + priv_key, + pub_key, + service_id, + service_conn_id, + connection, + ) + } pub(crate) fn recv_handshake( &mut self, mut handshake: Handshake, @@ -105,7 +121,6 @@ impl HandshakeTracker { if let Some(h_k) = self.keys_srv.iter().find(|k| k.id == req.key_id) { - use enc::asym::PrivKey; // Directory synchronized can only use keys // for key exchange, not signing keys if let PrivKey::Exchange(k) = &h_k.key { @@ -175,20 +190,9 @@ impl HandshakeTracker { })); } DirSync::Resp(resp) => { - let hshake_idx = { - match self - .hshake_cli - .iter() - .position(|h| h.id == resp.client_key_id) - { - Some(h) => Some(h.clone()), - None => None, - } - }; - let hshake_idx = { - if let Some(real_idx) = hshake_idx { - real_idx - } else { + let hshake = match self.hshake_cli.get(resp.client_key_id) { + Some(hshake) => hshake, + None => { ::tracing::debug!( "No such client key id: {:?}", resp.client_key_id @@ -196,7 +200,6 @@ impl HandshakeTracker { return Err(handshake::Error::UnknownKeyID.into()); } }; - let hshake = &self.hshake_cli[hshake_idx]; let cipher_recv = &hshake.connection.cipher_recv; use crate::enc::sym::AAD; // no aad for now @@ -212,14 +215,8 @@ impl HandshakeTracker { return Err(handshake::Error::Key(e).into()); } } - // we can remove the handshake from the list - let hshake: HandshakeClient = { - let len = self.hshake_cli.len(); - if (hshake_idx + 1) != len { - self.hshake_cli.swap(hshake_idx, len - 1); - } - self.hshake_cli.pop().unwrap() - }; + let hshake = + self.hshake_cli.remove(resp.client_key_id).unwrap(); return Ok(HandshakeAction::ClientConnect( ClientConnectInfo { service_id: hshake.service_id, diff --git a/src/inner/worker.rs b/src/inner/worker.rs index 96a52f4..3210229 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -1,6 +1,6 @@ //! Worker thread implementation use crate::{ - auth::{ServiceID, TokenChecker}, + auth::{Domain, ServiceID, Token, TokenChecker, UserID}, config::Config, connection::{ self, @@ -14,10 +14,9 @@ use crate::{ }, dnssec, enc::{ - asym::{self, PrivKey, PubKey}, + asym::{PrivKey, PubKey}, hkdf::{self, Hkdf, HkdfKind}, - sym::{self}, - Random, Secret, + sym, Random, Secret, }, inner::{HandshakeAction, HandshakeTracker, ThreadTracker}, }; @@ -36,17 +35,19 @@ pub(crate) struct RawUdp { pub packet: Packet, } +pub(crate) struct ConnectInfo { + pub answer: oneshot::Sender>, + pub resolved: dnssec::Record, + pub service_id: ServiceID, + pub domain: Domain, + // TODO: UserID, Token information +} + pub(crate) enum Work { /// ask the thread to report to the main thread the total number of /// connections present CountConnections(oneshot::Sender), - Connect( - ( - oneshot::Sender>, - dnssec::Record, - ServiceID, - ), - ), + Connect(ConnectInfo), Recv(RawUdp), } pub(crate) enum WorkAnswer { @@ -162,13 +163,13 @@ impl Worker { let conn_num = self.connections.len(); let _ = sender.send(conn_num); } - Work::Connect((send_res, dnssec_record, _service_id)) => { + Work::Connect(conn_info) => { // PERF: geolocation // Find the first destination with a coherent // pubkey/key exchange let destination = - dnssec_record.addresses.iter().find_map(|addr| { + conn_info.resolved.addresses.iter().find_map(|addr| { if addr .handshake_ids .iter() @@ -187,7 +188,8 @@ impl Worker { // *we* support for idx in addr.public_key_idx.iter() { let key_supported_k_x = - dnssec_record.public_keys[idx.0 as usize] + conn_info.resolved.public_keys + [idx.0 as usize] .1 .kind() .key_exchanges(); @@ -200,7 +202,7 @@ impl Worker { Some(exchange) => { return Some(( addr, - dnssec_record.public_keys + conn_info.resolved.public_keys [idx.0 as usize], exchange.clone(), )) @@ -214,7 +216,9 @@ impl Worker { Some((addr, key, exchange)) => (addr, key, exchange), None => { let _ = - send_res.send(Err(crate::Error::Resolution( + conn_info + .answer + .send(Err(crate::Error::Resolution( "No selectable address and key combination" .to_owned(), ))); @@ -223,11 +227,11 @@ impl Worker { }; let hkdf_selected = match hkdf::client_select_hkdf( &self.cfg, - &dnssec_record.hkdfs, + &conn_info.resolved.hkdfs, ) { Some(hkdf_selected) => hkdf_selected, None => { - let _ = send_res.send(Err( + let _ = conn_info.answer.send(Err( handshake::Error::Negotiation.into(), )); continue 'mainloop; @@ -235,23 +239,24 @@ impl Worker { }; let cipher_selected = match sym::client_select_cipher( &self.cfg, - &dnssec_record.ciphers, + &conn_info.resolved.ciphers, ) { Some(cipher_selected) => cipher_selected, None => { - let _ = send_res.send(Err( + let _ = conn_info.answer.send(Err( handshake::Error::Negotiation.into(), )); continue 'mainloop; } }; + // FIXME: save KeyID let (priv_key, pub_key) = match exchange.new_keypair(&self.rand) { Ok(pair) => pair, Err(_) => { ::tracing::error!("Failed to generate keys"); - let _ = send_res.send(Err( + let _ = conn_info.answer.send(Err( handshake::Error::KeyGeneration.into(), )); continue 'mainloop; @@ -266,7 +271,7 @@ impl Worker { ::tracing::warn!( "Could not run the key exchange" ); - let _ = send_res.send(Err( + let _ = conn_info.answer.send(Err( handshake::Error::Negotiation.into(), )); continue 'mainloop; @@ -279,25 +284,79 @@ impl Worker { // are PubKey::Exchange unreachable!() } + let mut conn = Connection::new( + hkdf, + cipher_selected, + connection::Role::Client, + &self.rand, + ); + + let auth_recv_id = self.connections.reserve_first(); + let service_conn_id = self.connections.reserve_first(); + conn.id_recv = auth_recv_id; + let (client_key_id, hshake) = match self + .handshakes + .new_client( + PrivKey::Exchange(priv_key), + PubKey::Exchange(pub_key), + conn_info.service_id, + service_conn_id, + conn, + ) { + Ok((client_key_id, hshake)) => (client_key_id, hshake), + Err(_) => { + ::tracing::warn!("Too many client handshakes"); + let _ = conn_info.answer.send(Err( + handshake::Error::TooManyClientHandshakes + .into(), + )); + continue 'mainloop; + } + }; + // build request - /* + let auth_info = dirsync::AuthInfo { + user: UserID::new_anonymous(), + token: Token::new_anonymous(&self.rand), + service_id: conn_info.service_id, + domain: conn_info.domain, + }; let req_data = dirsync::ReqData { nonce: dirsync::Nonce::new(&self.rand), - client_key_id: + client_key_id, + id: auth_recv_id.0, + auth: auth_info, }; let req = dirsync::Req { key_id: key.0, exchange, - hkdf, - cipher, - exchange_key: client_pub_key, - data: 42, + hkdf: hkdf_selected, + cipher: cipher_selected, + exchange_key: pub_key, + data: dirsync::ReqInner::ClearText(req_data), }; - */ + let mut raw = Vec::::with_capacity(req.len()); + req.serialize( + cipher_selected.nonce_len(), + cipher_selected.tag_len(), + &mut raw[..], + ); + // encrypt + let encrypt_start = req.encrypted_offset(); + let encrypt_end = encrypt_start + req.encrypted_length(); + if let Err(e) = hshake.connection.cipher_send.encrypt( + sym::AAD(&[]), + &mut raw[encrypt_start..encrypt_end], + ) { + ::tracing::error!("Can't encrypt DirSync Request"); + let _ = conn_info.answer.send(Err(e.into())); + continue 'mainloop; + } // start timeout - // send packet + // send packeti + //self.send_packet(raw, todo!() } @@ -395,15 +454,16 @@ impl Worker { let head_len = req.cipher.nonce_len(); let tag_len = req.cipher.tag_len(); - let mut raw_conn = Connection::new( + let mut auth_conn = Connection::new( authinfo.hkdf, req.cipher, connection::Role::Server, &self.rand, ); - raw_conn.id_send = IDSend(req_data.id); + auth_conn.id_send = IDSend(req_data.id); // track connection - let auth_conn = self.connections.reserve_first(raw_conn); + let auth_id_recv = self.connections.reserve_first(); + auth_conn.id_recv = auth_id_recv; let resp_data = dirsync::RespData { client_nonce: req_data.nonce, @@ -444,7 +504,7 @@ impl Worker { self.send_packet(raw_out, udp.src, udp.dst).await; return; } - HandshakeAction::ClientConnect(mut cci) => { + HandshakeAction::ClientConnect(cci) => { let ds_resp; if let HandshakeData::DirSync(DirSync::Resp(resp)) = cci.handshake.data @@ -465,17 +525,17 @@ impl Worker { ); return; } - { - let conn = Rc::get_mut(&mut cci.connection).unwrap(); - conn.id_send = IDSend(resp_data.id); - } + let mut conn = cci.connection; + conn.id_send = IDSend(resp_data.id); + let id_recv = conn.id_recv; + let cipher = conn.cipher_recv.kind(); // track the connection to the authentication server - if self.connections.track(cci.connection.clone()).is_err() { - self.connections.delete(cci.connection.id_recv); + if self.connections.track(Rc::new(conn)).is_err() { + ::tracing::error!("Could not track new connection"); + self.connections.delete(id_recv); + return; } - if cci.connection.id_recv.0 - == resp_data.service_connection_id - { + if id_recv.0 == resp_data.service_connection_id { // the user asked a single connection // to the authentication server, without any additional // service. No more connections to setup @@ -492,7 +552,7 @@ impl Worker { ); let mut service_connection = Connection::new( hkdf, - cci.connection.cipher_recv.kind(), + cipher, connection::Role::Client, &self.rand, ); diff --git a/src/lib.rs b/src/lib.rs index b604f2d..508c29f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,17 +22,16 @@ mod inner; use ::std::{sync::Arc, vec::Vec}; use ::tokio::{net::UdpSocket, sync::Mutex}; -use auth::ServiceID; use crate::{ - auth::TokenChecker, + auth::{Domain, ServiceID, TokenChecker}, connection::{ handshake, socket::{SocketList, UdpClient, UdpServer}, AuthServerConnections, Packet, }, inner::{ - worker::{RawUdp, Work, Worker}, + worker::{ConnectInfo, RawUdp, Work, Worker}, ThreadTracker, }, }; @@ -62,6 +61,9 @@ pub enum Error { /// Resolution problems. wrong or incomplete DNSSEC data #[error("DNSSEC resolution: {0}")] Resolution(String), + /// Wrapper on encryption errors + #[error("Encrypt: {0}")] + Encrypt(enc::Error), } /// Instance of a fenrir endpoint @@ -232,7 +234,7 @@ impl Fenrir { Ok(()) } /// Get the raw TXT record of a Fenrir domain - pub async fn resolv_txt(&self, domain: &str) -> Result { + pub async fn resolv_txt(&self, domain: &Domain) -> Result { match &self.dnssec { Some(dnssec) => Ok(dnssec.resolv(domain).await?), None => Err(Error::NotInitialized), @@ -240,7 +242,10 @@ impl Fenrir { } /// Get the raw TXT record of a Fenrir domain - pub async fn resolv(&self, domain: &str) -> Result { + pub async fn resolv( + &self, + domain: &Domain, + ) -> Result { let record_str = self.resolv_txt(domain).await?; Ok(dnssec::Dnssec::parse_txt_record(&record_str)?) } @@ -248,7 +253,7 @@ impl Fenrir { /// Connect to a service pub async fn connect( &self, - domain: &str, + domain: &Domain, service: ServiceID, ) -> Result<(), Error> { let resolved = self.resolv(domain).await?; @@ -310,7 +315,12 @@ impl Fenrir { // and tell that thread to connect somewhere let (send, recv) = ::tokio::sync::oneshot::channel(); let _ = self._thread_work[thread_idx] - .send(Work::Connect((send, resolved.clone(), service))) + .send(Work::Connect(ConnectInfo { + answer: send, + resolved: resolved.clone(), + service_id: service, + domain: domain.clone(), + })) .await; match recv.await { From 3e09b9cee0417ea6525612d57b2f6e5b95ab63ce Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Mon, 5 Jun 2023 10:33:25 +0200 Subject: [PATCH 21/34] Send initial dirsync packet and handshake timeout Signed-off-by: Luca Fulchir --- src/connection/handshake/mod.rs | 14 +++--- src/dnssec/record.rs | 11 +++++ src/inner/mod.rs | 12 ++++- src/inner/worker.rs | 77 +++++++++++++++++++++++++-------- 4 files changed, 87 insertions(+), 27 deletions(-) diff --git a/src/connection/handshake/mod.rs b/src/connection/handshake/mod.rs index 273b57e..427bff4 100644 --- a/src/connection/handshake/mod.rs +++ b/src/connection/handshake/mod.rs @@ -78,12 +78,10 @@ pub(crate) struct HandshakeServer { } pub(crate) struct HandshakeClient { - pub id: KeyID, - pub key: PrivKey, pub service_id: ServiceID, pub service_conn_id: IDRecv, pub connection: Connection, - pub timeout: Rc, + pub timeout: Option<::tokio::task::JoinHandle<()>>, } /// Tracks the keys used by the client and the handshake @@ -131,7 +129,7 @@ impl HandshakeClientList { service_id: ServiceID, service_conn_id: IDRecv, connection: Connection, - ) -> Result<(KeyID, &HandshakeClient), ()> { + ) -> Result<(KeyID, &mut HandshakeClient), ()> { let maybe_free_key_idx = self.used.iter().enumerate().find_map(|(idx, bmap)| { match bmap.first_false_index() { @@ -159,18 +157,16 @@ impl HandshakeClientList { self.keys.push(None); self.list.push(None); } - self.keys[free_key_idx] = Some((priv_key.clone(), pub_key)); + self.keys[free_key_idx] = Some((priv_key, pub_key)); self.list[free_key_idx] = Some(HandshakeClient { - id: KeyID(free_key_idx as u16), - key: priv_key, service_id, service_conn_id, connection, - timeout: Rc::new(0), + timeout: None, }); Ok(( KeyID(free_key_idx as u16), - self.list[free_key_idx].as_ref().unwrap(), + self.list[free_key_idx].as_mut().unwrap(), )) } } diff --git a/src/dnssec/record.rs b/src/dnssec/record.rs index 35b7489..ab5065b 100644 --- a/src/dnssec/record.rs +++ b/src/dnssec/record.rs @@ -186,6 +186,17 @@ pub struct Address { } impl Address { + /// return this Address as a socket address + /// Note that since Fenrir can work on top of IP, without ports, + /// this is not guaranteed to return a SocketAddr + pub fn as_sockaddr(&self) -> Option<::std::net::SocketAddr> { + match self.port { + Some(port) => { + Some(::std::net::SocketAddr::new(self.ip, port.get())) + } + None => None, + } + } fn raw_len(&self) -> usize { // UDP port + Priority + Weight + pubkey_len + handshake_len let mut size = 6; diff --git a/src/inner/mod.rs b/src/inner/mod.rs index 147984f..67add2a 100644 --- a/src/inner/mod.rs +++ b/src/inner/mod.rs @@ -99,7 +99,7 @@ impl HandshakeTracker { service_id: ServiceID, service_conn_id: IDRecv, connection: Connection, - ) -> Result<(KeyID, &HandshakeClient), ()> { + ) -> Result<(KeyID, &mut HandshakeClient), ()> { self.hshake_cli.add( priv_key, pub_key, @@ -108,6 +108,16 @@ impl HandshakeTracker { connection, ) } + pub(crate) fn timeout_client( + &mut self, + key_id: KeyID, + ) -> Option<[IDRecv; 2]> { + if let Some(hshake) = self.hshake_cli.remove(key_id) { + Some([hshake.connection.id_recv, hshake.service_conn_id]) + } else { + None + } + } pub(crate) fn recv_handshake( &mut self, mut handshake: Handshake, diff --git a/src/inner/worker.rs b/src/inner/worker.rs index 3210229..642d9d2 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -14,7 +14,7 @@ use crate::{ }, dnssec, enc::{ - asym::{PrivKey, PubKey}, + asym::{KeyID, PrivKey, PubKey}, hkdf::{self, Hkdf, HkdfKind}, sym, Random, Secret, }, @@ -24,7 +24,7 @@ use ::std::{rc::Rc, sync::Arc, vec::Vec}; /// This worker must be cpu-pinned use ::tokio::{ net::UdpSocket, - sync::{oneshot, Mutex}, + sync::{mpsc, oneshot, Mutex}, }; /// Track a raw Udp packet @@ -48,6 +48,7 @@ pub(crate) enum Work { /// connections present CountConnections(oneshot::Sender), Connect(ConnectInfo), + DropHandshake(KeyID), Recv(RawUdp), } pub(crate) enum WorkAnswer { @@ -64,6 +65,8 @@ pub(crate) struct Worker { token_check: Option>>, sockets: Vec, queue: ::async_channel::Receiver, + queue_timeouts_recv: mpsc::UnboundedReceiver, + queue_timeouts_send: mpsc::UnboundedSender, thread_channels: Vec<::async_channel::Sender>, connections: ConnList, handshakes: HandshakeTracker, @@ -132,6 +135,8 @@ impl Worker { } }; + let (queue_timeouts_send, queue_timeouts_recv) = + mpsc::unbounded_channel(); Ok(Self { cfg, thread_id, @@ -140,6 +145,8 @@ impl Worker { token_check, sockets, queue, + queue_timeouts_recv, + queue_timeouts_send, thread_channels: Vec::new(), connections: ConnList::new(thread_id), handshakes: HandshakeTracker::new(thread_id), @@ -151,6 +158,12 @@ impl Worker { _done = self.stop_working.recv() => { break; } + maybe_timeout = self.queue.recv() => { + match maybe_timeout { + Ok(work) => work, + Err(_) => break, + } + } maybe_work = self.queue.recv() => { match maybe_work { Ok(work) => work, @@ -166,20 +179,22 @@ impl Worker { Work::Connect(conn_info) => { // PERF: geolocation - // Find the first destination with a coherent - // pubkey/key exchange + // Find the first destination with: + // * UDP port + // * a coherent pubkey/key exchange. let destination = conn_info.resolved.addresses.iter().find_map(|addr| { - if addr - .handshake_ids - .iter() - .find(|h_srv| { - self.cfg.handshakes.contains(h_srv) - }) - .is_none() + if addr.port.is_none() + || addr + .handshake_ids + .iter() + .find(|h_srv| { + self.cfg.handshakes.contains(h_srv) + }) + .is_none() { // skip servers with no corresponding - // handshake types + // handshake types or no udp port return None; } @@ -250,7 +265,6 @@ impl Worker { } }; - // FIXME: save KeyID let (priv_key, pub_key) = match exchange.new_keypair(&self.rand) { Ok(pair) => pair, @@ -353,12 +367,34 @@ impl Worker { continue 'mainloop; } - // start timeout + // send always from the first socket + // FIXME: select based on routing table + let sender = self.sockets[0].local_addr().unwrap(); + let dest = UdpServer(addr.as_sockaddr().unwrap()); - // send packeti - //self.send_packet(raw, + // start the timeout right before sending the packet + hshake.timeout = Some(::tokio::task::spawn_local( + Self::handshake_timeout( + self.queue_timeouts_send.clone(), + client_key_id, + ), + )); - todo!() + // send packet + self.send_packet(raw, UdpClient(sender), dest).await; + + continue 'mainloop; + } + Work::DropHandshake(key_id) => { + if let Some(connections) = + self.handshakes.timeout_client(key_id) + { + for conn_id in connections.into_iter() { + if !conn_id.0.is_handshake() { + self.connections.delete(conn_id); + } + } + }; } //TODO: reconf message to add channels Work::Recv(pkt) => { @@ -367,6 +403,13 @@ impl Worker { } } } + async fn handshake_timeout( + timeout_queue: mpsc::UnboundedSender, + key_id: KeyID, + ) { + ::tokio::time::sleep(::std::time::Duration::from_secs(10)).await; + let _ = timeout_queue.send(Work::DropHandshake(key_id)); + } /// Read and do stuff with the raw udp packet async fn recv(&mut self, mut udp: RawUdp) { if udp.packet.id.is_handshake() { From 6da5464c68083f99851f76b911c9898ccef97382 Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Tue, 6 Jun 2023 22:37:34 +0200 Subject: [PATCH 22/34] Helpers for dnssec Signed-off-by: Luca Fulchir --- src/auth/mod.rs | 4 +- src/connection/mod.rs | 2 +- src/enc/asym.rs | 121 ++++++++++++++++++++++++++++++++++++++---- src/enc/hkdf.rs | 11 +++- src/enc/mod.rs | 4 ++ src/enc/sym.rs | 11 +++- src/inner/mod.rs | 5 +- src/inner/worker.rs | 21 +++----- 8 files changed, 149 insertions(+), 30 deletions(-) diff --git a/src/auth/mod.rs b/src/auth/mod.rs index 464e02a..aef9cbc 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -106,8 +106,10 @@ impl Domain { } } +/// Reserve the first service ID for the authentication service +pub const SERVICEID_AUTH: ServiceID = ServiceID([0; 16]); /// The Service ID is a UUID associated with the service. -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, PartialEq)] pub struct ServiceID([u8; 16]); impl From<[u8; 16]> for ServiceID { diff --git a/src/connection/mod.rs b/src/connection/mod.rs index 8dac30c..ee71c4e 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -185,7 +185,7 @@ impl ConnList { self.connections[id_in_thread] = Some(conn); Ok(()) } - pub(crate) fn delete(&mut self, id: IDRecv) { + pub(crate) fn remove(&mut self, id: IDRecv) { if let IDRecv(ID::ID(raw_id)) = id { let id_in_thread: usize = (raw_id.get() / (self.thread_id.total as u64)) as usize; diff --git a/src/enc/asym.rs b/src/enc/asym.rs index 83e8dd5..ea67ac3 100644 --- a/src/enc/asym.rs +++ b/src/enc/asym.rs @@ -25,6 +25,24 @@ impl KeyID { } } +impl TryFrom<&str> for KeyID { + type Error = ::std::io::Error; + fn try_from(raw: &str) -> Result { + if let Ok(id_u16) = raw.parse::() { + return Ok(KeyID(id_u16)); + } + return Err(::std::io::Error::new( + ::std::io::ErrorKind::InvalidData, + "KeyID must be between 0 and 65535", + )); + } +} +impl ::std::fmt::Display for KeyID { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + write!(f, "{}", self.0) + } +} + /// Capabilities of each key #[derive(Debug, Clone, Copy)] pub enum KeyCapabilities { @@ -56,13 +74,23 @@ impl KeyCapabilities { } /// Kind of key used in the handshake -#[derive(Debug, Copy, Clone, PartialEq, ::num_derive::FromPrimitive)] +#[derive( + Debug, + Copy, + Clone, + PartialEq, + ::num_derive::FromPrimitive, + ::strum_macros::EnumString, + ::strum_macros::IntoStaticStr, +)] #[non_exhaustive] #[repr(u8)] pub enum KeyKind { /// Ed25519 Public key (sign only) + #[strum(serialize = "ed25519")] Ed25519 = 0, /// X25519 Public key (key exchange) + #[strum(serialize = "x25519")] X25519, } // FIXME: actually check this @@ -73,8 +101,7 @@ impl KeyKind { match self { // FIXME: 99% wrong size KeyKind::Ed25519 => ::ring::signature::ED25519_PUBLIC_KEY_LEN, - // FIXME: 99% wrong size - KeyKind::X25519 => ::ring::signature::ED25519_PUBLIC_KEY_LEN, + KeyKind::X25519 => 32, } } /// Get the capabilities of this key type @@ -94,15 +121,30 @@ impl KeyKind { KeyKind::X25519 => &X25519_KEY_EXCHANGES, } } + /// generate new keypair + pub fn new_keypair( + &self, + rnd: &Random, + ) -> Result<(PrivKey, PubKey), Error> { + PubKey::new_keypair(*self, rnd) + } } -// FIXME: rename in KeyExchangeKind /// Kind of key exchange -#[derive(Debug, Copy, Clone, PartialEq, ::num_derive::FromPrimitive)] +#[derive( + Debug, + Copy, + Clone, + PartialEq, + ::num_derive::FromPrimitive, + ::strum_macros::EnumString, + ::strum_macros::IntoStaticStr, +)] #[non_exhaustive] #[repr(u8)] pub enum KeyExchangeKind { /// X25519 Public key + #[strum(serialize = "x25519diffiehellman")] X25519DiffieHellman = 0, } impl KeyExchangeKind { @@ -141,6 +183,13 @@ pub enum PubKey { } impl PubKey { + /// Get the serialized key length + pub fn len(&self) -> usize { + match self { + PubKey::Exchange(ex) => ex.len(), + PubKey::Signing => todo!(), + } + } /// return the kind of public key pub fn kind(&self) -> KeyKind { match self { @@ -149,6 +198,20 @@ impl PubKey { PubKey::Exchange(ex) => ex.kind(), } } + /// generate new keypair + fn new_keypair( + kind: KeyKind, + rnd: &Random, + ) -> Result<(PrivKey, PubKey), Error> { + match kind { + KeyKind::Ed25519 => todo!(), + KeyKind::X25519 => { + let (priv_key, pub_key) = + KeyExchangeKind::X25519DiffieHellman.new_keypair(rnd)?; + Ok((PrivKey::Exchange(priv_key), PubKey::Exchange(pub_key))) + } + } + } /// serialize the key into the buffer /// NOTE: Assumes there is enough space pub fn serialize_into(&self, out: &mut [u8]) { @@ -211,6 +274,24 @@ pub enum PrivKey { Signing, } +impl PrivKey { + /// Get the serialized key length + pub fn len(&self) -> usize { + match self { + PrivKey::Exchange(ex) => ex.len(), + PrivKey::Signing => todo!(), + } + } + /// serialize the key into the buffer + /// NOTE: Assumes there is enough space + pub fn serialize_into(&self, out: &mut [u8]) { + match self { + PrivKey::Exchange(ex) => ex.serialize_into(out), + PrivKey::Signing => todo!(), + } + } +} + /// Ephemeral private keys #[derive(Clone)] #[allow(missing_debug_implementations)] @@ -221,6 +302,12 @@ pub enum ExchangePrivKey { } impl ExchangePrivKey { + /// Get the serialized key length + pub fn len(&self) -> usize { + match self { + ExchangePrivKey::X25519(_) => KeyKind::X25519.pub_len(), + } + } /// Get the kind of key pub fn kind(&self) -> KeyKind { match self { @@ -238,12 +325,18 @@ impl ExchangePrivKey { if exchange != KeyExchangeKind::X25519DiffieHellman { return Err(Error::UnsupportedKeyExchange); } - if let ExchangePubKey::X25519(inner_pub_key) = pub_key { - let shared_secret = priv_key.diffie_hellman(&inner_pub_key); - Ok(shared_secret.into()) - } else { - Err(Error::UnsupportedKeyExchange) - } + let ExchangePubKey::X25519(inner_pub_key) = pub_key; + let shared_secret = priv_key.diffie_hellman(&inner_pub_key); + Ok(shared_secret.into()) + } + } + } + /// serialize the key into the buffer + /// NOTE: Assumes there is enough space + pub fn serialize_into(&self, out: &mut [u8]) { + match self { + ExchangePrivKey::X25519(key) => { + out[..32].copy_from_slice(&key.to_bytes()); } } } @@ -258,6 +351,12 @@ pub enum ExchangePubKey { } impl ExchangePubKey { + /// Get the serialized key length + pub fn len(&self) -> usize { + match self { + ExchangePubKey::X25519(_) => KeyKind::X25519.pub_len(), + } + } /// Get the kind of key pub fn kind(&self) -> KeyKind { match self { diff --git a/src/enc/hkdf.rs b/src/enc/hkdf.rs index c8a706d..d81276f 100644 --- a/src/enc/hkdf.rs +++ b/src/enc/hkdf.rs @@ -7,11 +7,20 @@ use ::zeroize::Zeroize; use crate::{config::Config, enc::Secret}; /// Kind of HKDF -#[derive(Debug, Copy, Clone, PartialEq, ::num_derive::FromPrimitive)] +#[derive( + Debug, + Copy, + Clone, + PartialEq, + ::num_derive::FromPrimitive, + ::strum_macros::EnumString, + ::strum_macros::IntoStaticStr, +)] #[non_exhaustive] #[repr(u8)] pub enum HkdfKind { /// Sha3 + #[strum(serialize = "sha3")] Sha3 = 0, } impl HkdfKind { diff --git a/src/enc/mod.rs b/src/enc/mod.rs index 09feb7b..06a6200 100644 --- a/src/enc/mod.rs +++ b/src/enc/mod.rs @@ -27,6 +27,10 @@ impl Random { pub fn fill(&self, out: &mut [u8]) { self.rnd.fill(out); } + /// return the underlying ring SystemRandom + pub fn ring_rnd(&self) -> &::ring::rand::SystemRandom { + &self.rnd + } } // Fake debug implementation to avoid leaking secrets diff --git a/src/enc/sym.rs b/src/enc/sym.rs index 4d28c64..37b5d78 100644 --- a/src/enc/sym.rs +++ b/src/enc/sym.rs @@ -8,10 +8,19 @@ use crate::{ use ::zeroize::Zeroize; /// List of possible Ciphers -#[derive(Debug, Copy, Clone, PartialEq, ::num_derive::FromPrimitive)] +#[derive( + Debug, + Copy, + Clone, + PartialEq, + ::num_derive::FromPrimitive, + ::strum_macros::EnumString, + ::strum_macros::IntoStaticStr, +)] #[repr(u8)] pub enum CipherKind { /// XChaCha20_Poly1305 + #[strum(serialize = "xchacha20poly1305")] XChaCha20Poly1305 = 0, } diff --git a/src/inner/mod.rs b/src/inner/mod.rs index 67add2a..24ff170 100644 --- a/src/inner/mod.rs +++ b/src/inner/mod.rs @@ -51,7 +51,7 @@ pub(crate) struct ClientConnectInfo { #[derive(Debug)] pub(crate) enum HandshakeAction { /// Parsing finished, all ok, nothing to do - None, + Nonthing, /// Packet parsed, now go perform authentication AuthNeeded(AuthNeededInfo), /// the client can fully establish a connection with this info @@ -227,6 +227,9 @@ impl HandshakeTracker { } let hshake = self.hshake_cli.remove(resp.client_key_id).unwrap(); + if let Some(timeout) = hshake.timeout { + timeout.abort(); + } return Ok(HandshakeAction::ClientConnect( ClientConnectInfo { service_id: hshake.service_id, diff --git a/src/inner/worker.rs b/src/inner/worker.rs index 642d9d2..0daec59 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -1,6 +1,6 @@ //! Worker thread implementation use crate::{ - auth::{Domain, ServiceID, Token, TokenChecker, UserID}, + auth::{self, Domain, ServiceID, Token, TokenChecker, UserID}, config::Config, connection::{ self, @@ -390,13 +390,10 @@ impl Worker { self.handshakes.timeout_client(key_id) { for conn_id in connections.into_iter() { - if !conn_id.0.is_handshake() { - self.connections.delete(conn_id); - } + self.connections.remove(conn_id); } }; } - //TODO: reconf message to add channels Work::Recv(pkt) => { self.recv(pkt).await; } @@ -545,7 +542,6 @@ impl Worker { return; } self.send_packet(raw_out, udp.src, udp.dst).await; - return; } HandshakeAction::ClientConnect(cci) => { let ds_resp; @@ -573,19 +569,19 @@ impl Worker { let id_recv = conn.id_recv; let cipher = conn.cipher_recv.kind(); // track the connection to the authentication server - if self.connections.track(Rc::new(conn)).is_err() { + if self.connections.track(conn.into()).is_err() { ::tracing::error!("Could not track new connection"); - self.connections.delete(id_recv); + self.connections.remove(id_recv); return; } - if id_recv.0 == resp_data.service_connection_id { + if cci.service_id == auth::SERVICEID_AUTH { // the user asked a single connection // to the authentication server, without any additional // service. No more connections to setup return; } // create and track the connection to the service - // SECURITY: + // SECURITY: xor with secrets //FIXME: the Secret should be XORed with the client stored // secret (if any) let hkdf = Hkdf::new( @@ -603,13 +599,10 @@ impl Worker { service_connection.id_send = IDSend(resp_data.service_connection_id); let _ = self.connections.track(service_connection.into()); - return; } - _ => {} + HandshakeAction::Nonthing => {} }; } - // copy packet, spawn - todo!(); } async fn send_packet( &self, From 787e11e8e42bdc7b9c69e9b3d7b67f532b9a9df1 Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Wed, 7 Jun 2023 11:07:46 +0200 Subject: [PATCH 23/34] Fixes for Hati Signed-off-by: Luca Fulchir --- src/enc/asym.rs | 12 ++++++++++-- src/lib.rs | 1 + 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/enc/asym.rs b/src/enc/asym.rs index ea67ac3..4b6c09f 100644 --- a/src/enc/asym.rs +++ b/src/enc/asym.rs @@ -282,11 +282,19 @@ impl PrivKey { PrivKey::Signing => todo!(), } } + /// return the kind of public key + pub fn kind(&self) -> KeyKind { + match self { + PrivKey::Signing => todo!(), + PrivKey::Exchange(ex) => ex.kind(), + } + } /// serialize the key into the buffer /// NOTE: Assumes there is enough space pub fn serialize_into(&self, out: &mut [u8]) { + out[0] = self.kind() as u8; match self { - PrivKey::Exchange(ex) => ex.serialize_into(out), + PrivKey::Exchange(ex) => ex.serialize_into(&mut out[1..]), PrivKey::Signing => todo!(), } } @@ -336,7 +344,7 @@ impl ExchangePrivKey { pub fn serialize_into(&self, out: &mut [u8]) { match self { ExchangePrivKey::X25519(key) => { - out[..32].copy_from_slice(&key.to_bytes()); + out[0..32].copy_from_slice(&key.to_bytes()); } } } diff --git a/src/lib.rs b/src/lib.rs index 508c29f..cac07b0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -354,6 +354,7 @@ impl Fenrir { } } + // TODO: start work on a LocalSet provided by the user /// Start one working thread for each physical cpu /// threads are pinned to each cpu core. /// Work will be divided and rerouted so that there is no need to lock From 55e10a60c696583edfa758a5519e1c0b4244bb60 Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Fri, 9 Jun 2023 14:55:49 +0200 Subject: [PATCH 24/34] Fix Dnssec record serializing/deserializing Signed-off-by: Luca Fulchir --- flake.nix | 2 + src/connection/handshake/mod.rs | 32 +++----- src/dnssec/mod.rs | 42 +++++++++- src/dnssec/record.rs | 136 +++++++++++++++++++------------- src/enc/asym.rs | 16 ++-- src/enc/errors.rs | 4 +- src/inner/worker.rs | 29 +++++-- 7 files changed, 172 insertions(+), 89 deletions(-) diff --git a/flake.nix b/flake.nix index fd0713c..7a74671 100644 --- a/flake.nix +++ b/flake.nix @@ -43,6 +43,8 @@ rust-bin.stable."1.69.0".default rustfmt rust-analyzer + # fenrir deps + hwloc ]; shellHook = '' # use zsh or other custom shell diff --git a/src/connection/handshake/mod.rs b/src/connection/handshake/mod.rs index 427bff4..14ba8fd 100644 --- a/src/connection/handshake/mod.rs +++ b/src/connection/handshake/mod.rs @@ -42,36 +42,28 @@ pub enum Error { } /// List of possible handshakes -#[derive(::num_derive::FromPrimitive, Debug, Clone, Copy, PartialEq)] +#[derive( + ::num_derive::FromPrimitive, + Debug, + Clone, + Copy, + PartialEq, + ::strum_macros::EnumString, + ::strum_macros::IntoStaticStr, +)] #[repr(u8)] pub enum HandshakeID { /// 1-RTT Directory synchronized handshake. Fast, no forward secrecy + #[strum(serialize = "directory_synchronized")] DirectorySynchronized = 0, /// 2-RTT Stateful exchange. Little DDos protection + #[strum(serialize = "stateful")] Stateful, /// 3-RTT stateless exchange. Forward secrecy and ddos protection + #[strum(serialize = "stateless")] Stateless, } -impl TryFrom<&str> for HandshakeID { - type Error = ::std::io::Error; - // TODO: from actual names, not only numeric - fn try_from(raw: &str) -> Result { - if let Ok(handshake_u8) = raw.parse::() { - if handshake_u8 >= 1 { - if let Some(handshake) = HandshakeID::from_u8(handshake_u8 - 1) - { - return Ok(handshake); - } - } - } - return Err(::std::io::Error::new( - ::std::io::ErrorKind::InvalidData, - "Unknown handshake ID", - )); - } -} - pub(crate) struct HandshakeServer { pub id: KeyID, pub key: PrivKey, diff --git a/src/dnssec/mod.rs b/src/dnssec/mod.rs index 143863c..bb479ba 100644 --- a/src/dnssec/mod.rs +++ b/src/dnssec/mod.rs @@ -21,6 +21,9 @@ pub enum Error { /// Errors in establishing connection or connection drops #[error("nameserver connection: {0:?}")] NameserverConnection(String), + /// record is not valid base85 + #[error("invalid base85")] + InvalidBase85, } #[cfg(any( @@ -137,9 +140,46 @@ impl Dnssec { ::tracing::error!("Can't parse record: {}", e); Err(::std::io::Error::new( ::std::io::ErrorKind::InvalidData, - "Can't parse the record", + "Can't parse the record: ".to_owned() + &e.to_string(), )) } }; } } +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_serialization() { + // The record was generated with: + // f-dnssec generate dnssec \ + // -a 1 2 42 directory_synchronized 127.0.0.1 31337 \ + // -p 42 x25519 x25519.pub \ + // -x x25519diffiehellman \ + // -c xchacha20poly1305 + const TXT_RECORD: &'static str = "v=Fenrir1 \ + 5fBgo5ovk=0Dk}g0V)6>0cKP8KO-Vna846zp@MaLF|nim_XH&nQvT-I|B9HfJpcd"; + + let record = match Dnssec::parse_txt_record(TXT_RECORD) { + Ok(record) => record, + Err(e) => { + assert!(false, "{}", e.to_string()); + return; + } + }; + let re_encoded = match record.encode() { + Ok(re_encoded) => re_encoded, + Err(e) => { + assert!(false, "{}", e.to_string()); + return; + } + }; + assert!( + TXT_RECORD[10..] == re_encoded, + "DNSSEC record decoding->encoding failed:\n{}\n{}", + TXT_RECORD, + re_encoded + ); + } +} diff --git a/src/dnssec/record.rs b/src/dnssec/record.rs index ab5065b..f5595e2 100644 --- a/src/dnssec/record.rs +++ b/src/dnssec/record.rs @@ -167,7 +167,7 @@ impl TryFrom<&str> for AddressWeight { /// * priority /// * weight within priority /// * list of supported handshakes IDs -/// * list of public keys IDs +/// * list of public keys. Indexes in the Record.public_keys #[derive(Debug, Clone)] pub struct Address { /// Ip address of server, v4 or v6 @@ -197,18 +197,17 @@ impl Address { None => None, } } - fn raw_len(&self) -> usize { - // UDP port + Priority + Weight + pubkey_len + handshake_len - let mut size = 6; + fn len(&self) -> usize { + let mut size = 4; let num_pubkey_idx = self.public_key_idx.len(); let idx_bytes = (num_pubkey_idx / 2) + (num_pubkey_idx % 2); size = size + idx_bytes + self.handshake_ids.len(); size + match self.ip { - IpAddr::V4(_) => size + 4, - IpAddr::V6(_) => size + 16, + IpAddr::V4(_) => 4, + IpAddr::V6(_) => 16, } } - fn encode_into(&self, raw: &mut Vec) { + fn serialize_into(&self, raw: &mut [u8]) { let mut bitfield: u8 = match self.ip { IpAddr::V4(_) => 0, IpAddr::V6(_) => 1, @@ -218,18 +217,19 @@ impl Address { bitfield <<= 3; bitfield |= self.weight as u8; - raw.push(bitfield); + raw[0] = bitfield; let len_combined: u8 = self.public_key_idx.len() as u8; let len_combined = len_combined << 4; let len_combined = len_combined | self.handshake_ids.len() as u8; - raw.push(len_combined); + raw[1] = len_combined; - raw.extend_from_slice( + raw[2..4].copy_from_slice( &(match self.port { Some(port) => port.get().to_le_bytes(), None => [0, 0], // oh noez, which zero goes first? }), ); + let mut written: usize = 4; // pair every idx, since the max is 16 for chunk in self.public_key_idx.chunks(2) { @@ -242,34 +242,51 @@ impl Address { }; let tmp = chunk[0].0 << 4; let tmp = tmp | second; - raw.push(tmp); + raw[written] = tmp; + written = written + 1; } for id in self.handshake_ids.iter() { - raw.push(*id as u8); + raw[written] = *id as u8; + written = written + 1; } + let next_written; match self.ip { IpAddr::V4(ip) => { + next_written = written + 4; let raw_ip = ip.octets(); - raw.extend_from_slice(&raw_ip); + raw[written..next_written].copy_from_slice(&raw_ip); } IpAddr::V6(ip) => { + next_written = written + 16; let raw_ip = ip.octets(); - raw.extend_from_slice(&raw_ip); + raw[written..next_written].copy_from_slice(&raw_ip); } }; + assert!( + next_written == raw.len(), + "write how much? {} - {}", + next_written, + raw.len() + ); } fn decode_raw(raw: &[u8]) -> Result<(Self, usize), Error> { - if raw.len() < 10 { + if raw.len() < 9 { return Err(Error::NotEnoughData(0)); } + // 3-byte bitfield let ip_type = raw[0] >> 6; let is_ipv6: bool; + let ip_len: usize; match ip_type { 0 => { is_ipv6 = false; + ip_len = 4; + } + 1 => { + is_ipv6 = true; + ip_len = 16; } - 1 => is_ipv6 = true, _ => return Err(Error::UnsupportedData(0)), } let raw_priority = (raw[0] << 2) >> 5; @@ -277,18 +294,19 @@ impl Address { let priority = AddressPriority::from_u8(raw_priority).unwrap(); let weight = AddressWeight::from_u8(raw_weight).unwrap(); + // Add publickey ids + let num_pubkey_idx = (raw[1] >> 4) as usize; + let num_handshake_ids = (raw[1] & 0x0F) as usize; + // UDP port - let raw_port = u16::from_le_bytes([raw[1], raw[2]]); + let raw_port = u16::from_le_bytes([raw[2], raw[3]]); let port = if raw_port == 0 { None } else { Some(NonZeroU16::new(raw_port).unwrap()) }; - // Add publickey ids - let num_pubkey_idx = (raw[3] >> 4) as usize; - let num_handshake_ids = (raw[3] & 0x0F) as usize; - if raw.len() <= 3 + num_pubkey_idx + num_handshake_ids { + if raw.len() <= 3 + num_pubkey_idx + num_handshake_ids + ip_len { return Err(Error::NotEnoughData(3)); } let mut bytes_parsed = 4; @@ -306,9 +324,9 @@ impl Address { } idx_added = idx_added + 2; } + bytes_parsed = bytes_parsed + idx_bytes; // add handshake ids - bytes_parsed = bytes_parsed + idx_bytes; let mut handshake_ids = Vec::with_capacity(num_handshake_ids); for raw_handshake_id in raw[bytes_parsed..(bytes_parsed + num_handshake_ids)].iter() @@ -407,57 +425,69 @@ impl Record { // everything else is all good let total_size: usize = 3 - + self.addresses.iter().map(|a| a.raw_len()).sum::() + + self.addresses.iter().map(|a| a.len()).sum::() + self .public_keys .iter() - .map(|(_, key)| 3 + key.kind().pub_len()) + .map(|(_, key)| 4 + key.kind().pub_len()) .sum::() + self.key_exchanges.len() + self.hkdfs.len() + self.ciphers.len(); let mut raw = Vec::with_capacity(total_size); + raw.resize(total_size, 0); // amount of data. addresses, then pubkeys. 4 bits each let len_combined: u8 = self.addresses.len() as u8; let len_combined = len_combined << 4; let len_combined = len_combined | self.public_keys.len() as u8; - raw.push(len_combined); + raw[0] = len_combined; // number of key exchanges and hkdfs let len_combined: u8 = self.key_exchanges.len() as u8; let len_combined = len_combined << 4; let len_combined = len_combined | self.hkdfs.len() as u8; - raw.push(len_combined); + raw[1] = len_combined; let num_of_ciphers: u8 = (self.ciphers.len() as u8) << 4; - raw.push(num_of_ciphers); + raw[2] = num_of_ciphers; + let mut written: usize = 3; for address in self.addresses.iter() { - address.encode_into(&mut raw); + let len = address.len(); + let written_next = written + len; + address.serialize_into(&mut raw[written..written_next]); + written = written_next; } for (public_key_id, public_key) in self.public_keys.iter() { let key_id_bytes = public_key_id.0.to_le_bytes(); - raw.extend_from_slice(&key_id_bytes); - raw.push(public_key.kind().pub_len() as u8); - raw.push(public_key.kind() as u8); - public_key.serialize_into(&mut raw); + let written_next = written + KeyID::len(); + raw[written..written_next].copy_from_slice(&key_id_bytes); + written = written_next; + raw[written] = public_key.kind().pub_len() as u8; + written = written + 1; + let written_next = written + public_key.len(); + public_key.serialize_into(&mut raw[written..written_next]); + written = written_next; } for k_x in self.key_exchanges.iter() { - raw.push(*k_x as u8); + raw[written] = *k_x as u8; + written = written + 1; } for h in self.hkdfs.iter() { - raw.push(*h as u8); + raw[written] = *h as u8; + written = written + 1; } for c in self.ciphers.iter() { - raw.push(*c as u8); + raw[written] = *c as u8; + written = written + 1; } Ok(::base85::encode(&raw)) } /// Decode from base85 to the actual object pub fn decode(raw: &[u8]) -> Result { - // bare minimum for lengths, (1 address), (1 key), cipher negotiation - const MIN_RAW_LENGTH: usize = 3 + (6 + 4) + (4 + 32) + 1 + 1 + 1; + // bare minimum for lengths, (1 address), (1 key),no cipher negotiation + const MIN_RAW_LENGTH: usize = 3 + (6 + 4) + (4 + 32); if raw.len() <= MIN_RAW_LENGTH { return Err(Error::NotEnoughData(0)); } @@ -492,6 +522,7 @@ impl Record { result.addresses.push(address); num_addresses = num_addresses - 1; } + while num_public_keys > 0 { if bytes_parsed + 3 >= raw.len() { return Err(Error::NotEnoughData(bytes_parsed)); @@ -503,24 +534,23 @@ impl Record { bytes_parsed = bytes_parsed + 2; let pubkey_length = raw[bytes_parsed] as usize; bytes_parsed = bytes_parsed + 1; - if pubkey_length + bytes_parsed >= raw.len() { + let bytes_next_key = bytes_parsed + 1 + pubkey_length; + if bytes_next_key > raw.len() { return Err(Error::NotEnoughData(bytes_parsed)); } - let (public_key, bytes) = match PubKey::deserialize( - &raw[bytes_parsed..(bytes_parsed + pubkey_length)], - ) { - Ok((public_key, bytes)) => (public_key, bytes), - Err(enc::Error::UnsupportedKey(_)) => { - // continue parsing. This could be a new pubkey type - // that is not supported by an older client - ::tracing::warn!("Unsupported public key type"); - bytes_parsed = bytes_parsed + pubkey_length; - continue; - } - Err(_) => { - return Err(Error::UnsupportedData(bytes_parsed)); - } - }; + let (public_key, bytes) = + match PubKey::deserialize(&raw[bytes_parsed..bytes_next_key]) { + Ok((public_key, bytes)) => (public_key, bytes), + Err(enc::Error::UnsupportedKey(_)) => { + // continue parsing. This could be a new pubkey type + // that is not supported by an older client + bytes_parsed = bytes_parsed + pubkey_length; + continue; + } + Err(e) => { + return Err(Error::UnsupportedData(bytes_parsed)); + } + }; if bytes != 1 + pubkey_length { return Err(Error::UnsupportedData(bytes_parsed)); } diff --git a/src/enc/asym.rs b/src/enc/asym.rs index 4b6c09f..5d057f2 100644 --- a/src/enc/asym.rs +++ b/src/enc/asym.rs @@ -185,7 +185,7 @@ pub enum PubKey { impl PubKey { /// Get the serialized key length pub fn len(&self) -> usize { - match self { + 1 + match self { PubKey::Exchange(ex) => ex.len(), PubKey::Signing => todo!(), } @@ -231,7 +231,7 @@ impl PubKey { /// Try to deserialize the pubkey from raw bytes /// on success returns the public key and the number of parsed bytes pub fn deserialize(raw: &[u8]) -> Result<(Self, usize), Error> { - if raw.len() < 1 + MIN_KEY_SIZE { + if raw.len() < 1 { return Err(Error::NotEnoughData(0)); } let kind: KeyKind = match KeyKind::from_u8(raw[0]) { @@ -248,14 +248,18 @@ impl PubKey { } KeyKind::X25519 => { let pub_key: ::x25519_dalek::PublicKey = - match ::bincode::deserialize(&raw[1..(1 + kind.pub_len())]) + //match ::bincode::deserialize(&raw[1..(1 + kind.pub_len())]) + match ::bincode::deserialize(&raw[1..]) { Ok(pub_key) => pub_key, - Err(_) => return Err(Error::Parsing), + Err(e) => { + ::tracing::error!("x25519 deserialize: {}", e); + return Err(Error::Parsing); + } }; Ok(( PubKey::Exchange(ExchangePubKey::X25519(pub_key)), - kind.pub_len(), + 1 + kind.pub_len(), )) } } @@ -277,7 +281,7 @@ pub enum PrivKey { impl PrivKey { /// Get the serialized key length pub fn len(&self) -> usize { - match self { + 1 + match self { PrivKey::Exchange(ex) => ex.len(), PrivKey::Signing => todo!(), } diff --git a/src/enc/errors.rs b/src/enc/errors.rs index c0591b0..a394406 100644 --- a/src/enc/errors.rs +++ b/src/enc/errors.rs @@ -7,13 +7,13 @@ pub enum Error { #[error("can't parse key")] Parsing, /// Not enough data - #[error("not enough data")] + #[error("not enough data: {0}")] NotEnoughData(usize), /// buffer too small #[error("buffer too small")] InsufficientBuffer, /// Unsupported Key type found. - #[error("unsupported key type")] + #[error("unsupported key type: {0}")] UnsupportedKey(usize), /// Unsupported key exchange for this key #[error("unsupported key exchange")] diff --git a/src/inner/worker.rs b/src/inner/worker.rs index 0daec59..6d77c03 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -14,7 +14,7 @@ use crate::{ }, dnssec, enc::{ - asym::{KeyID, PrivKey, PubKey}, + asym::{self, KeyID, PrivKey, PubKey}, hkdf::{self, Hkdf, HkdfKind}, sym, Random, Secret, }, @@ -202,18 +202,33 @@ impl Worker { // that supports one of the key exchanges that // *we* support for idx in addr.public_key_idx.iter() { + // for each key, + // get all the possible key exchanges let key_supported_k_x = conn_info.resolved.public_keys [idx.0 as usize] .1 .kind() .key_exchanges(); - match self - .cfg - .key_exchanges - .iter() - .find(|x| key_supported_k_x.contains(x)) - { + // consider only the key exchanges allowed + // in the dnssec record + let filtered_key_exchanges: Vec< + asym::KeyExchangeKind, + > = key_supported_k_x + .into_iter() + .filter(|k_x| { + conn_info + .resolved + .key_exchanges + .contains(k_x) + }) + .copied() + .collect(); + + // finally make sure that we support those + match self.cfg.key_exchanges.iter().find(|x| { + filtered_key_exchanges.contains(x) + }) { Some(exchange) => { return Some(( addr, From 5625bd95a4ea07abe645a0d2d7a6d02be1ecd051 Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Fri, 9 Jun 2023 19:06:58 +0200 Subject: [PATCH 25/34] Test request serialization Signed-off-by: Luca Fulchir --- src/auth/mod.rs | 6 +-- src/connection/handshake/dirsync.rs | 55 +++++++++++++++++++++--- src/connection/handshake/mod.rs | 31 ++++++++++---- src/connection/handshake/tests.rs | 65 +++++++++++++++++++++++++++++ src/dnssec/mod.rs | 54 +++++++++++++++++------- src/dnssec/record.rs | 10 ++--- src/enc/asym.rs | 47 ++++++++++----------- 7 files changed, 205 insertions(+), 63 deletions(-) create mode 100644 src/connection/handshake/tests.rs diff --git a/src/auth/mod.rs b/src/auth/mod.rs index aef9cbc..20df6f7 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -5,7 +5,7 @@ use ::zeroize::Zeroize; /// User identifier. 16 bytes for easy uuid conversion #[derive(Debug, Copy, Clone)] -pub struct UserID([u8; 16]); +pub struct UserID(pub [u8; 16]); impl From<[u8; 16]> for UserID { fn from(raw: [u8; 16]) -> Self { @@ -36,7 +36,7 @@ impl UserID { /// Authentication Token, basically just 32 random bytes #[derive(Clone, Zeroize)] #[zeroize(drop)] -pub struct Token([u8; 32]); +pub struct Token(pub [u8; 32]); impl Token { /// New random token, anonymous should not check this anyway @@ -110,7 +110,7 @@ impl Domain { pub const SERVICEID_AUTH: ServiceID = ServiceID([0; 16]); /// The Service ID is a UUID associated with the service. #[derive(Debug, Copy, Clone, PartialEq)] -pub struct ServiceID([u8; 16]); +pub struct ServiceID(pub [u8; 16]); impl From<[u8; 16]> for ServiceID { fn from(raw: [u8; 16]) -> Self { diff --git a/src/connection/handshake/dirsync.rs b/src/connection/handshake/dirsync.rs index 031f789..e387d2c 100644 --- a/src/connection/handshake/dirsync.rs +++ b/src/connection/handshake/dirsync.rs @@ -97,7 +97,9 @@ pub struct Req { pub exchange_key: ExchangePubKey, /// encrypted data pub data: ReqInner, - // Security: Add padding to min: 1200 bytes to avoid amplification attaks + // SECURITY: TODO: Add padding to min: 1200 bytes + // to avoid amplification attaks + // also: 1200 < 1280 to allow better vpn compatibility } impl Req { @@ -125,7 +127,9 @@ impl Req { + HkdfKind::len() + CipherKind::len() + self.exchange_key.kind().pub_len() + + self.cipher.nonce_len().0 + self.data.len() + + self.cipher.tag_len().0 } /// Serialize into raw bytes /// NOTE: assumes that there is exactly as much buffer as needed @@ -135,8 +139,21 @@ impl Req { tag_len: TagLen, out: &mut [u8], ) { - //assert!(out.len() > , ": not enough buffer to serialize"); - todo!() + out[0..2].copy_from_slice(&self.key_id.0.to_le_bytes()); + out[2] = self.exchange as u8; + out[3] = self.hkdf as u8; + out[4] = self.cipher as u8; + let key_len = self.exchange_key.len(); + let written_next = 5 + key_len; + self.exchange_key.serialize_into(&mut out[5..written_next]); + let written = written_next; + if let ReqInner::ClearText(data) = &self.data { + let from = written + head_len.0; + let to = out.len() - tag_len.0; + data.serialize(&mut out[from..to]); + } else { + unreachable!(); + } } } @@ -147,7 +164,7 @@ impl super::HandshakeParsing for Req { return Err(Error::NotEnoughData); } let key_id: KeyID = - KeyID(u16::from_le_bytes(raw[0..1].try_into().unwrap())); + KeyID(u16::from_le_bytes(raw[0..2].try_into().unwrap())); use ::num_traits::FromPrimitive; let exchange: KeyExchangeKind = match KeyExchangeKind::from_u8(raw[2]) { Some(exchange) => exchange, @@ -161,7 +178,7 @@ impl super::HandshakeParsing for Req { Some(cipher) => cipher, None => return Err(Error::Parsing), }; - let (exchange_key, len) = match ExchangePubKey::from_slice(&raw[5..]) { + let (exchange_key, len) = match ExchangePubKey::deserialize(&raw[5..]) { Ok(exchange_key) => exchange_key, Err(e) => return Err(e.into()), }; @@ -235,6 +252,21 @@ impl AuthInfo { pub fn len(&self) -> usize { Self::MIN_PKT_LEN + self.domain.len() } + /// serialize into a buffer + /// Note: assumes there is enough space + pub fn serialize(&self, out: &mut [u8]) { + out[..auth::UserID::len()].copy_from_slice(&self.user.0); + const WRITTEN_TOKEN: usize = auth::UserID::len() + auth::Token::len(); + out[auth::UserID::len()..WRITTEN_TOKEN].copy_from_slice(&self.token.0); + const WRITTEN_SERVICE_ID: usize = + WRITTEN_TOKEN + auth::ServiceID::len(); + out[WRITTEN_TOKEN..WRITTEN_SERVICE_ID] + .copy_from_slice(&self.service_id.0); + let domain_len = self.domain.0.as_bytes().len() as u8; + out[WRITTEN_SERVICE_ID] = domain_len; + const WRITTEN_DOMAIN_LEN: usize = WRITTEN_SERVICE_ID + 1; + out[WRITTEN_DOMAIN_LEN..].copy_from_slice(&self.domain.0.as_bytes()); + } /// deserialize from raw bytes pub fn deserialize(raw: &[u8]) -> Result { if raw.len() < Self::MIN_PKT_LEN { @@ -295,6 +327,19 @@ impl ReqData { /// Minimum byte length of the request data pub const MIN_PKT_LEN: usize = 16 + KeyID::len() + ID::len() + AuthInfo::MIN_PKT_LEN; + /// serialize into a buffer + /// Note: assumes there is enough space + pub fn serialize(&self, out: &mut [u8]) { + out[..Nonce::len()].copy_from_slice(&self.nonce.0); + const WRITTEN_KEY: usize = Nonce::len() + KeyID::len(); + out[Nonce::len()..WRITTEN_KEY] + .copy_from_slice(&self.client_key_id.0.to_le_bytes()); + const WRITTEN: usize = WRITTEN_KEY; + const WRITTEN_ID: usize = WRITTEN + 8; + out[WRITTEN..WRITTEN_ID] + .copy_from_slice(&self.id.as_u64().to_le_bytes()); + self.auth.serialize(&mut out[WRITTEN_ID..]); + } /// Parse the cleartext raw data pub fn deserialize(raw: &[u8]) -> Result { if raw.len() < Self::MIN_PKT_LEN { diff --git a/src/connection/handshake/mod.rs b/src/connection/handshake/mod.rs index 14ba8fd..bd0b501 100644 --- a/src/connection/handshake/mod.rs +++ b/src/connection/handshake/mod.rs @@ -1,6 +1,8 @@ //! Handhsake handling pub mod dirsync; +#[cfg(test)] +mod tests; use crate::{ auth::ServiceID, @@ -194,7 +196,7 @@ impl HandshakeData { /// Kind of handshake #[derive(::num_derive::FromPrimitive, Debug, Clone, Copy)] #[repr(u8)] -pub enum Kind { +pub enum HandshakeKind { /// 1-RTT, Directory synchronized handshake /// Request DirSyncReq = 0, @@ -210,6 +212,12 @@ pub enum Kind { .... */ } +impl HandshakeKind { + /// Length of the serialized field + pub const fn len() -> usize { + 1 + } +} /// Parsed handshake #[derive(Debug, Clone)] @@ -230,7 +238,7 @@ impl Handshake { } /// return the total length of the handshake pub fn len(&self) -> usize { - ProtocolVersion::len() + self.data.len() + ProtocolVersion::len() + HandshakeKind::len() + self.data.len() } const MIN_PKT_LEN: usize = 8; /// Parse the packet and return the parsed handshake @@ -242,13 +250,15 @@ impl Handshake { Some(fenrir_version) => fenrir_version, None => return Err(Error::Parsing), }; - let handshake_kind = match Kind::from_u8(raw[1]) { + let handshake_kind = match HandshakeKind::from_u8(raw[1]) { Some(handshake_kind) => handshake_kind, None => return Err(Error::Parsing), }; let data = match handshake_kind { - Kind::DirSyncReq => dirsync::Req::deserialize(&raw[2..])?, - Kind::DirSyncResp => dirsync::Resp::deserialize(&raw[2..])?, + HandshakeKind::DirSyncReq => dirsync::Req::deserialize(&raw[2..])?, + HandshakeKind::DirSyncResp => { + dirsync::Resp::deserialize(&raw[2..])? + } }; Ok(Self { fenrir_version, @@ -263,9 +273,14 @@ impl Handshake { tag_len: TagLen, out: &mut [u8], ) { - assert!(out.len() > 1, "Handshake: not enough buffer to serialize"); - self.fenrir_version.serialize(&mut out[0]); - self.data.serialize(head_len, tag_len, &mut out[1..]); + out[0] = self.fenrir_version as u8; + out[1] = match &self.data { + HandshakeData::DirSync(d) => match d { + dirsync::DirSync::Req(_) => HandshakeKind::DirSyncReq, + dirsync::DirSync::Resp(_) => HandshakeKind::DirSyncResp, + }, + } as u8; + self.data.serialize(head_len, tag_len, &mut out[2..]); } } diff --git a/src/connection/handshake/tests.rs b/src/connection/handshake/tests.rs new file mode 100644 index 0000000..df485d7 --- /dev/null +++ b/src/connection/handshake/tests.rs @@ -0,0 +1,65 @@ +use crate::{ + auth, + connection::{handshake::*, ID}, + enc, +}; + +#[test] +fn test_handshake_dirsync_req() { + let rand = enc::Random::new(); + let secret = enc::Secret::new_rand(&rand); + let cipher_send = enc::sym::CipherSend::new( + enc::sym::CipherKind::XChaCha20Poly1305, + secret, + &rand, + ); + + let (_, exchange_key) = + match enc::asym::KeyExchangeKind::X25519DiffieHellman.new_keypair(&rand) + { + Ok(pair) => pair, + Err(_) => { + assert!(false, "Can't generate random keypair"); + return; + } + }; + + let data = dirsync::ReqInner::ClearText(dirsync::ReqData { + nonce: dirsync::Nonce::new(&rand), + client_key_id: KeyID(2424), + id: ID::ID(::core::num::NonZeroU64::new(424242).unwrap()), + auth: dirsync::AuthInfo { + user: auth::UserID::new(&rand), + token: auth::Token::new_anonymous(&rand), + service_id: auth::SERVICEID_AUTH, + domain: auth::Domain("example.com".to_owned()), + }, + }); + + let h_req = Handshake::new(HandshakeData::DirSync(dirsync::DirSync::Req( + dirsync::Req { + key_id: KeyID(4224), + exchange: enc::asym::KeyExchangeKind::X25519DiffieHellman, + hkdf: enc::hkdf::HkdfKind::Sha3, + cipher: enc::sym::CipherKind::XChaCha20Poly1305, + exchange_key, + data, + }, + ))); + + let mut bytes = Vec::::with_capacity(h_req.len()); + bytes.resize(h_req.len(), 0); + h_req.serialize( + cipher_send.kind().nonce_len(), + cipher_send.kind().tag_len(), + &mut bytes, + ); + + let deserialized = match Handshake::deserialize(&bytes) { + Ok(deserialized) => deserialized, + Err(e) => { + assert!(false, "{}", e.to_string()); + return; + } + }; +} diff --git a/src/dnssec/mod.rs b/src/dnssec/mod.rs index bb479ba..bc0f661 100644 --- a/src/dnssec/mod.rs +++ b/src/dnssec/mod.rs @@ -152,34 +152,56 @@ mod tests { #[test] fn test_serialization() { - // The record was generated with: - // f-dnssec generate dnssec \ - // -a 1 2 42 directory_synchronized 127.0.0.1 31337 \ - // -p 42 x25519 x25519.pub \ - // -x x25519diffiehellman \ - // -c xchacha20poly1305 - const TXT_RECORD: &'static str = "v=Fenrir1 \ - 5fBgo5ovk=0Dk}g0V)6>0cKP8KO-Vna846zp@MaLF|nim_XH&nQvT-I|B9HfJpcd"; + let rand = enc::Random::new(); + let (_, exchange_key) = + match enc::asym::KeyExchangeKind::X25519DiffieHellman + .new_keypair(&rand) + { + Ok(pair) => pair, + Err(_) => { + assert!(false, "Can't generate random keypair"); + return; + } + }; + use crate::enc; + let record = Record { + public_keys : [(enc::asym::KeyID(42), + enc::asym::PubKey::Exchange(exchange_key))].to_vec(), + addresses: [record::Address { + ip: ::std::net::IpAddr::V4(::std::net::Ipv4Addr::new(127,0,0,1)), + port: Some(::core::num::NonZeroU16::new(31337).unwrap()), + priority: record::AddressPriority::P1, + weight: record::AddressWeight::W1, + handshake_ids: [crate::connection::handshake::HandshakeID::DirectorySynchronized].to_vec(), + public_key_idx : [record::PubKeyIdx(0)].to_vec(), - let record = match Dnssec::parse_txt_record(TXT_RECORD) { + }].to_vec(), + key_exchanges: [enc::asym::KeyExchangeKind::X25519DiffieHellman].to_vec(), + hkdfs: [enc::hkdf::HkdfKind::Sha3].to_vec(), + ciphers: [enc::sym::CipherKind::XChaCha20Poly1305].to_vec(), + + }; + let encoded = match record.encode() { + Ok(encoded) => encoded, + Err(e) => { + assert!(false, "{}", e.to_string()); + return; + } + }; + let full_record = "v=Fenrir1 ".to_string() + &encoded; + let record = match Dnssec::parse_txt_record(&full_record) { Ok(record) => record, Err(e) => { assert!(false, "{}", e.to_string()); return; } }; - let re_encoded = match record.encode() { + let _re_encoded = match record.encode() { Ok(re_encoded) => re_encoded, Err(e) => { assert!(false, "{}", e.to_string()); return; } }; - assert!( - TXT_RECORD[10..] == re_encoded, - "DNSSEC record decoding->encoding failed:\n{}\n{}", - TXT_RECORD, - re_encoded - ); } } diff --git a/src/dnssec/record.rs b/src/dnssec/record.rs index f5595e2..b3fdbec 100644 --- a/src/dnssec/record.rs +++ b/src/dnssec/record.rs @@ -429,7 +429,7 @@ impl Record { + self .public_keys .iter() - .map(|(_, key)| 4 + key.kind().pub_len()) + .map(|(_, key)| 3 + key.kind().pub_len()) .sum::() + self.key_exchanges.len() + self.hkdfs.len() @@ -463,7 +463,7 @@ impl Record { let written_next = written + KeyID::len(); raw[written..written_next].copy_from_slice(&key_id_bytes); written = written_next; - raw[written] = public_key.kind().pub_len() as u8; + raw[written] = public_key.len() as u8; written = written + 1; let written_next = written + public_key.len(); public_key.serialize_into(&mut raw[written..written_next]); @@ -531,10 +531,10 @@ impl Record { let raw_key_id = u16::from_le_bytes([raw[bytes_parsed], raw[bytes_parsed + 1]]); let id = KeyID(raw_key_id); - bytes_parsed = bytes_parsed + 2; + bytes_parsed = bytes_parsed + KeyID::len(); let pubkey_length = raw[bytes_parsed] as usize; bytes_parsed = bytes_parsed + 1; - let bytes_next_key = bytes_parsed + 1 + pubkey_length; + let bytes_next_key = bytes_parsed + pubkey_length; if bytes_next_key > raw.len() { return Err(Error::NotEnoughData(bytes_parsed)); } @@ -551,7 +551,7 @@ impl Record { return Err(Error::UnsupportedData(bytes_parsed)); } }; - if bytes != 1 + pubkey_length { + if bytes != pubkey_length { return Err(Error::UnsupportedData(bytes_parsed)); } bytes_parsed = bytes_parsed + bytes; diff --git a/src/enc/asym.rs b/src/enc/asym.rs index 5d057f2..16201cf 100644 --- a/src/enc/asym.rs +++ b/src/enc/asym.rs @@ -93,16 +93,19 @@ pub enum KeyKind { #[strum(serialize = "x25519")] X25519, } -// FIXME: actually check this -const MIN_KEY_SIZE: usize = 32; impl KeyKind { + /// Length of the serialized field + pub const fn len() -> usize { + 1 + } /// return the expected length of the public key pub fn pub_len(&self) -> usize { - match self { - // FIXME: 99% wrong size - KeyKind::Ed25519 => ::ring::signature::ED25519_PUBLIC_KEY_LEN, - KeyKind::X25519 => 32, - } + KeyKind::len() + + match self { + // FIXME: 99% wrong size + KeyKind::Ed25519 => ::ring::signature::ED25519_PUBLIC_KEY_LEN, + KeyKind::X25519 => 32, + } } /// Get the capabilities of this key type pub fn capabilities(&self) -> KeyCapabilities { @@ -185,7 +188,7 @@ pub enum PubKey { impl PubKey { /// Get the serialized key length pub fn len(&self) -> usize { - 1 + match self { + match self { PubKey::Exchange(ex) => ex.len(), PubKey::Signing => todo!(), } @@ -215,17 +218,12 @@ impl PubKey { /// serialize the key into the buffer /// NOTE: Assumes there is enough space pub fn serialize_into(&self, out: &mut [u8]) { - assert!( - out.len() >= 1 + self.kind().pub_len(), - "Not enough out buffer", - ); - out[0] = self.kind() as u8; match self { PubKey::Signing => { ::tracing::error!("serializing ed25519 not supported"); return; } - PubKey::Exchange(ex) => ex.serialize_into(&mut out[1..]), + PubKey::Exchange(ex) => ex.serialize_into(out), } } /// Try to deserialize the pubkey from raw bytes @@ -238,7 +236,7 @@ impl PubKey { Some(kind) => kind, None => return Err(Error::UnsupportedKey(1)), }; - if raw.len() < 1 + kind.pub_len() { + if raw.len() < kind.pub_len() { return Err(Error::NotEnoughData(1)); } match kind { @@ -259,7 +257,7 @@ impl PubKey { }; Ok(( PubKey::Exchange(ExchangePubKey::X25519(pub_key)), - 1 + kind.pub_len(), + kind.pub_len(), )) } } @@ -281,7 +279,7 @@ pub enum PrivKey { impl PrivKey { /// Get the serialized key length pub fn len(&self) -> usize { - 1 + match self { + match self { PrivKey::Exchange(ex) => ex.len(), PrivKey::Signing => todo!(), } @@ -296,9 +294,8 @@ impl PrivKey { /// serialize the key into the buffer /// NOTE: Assumes there is enough space pub fn serialize_into(&self, out: &mut [u8]) { - out[0] = self.kind() as u8; match self { - PrivKey::Exchange(ex) => ex.serialize_into(&mut out[1..]), + PrivKey::Exchange(ex) => ex.serialize_into(out), PrivKey::Signing => todo!(), } } @@ -346,9 +343,10 @@ impl ExchangePrivKey { /// serialize the key into the buffer /// NOTE: Assumes there is enough space pub fn serialize_into(&self, out: &mut [u8]) { + out[0] = self.kind() as u8; match self { ExchangePrivKey::X25519(key) => { - out[0..32].copy_from_slice(&key.to_bytes()); + out[1..33].copy_from_slice(&key.to_bytes()); } } } @@ -378,21 +376,18 @@ impl ExchangePubKey { /// serialize the key into the buffer /// NOTE: Assumes there is enough space pub fn serialize_into(&self, out: &mut [u8]) { + out[0] = self.kind() as u8; match self { ExchangePubKey::X25519(pk) => { let bytes = pk.as_bytes(); - assert!(bytes.len() == 32, "x25519 should have been 32 bytes"); - out[..32].copy_from_slice(bytes); + out[1..33].copy_from_slice(bytes); } } } /// Load public key used for key exchange from it raw bytes /// The riesult is "unparsed" since we don't verify /// the actual key - pub fn from_slice(raw: &[u8]) -> Result<(Self, usize), Error> { - if raw.len() < 1 + MIN_KEY_SIZE { - return Err(Error::NotEnoughData(0)); - } + pub fn deserialize(raw: &[u8]) -> Result<(Self, usize), Error> { match KeyKind::from_u8(raw[0]) { Some(kind) => match kind { KeyKind::Ed25519 => { From a6fda8180d2f63cb13690ff8a49b50e6fa7b6ed9 Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Fri, 9 Jun 2023 19:09:41 +0200 Subject: [PATCH 26/34] DNSSEC: move keys before addresses it was kinda stupid to keep the keys *after* the addresses but have the addresses keep an index to the array of pubkeys anyway Signed-off-by: Luca Fulchir --- src/dnssec/record.rs | 87 ++++++++++++++++++++++---------------------- 1 file changed, 43 insertions(+), 44 deletions(-) diff --git a/src/dnssec/record.rs b/src/dnssec/record.rs index b3fdbec..6a49a3a 100644 --- a/src/dnssec/record.rs +++ b/src/dnssec/record.rs @@ -13,6 +13,12 @@ //! * 1 byte: divided in half: //! * half: number of ciphers //! * half: nothing +//! [ # list of pubkeys (max: 16) +//! * 2 byte: pubkey id +//! * 1 byte: pubkey length +//! * 1 byte: pubkey type +//! * Y bytes: pubkey +//! ] //! [ # list of addresses //! * 1 byte: bitfield //! * 0..1 ipv4/ipv6 @@ -26,12 +32,6 @@ //! * [ 1 byte per handshake id ] //! * X bytes: IP //! ] -//! [ # list of pubkeys (max: 16) -//! * 2 byte: pubkey id -//! * 1 byte: pubkey length -//! * 1 byte: pubkey type -//! * Y bytes: pubkey -//! ] //! [ # list of supported key exchanges //! * 1 byte for each cipher //! ] @@ -452,12 +452,6 @@ impl Record { raw[2] = num_of_ciphers; let mut written: usize = 3; - for address in self.addresses.iter() { - let len = address.len(); - let written_next = written + len; - address.serialize_into(&mut raw[written..written_next]); - written = written_next; - } for (public_key_id, public_key) in self.public_keys.iter() { let key_id_bytes = public_key_id.0.to_le_bytes(); let written_next = written + KeyID::len(); @@ -469,6 +463,12 @@ impl Record { public_key.serialize_into(&mut raw[written..written_next]); written = written_next; } + for address in self.addresses.iter() { + let len = address.len(); + let written_next = written + len; + address.serialize_into(&mut raw[written..written_next]); + written = written_next; + } for k_x in self.key_exchanges.iter() { raw[written] = *k_x as u8; written = written + 1; @@ -506,23 +506,6 @@ impl Record { ciphers: Vec::with_capacity(num_ciphers), }; - while num_addresses > 0 { - let (address, bytes) = - match Address::decode_raw(&raw[bytes_parsed..]) { - Ok(address) => address, - Err(Error::UnsupportedData(b)) => { - return Err(Error::UnsupportedData(bytes_parsed + b)) - } - Err(Error::NotEnoughData(b)) => { - return Err(Error::NotEnoughData(bytes_parsed + b)) - } - Err(e) => return Err(e), - }; - bytes_parsed = bytes_parsed + bytes; - result.addresses.push(address); - num_addresses = num_addresses - 1; - } - while num_public_keys > 0 { if bytes_parsed + 3 >= raw.len() { return Err(Error::NotEnoughData(bytes_parsed)); @@ -558,6 +541,37 @@ impl Record { result.public_keys.push((id, public_key)); num_public_keys = num_public_keys - 1; } + while num_addresses > 0 { + let (address, bytes) = + match Address::decode_raw(&raw[bytes_parsed..]) { + Ok(address) => address, + Err(Error::UnsupportedData(b)) => { + return Err(Error::UnsupportedData(bytes_parsed + b)) + } + Err(Error::NotEnoughData(b)) => { + return Err(Error::NotEnoughData(bytes_parsed + b)) + } + Err(e) => return Err(e), + }; + bytes_parsed = bytes_parsed + bytes; + result.addresses.push(address); + num_addresses = num_addresses - 1; + } + for addr in result.addresses.iter() { + for idx in addr.public_key_idx.iter() { + if idx.0 as usize >= result.public_keys.len() { + return Err(Error::Max16PublicKeys); + } + if !result.public_keys[idx.0 as usize] + .1 + .kind() + .capabilities() + .has_exchange() + { + return Err(Error::UnsupportedData(bytes_parsed)); + } + } + } if bytes_parsed + num_key_exchanges + num_hkdfs + num_ciphers != raw.len() { @@ -618,21 +632,6 @@ impl Record { result.ciphers.push(cipher); num_ciphers = num_ciphers - 1; } - for addr in result.addresses.iter() { - for idx in addr.public_key_idx.iter() { - if idx.0 as usize >= result.public_keys.len() { - return Err(Error::Max16PublicKeys); - } - if !result.public_keys[idx.0 as usize] - .1 - .kind() - .capabilities() - .has_exchange() - { - return Err(Error::UnsupportedData(bytes_parsed)); - } - } - } if bytes_parsed != raw.len() { Err(Error::UnknownData(bytes_parsed)) } else { From 4df73b658ac7c3993cca4858e8f19ac96a982dfc Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Fri, 9 Jun 2023 20:01:18 +0200 Subject: [PATCH 27/34] Correctly test for equality the DirSync::Req Signed-off-by: Luca Fulchir --- src/auth/mod.rs | 4 ++-- src/connection/handshake/dirsync.rs | 26 ++++++++++++++------------ src/connection/handshake/mod.rs | 11 ++++++++--- src/connection/handshake/tests.rs | 16 +++++++++++++++- src/connection/mod.rs | 2 +- src/enc/mod.rs | 2 +- 6 files changed, 41 insertions(+), 20 deletions(-) diff --git a/src/auth/mod.rs b/src/auth/mod.rs index 20df6f7..84be8cb 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -4,7 +4,7 @@ use crate::enc::Random; use ::zeroize::Zeroize; /// User identifier. 16 bytes for easy uuid conversion -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, PartialEq)] pub struct UserID(pub [u8; 16]); impl From<[u8; 16]> for UserID { @@ -34,7 +34,7 @@ impl UserID { } } /// Authentication Token, basically just 32 random bytes -#[derive(Clone, Zeroize)] +#[derive(Clone, Zeroize, PartialEq)] #[zeroize(drop)] pub struct Token(pub [u8; 32]); diff --git a/src/connection/handshake/dirsync.rs b/src/connection/handshake/dirsync.rs index e387d2c..68d58da 100644 --- a/src/connection/handshake/dirsync.rs +++ b/src/connection/handshake/dirsync.rs @@ -24,8 +24,8 @@ use ::arrayref::array_mut_ref; // TODO: merge with crate::enc::sym::Nonce /// random nonce -#[derive(Debug, Clone, Copy)] -pub struct Nonce([u8; 16]); +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct Nonce(pub(crate) [u8; 16]); impl Nonce { /// Create a new random Nonce @@ -51,7 +51,7 @@ impl From<&[u8; 16]> for Nonce { } /// Parsed handshake -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub enum DirSync { /// Directory synchronized handshake: client request Req(Req), @@ -83,7 +83,7 @@ impl DirSync { } /// Client request of a directory synchronized handshake -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct Req { /// Id of the server key used for the key exchange pub key_id: KeyID, @@ -107,6 +107,7 @@ impl Req { /// NOTE: starts from the beginning of the fenrir packet pub fn encrypted_offset(&self) -> usize { ProtocolVersion::len() + + crate::handshake::HandshakeID::len() + KeyID::len() + KeyExchangeKind::len() + HkdfKind::len() @@ -195,7 +196,7 @@ impl super::HandshakeParsing for Req { } /// Quick way to avoid mixing cipher and clear text -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub enum ReqInner { /// Data is still encrytped, we only keep the length CipherText(usize), @@ -211,16 +212,17 @@ impl ReqInner { } } /// parse the cleartext + // FIXME: return Result<> pub fn deserialize_as_cleartext(&mut self, raw: &[u8]) { let clear = match self { ReqInner::CipherText(len) => { assert!( - *len == raw.len(), + *len > raw.len(), "DirSync::ReqInner::CipherText length mismatch" ); match ReqData::deserialize(raw) { Ok(clear) => clear, - Err(_) => return, + Err(_e) => return, } } _ => return, @@ -230,7 +232,7 @@ impl ReqInner { } /// Informations needed for authentication -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct AuthInfo { /// User of the domain pub user: auth::UserID, @@ -308,7 +310,7 @@ impl AuthInfo { } /// Decrypted request data -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct ReqData { /// Random nonce, the client can use this to track multiple key exchanges pub nonce: Nonce, @@ -373,7 +375,7 @@ impl ReqData { } /// Quick way to avoid mixing cipher and clear text -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub enum RespInner { /// Server data, still in ciphertext CipherText(usize), @@ -412,7 +414,7 @@ impl RespInner { } /// Server response in a directory synchronized handshake -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct Resp { /// Tells the client with which key the exchange was done pub client_key_id: KeyID, @@ -476,7 +478,7 @@ impl Resp { } /// Decrypted response data -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct RespData { /// Client nonce, copied from the request pub client_nonce: Nonce, diff --git a/src/connection/handshake/mod.rs b/src/connection/handshake/mod.rs index bd0b501..8796566 100644 --- a/src/connection/handshake/mod.rs +++ b/src/connection/handshake/mod.rs @@ -13,7 +13,6 @@ use crate::{ }, }; use ::num_traits::FromPrimitive; -use ::std::{collections::VecDeque, rc::Rc}; /// Handshake errors #[derive(::thiserror::Error, Debug, Copy, Clone)] @@ -65,6 +64,12 @@ pub enum HandshakeID { #[strum(serialize = "stateless")] Stateless, } +impl HandshakeID { + /// The length of the serialized field + pub const fn len() -> usize { + 1 + } +} pub(crate) struct HandshakeServer { pub id: KeyID, @@ -166,7 +171,7 @@ impl HandshakeClientList { } /// Parsed handshake -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub enum HandshakeData { /// Directory synchronized handhsake DirSync(dirsync::DirSync), @@ -220,7 +225,7 @@ impl HandshakeKind { } /// Parsed handshake -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub struct Handshake { /// Fenrir Protocol version pub fenrir_version: ProtocolVersion, diff --git a/src/connection/handshake/tests.rs b/src/connection/handshake/tests.rs index df485d7..d37012b 100644 --- a/src/connection/handshake/tests.rs +++ b/src/connection/handshake/tests.rs @@ -55,11 +55,25 @@ fn test_handshake_dirsync_req() { &mut bytes, ); - let deserialized = match Handshake::deserialize(&bytes) { + let mut deserialized = match Handshake::deserialize(&bytes) { Ok(deserialized) => deserialized, Err(e) => { assert!(false, "{}", e.to_string()); return; } }; + if let HandshakeData::DirSync(dirsync::DirSync::Req(r_a)) = + &mut deserialized.data + { + let enc_start = + r_a.encrypted_offset() + cipher_send.kind().nonce_len().0; + r_a.data.deserialize_as_cleartext( + &bytes[enc_start..(bytes.len() - cipher_send.kind().tag_len().0)], + ); + }; + + assert!( + deserialized == h_req, + "DirSync Req (de)serialization not working", + ); } diff --git a/src/connection/mod.rs b/src/connection/mod.rs index ee71c4e..07b7c18 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -30,7 +30,7 @@ pub struct IDRecv(pub ID); pub struct IDSend(pub ID); /// Version of the fenrir protocol in use -#[derive(::num_derive::FromPrimitive, Debug, Copy, Clone)] +#[derive(::num_derive::FromPrimitive, Debug, Copy, Clone, PartialEq)] #[repr(u8)] pub enum ProtocolVersion { /// First Fenrir Protocol Version diff --git a/src/enc/mod.rs b/src/enc/mod.rs index 06a6200..ff1bb9f 100644 --- a/src/enc/mod.rs +++ b/src/enc/mod.rs @@ -80,7 +80,7 @@ impl ::rand_core::CryptoRng for &Random {} /// Secret, used for keys. /// Grants that on drop() we will zero out memory -#[derive(Zeroize, Clone)] +#[derive(Zeroize, Clone, PartialEq)] #[zeroize(drop)] pub struct Secret([u8; 32]); // Fake debug implementation to avoid leaking secrets From e2874451d16ac2c1261e0dfa427b9667596c77fe Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Fri, 9 Jun 2023 20:05:59 +0200 Subject: [PATCH 28/34] Return error from parsing the encrypted ReqData Signed-off-by: Luca Fulchir --- src/connection/handshake/dirsync.rs | 11 +++++++---- src/connection/handshake/tests.rs | 6 ++++-- src/inner/mod.rs | 6 +++++- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/connection/handshake/dirsync.rs b/src/connection/handshake/dirsync.rs index 68d58da..34b2819 100644 --- a/src/connection/handshake/dirsync.rs +++ b/src/connection/handshake/dirsync.rs @@ -212,8 +212,10 @@ impl ReqInner { } } /// parse the cleartext - // FIXME: return Result<> - pub fn deserialize_as_cleartext(&mut self, raw: &[u8]) { + pub fn deserialize_as_cleartext( + &mut self, + raw: &[u8], + ) -> Result<(), Error> { let clear = match self { ReqInner::CipherText(len) => { assert!( @@ -222,12 +224,13 @@ impl ReqInner { ); match ReqData::deserialize(raw) { Ok(clear) => clear, - Err(_e) => return, + Err(e) => return Err(e), } } - _ => return, + _ => return Err(Error::Parsing), }; *self = ReqInner::ClearText(clear); + Ok(()) } } diff --git a/src/connection/handshake/tests.rs b/src/connection/handshake/tests.rs index d37012b..9530407 100644 --- a/src/connection/handshake/tests.rs +++ b/src/connection/handshake/tests.rs @@ -67,9 +67,11 @@ fn test_handshake_dirsync_req() { { let enc_start = r_a.encrypted_offset() + cipher_send.kind().nonce_len().0; - r_a.data.deserialize_as_cleartext( + if let Err(e) = r_a.data.deserialize_as_cleartext( &bytes[enc_start..(bytes.len() - cipher_send.kind().tag_len().0)], - ); + ) { + assert!(false, "DirSync Req Inner serialize: {}", e.to_string()); + } }; assert!( diff --git a/src/inner/mod.rs b/src/inner/mod.rs index 24ff170..faa7609 100644 --- a/src/inner/mod.rs +++ b/src/inner/mod.rs @@ -184,7 +184,11 @@ impl HandshakeTracker { &mut handshake_raw[req.encrypted_offset()..], ) { Ok(cleartext) => { - req.data.deserialize_as_cleartext(cleartext) + if let Err(e) = + req.data.deserialize_as_cleartext(cleartext) + { + return Err(e.into()); + } } Err(e) => { return Err(handshake::Error::Key(e).into()); From a32dfe098f53fc23550a95bbfc8a1fee90e2ff59 Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Fri, 9 Jun 2023 21:02:42 +0200 Subject: [PATCH 29/34] Add the git pre-commit hook Signed-off-by: Luca Fulchir --- Readme.md | 24 ++++++++++++++++++++++++ src/connection/handshake/tests.rs | 17 ++++------------- var/git-pre-commit | 4 ++++ 3 files changed, 32 insertions(+), 13 deletions(-) create mode 100755 var/git-pre-commit diff --git a/Readme.md b/Readme.md index 7a2f74b..d67fe77 100644 --- a/Readme.md +++ b/Readme.md @@ -15,3 +15,27 @@ you will find the result in `./target/release` If you want to build the `Hati` server, you don't need to build this library separately. Just build the server and it will automatically include this lib +# Developing + +we recommend to use the nix environment, so that you will have +exactly the same environment as the developers. + +just enter the repository directory and run + +``` +nix develop +``` + +and everything should be done for you. + +## Git + +Please configure a pre-commit hook like the one in `var/git-pre-commit` + +``` +cp var/git-pre-commit .git/hooks/pre-commit +``` + +This will run `cargo test --offline` right before your commit, +to make sure that everything compiles and that the test pass + diff --git a/src/connection/handshake/tests.rs b/src/connection/handshake/tests.rs index 9530407..c5f84dd 100644 --- a/src/connection/handshake/tests.rs +++ b/src/connection/handshake/tests.rs @@ -8,11 +8,7 @@ use crate::{ fn test_handshake_dirsync_req() { let rand = enc::Random::new(); let secret = enc::Secret::new_rand(&rand); - let cipher_send = enc::sym::CipherSend::new( - enc::sym::CipherKind::XChaCha20Poly1305, - secret, - &rand, - ); + let cipher = enc::sym::CipherKind::XChaCha20Poly1305; let (_, exchange_key) = match enc::asym::KeyExchangeKind::X25519DiffieHellman.new_keypair(&rand) @@ -49,11 +45,7 @@ fn test_handshake_dirsync_req() { let mut bytes = Vec::::with_capacity(h_req.len()); bytes.resize(h_req.len(), 0); - h_req.serialize( - cipher_send.kind().nonce_len(), - cipher_send.kind().tag_len(), - &mut bytes, - ); + h_req.serialize(cipher.nonce_len(), cipher.tag_len(), &mut bytes); let mut deserialized = match Handshake::deserialize(&bytes) { Ok(deserialized) => deserialized, @@ -65,10 +57,9 @@ fn test_handshake_dirsync_req() { if let HandshakeData::DirSync(dirsync::DirSync::Req(r_a)) = &mut deserialized.data { - let enc_start = - r_a.encrypted_offset() + cipher_send.kind().nonce_len().0; + let enc_start = r_a.encrypted_offset() + cipher.nonce_len().0; if let Err(e) = r_a.data.deserialize_as_cleartext( - &bytes[enc_start..(bytes.len() - cipher_send.kind().tag_len().0)], + &bytes[enc_start..(bytes.len() - cipher.tag_len().0)], ) { assert!(false, "DirSync Req Inner serialize: {}", e.to_string()); } diff --git a/var/git-pre-commit b/var/git-pre-commit new file mode 100755 index 0000000..81ac843 --- /dev/null +++ b/var/git-pre-commit @@ -0,0 +1,4 @@ +#!/usr/bin/env bash + +RUSTFLAGS=-Awarnings exec cargo test --offline --profile dev + From faaf8762c7c012ec6969d7caf1d1a9f8fd026991 Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Fri, 9 Jun 2023 21:58:33 +0200 Subject: [PATCH 30/34] Test (de)serialization of DirSync::Resp Signed-off-by: Luca Fulchir --- src/connection/handshake/dirsync.rs | 92 ++++++++++++++++++----------- src/connection/handshake/mod.rs | 10 ++-- src/connection/handshake/tests.rs | 59 +++++++++++++++++- src/connection/packet.rs | 18 +++--- src/enc/mod.rs | 9 +++ src/inner/mod.rs | 8 +-- src/inner/worker.rs | 3 +- 7 files changed, 142 insertions(+), 57 deletions(-) diff --git a/src/connection/handshake/dirsync.rs b/src/connection/handshake/dirsync.rs index 34b2819..8023ed6 100644 --- a/src/connection/handshake/dirsync.rs +++ b/src/connection/handshake/dirsync.rs @@ -20,8 +20,6 @@ use crate::{ }, }; -use ::arrayref::array_mut_ref; - // TODO: merge with crate::enc::sym::Nonce /// random nonce #[derive(Debug, Clone, Copy, PartialEq)] @@ -61,10 +59,10 @@ pub enum DirSync { impl DirSync { /// actual length of the dirsync handshake data - pub fn len(&self) -> usize { + pub fn len(&self, head_len: HeadLen, tag_len: TagLen) -> usize { match self { DirSync::Req(req) => req.len(), - DirSync::Resp(resp) => resp.len(), + DirSync::Resp(resp) => resp.len(head_len, tag_len), } } /// Serialize into raw bytes @@ -394,25 +392,31 @@ impl RespInner { } } /// parse the cleartext - pub fn deserialize_as_cleartext(&mut self, raw: &[u8]) { + pub fn deserialize_as_cleartext( + &mut self, + raw: &[u8], + ) -> Result<(), Error> { let clear = match self { RespInner::CipherText(len) => { assert!( - *len == raw.len(), + *len > raw.len(), "DirSync::RespInner::CipherText length mismatch" ); match RespData::deserialize(raw) { Ok(clear) => clear, - Err(_) => return, + Err(e) => return Err(e), } } - _ => return, + _ => return Err(Error::Parsing), }; *self = RespInner::ClearText(clear); + Ok(()) } - /// serialize, but only if ciphertext + /// Serialize the still cleartext data pub fn serialize(&self, out: &mut [u8]) { - todo!() + if let RespInner::ClearText(clear) = &self { + clear.serialize(out); + } } } @@ -432,7 +436,7 @@ impl super::HandshakeParsing for Resp { return Err(Error::NotEnoughData); } let client_key_id: KeyID = - KeyID(u16::from_le_bytes(raw[0..1].try_into().unwrap())); + KeyID(u16::from_le_bytes(raw[0..2].try_into().unwrap())); Ok(HandshakeData::DirSync(DirSync::Resp(Self { client_key_id, data: RespInner::CipherText(raw[KeyID::len()..].len()), @@ -444,7 +448,9 @@ impl Resp { /// return the offset of the encrypted data /// NOTE: starts from the beginning of the fenrir packet pub fn encrypted_offset(&self) -> usize { - ProtocolVersion::len() + KeyID::len() + ProtocolVersion::len() + + crate::connection::handshake::HandshakeID::len() + + KeyID::len() } /// return the total length of the cleartext data pub fn encrypted_length(&self) -> usize { @@ -454,29 +460,21 @@ impl Resp { } } /// Total length of the response handshake - pub fn len(&self) -> usize { - KeyID::len() + self.data.len() + pub fn len(&self, head_len: HeadLen, tag_len: TagLen) -> usize { + KeyID::len() + head_len.0 + self.data.len() + tag_len.0 } /// Serialize into raw bytes /// NOTE: assumes that there is exactly as much buffer as needed - /// NOTE: assumes that the data is *ClearText* pub fn serialize( &self, head_len: HeadLen, - tag_len: TagLen, + _tag_len: TagLen, out: &mut [u8], ) { - assert!( - out.len() == KeyID::len() + self.data.len(), - "DirSync Resp: not enough buffer to serialize" - ); - self.client_key_id.serialize(array_mut_ref![out, 0, 2]); - let end_data = (2 + self.data.len()) - tag_len.0; - self.data.serialize(&mut out[(2 + head_len.0)..end_data]); - } - /// Set the cleartext data after it was parsed - pub fn set_data(&mut self, data: RespData) { - self.data = RespInner::ClearText(data); + out[0..2].copy_from_slice(&self.client_key_id.0.to_le_bytes()); + let start_data = 2 + head_len.0; + let end_data = start_data + self.data.len(); + self.data.serialize(&mut out[start_data..end_data]); } } @@ -494,30 +492,54 @@ pub struct RespData { } impl RespData { - const NONCE_LEN: usize = ::core::mem::size_of::(); /// Return the expected length for buffer allocation pub fn len() -> usize { - Self::NONCE_LEN + ID::len() + ID::len() + 32 + Nonce::len() + ID::len() + ID::len() + Secret::len() } /// Serialize the data into a buffer /// NOTE: assumes that there is exactly asa much buffer as needed pub fn serialize(&self, out: &mut [u8]) { - assert!(out.len() == Self::len(), "wrong buffer size"); let mut start = 0; - let mut end = Self::NONCE_LEN; + let mut end = Nonce::len(); out[start..end].copy_from_slice(&self.client_nonce.0); start = end; - end = end + Self::NONCE_LEN; + end = end + ID::len(); self.id.serialize(&mut out[start..end]); start = end; - end = end + Self::NONCE_LEN; + end = end + ID::len(); self.service_connection_id.serialize(&mut out[start..end]); start = end; - end = end + Self::NONCE_LEN; + end = end + Secret::len(); out[start..end].copy_from_slice(self.service_key.as_ref()); } /// Parse the cleartext raw data pub fn deserialize(raw: &[u8]) -> Result { - todo!(); + let raw_sized: &[u8; 16] = raw[..Nonce::len()].try_into().unwrap(); + let client_nonce: Nonce = raw_sized.into(); + let end = Nonce::len() + ID::len(); + let id: ID = + u64::from_le_bytes(raw[Nonce::len()..end].try_into().unwrap()) + .into(); + if id.is_handshake() { + return Err(Error::Parsing); + } + let parsed = end; + let end = parsed + ID::len(); + let service_connection_id: ID = + u64::from_le_bytes(raw[parsed..end].try_into().unwrap()).into(); + if service_connection_id.is_handshake() { + return Err(Error::Parsing); + } + let parsed = end; + let end = parsed + Secret::len(); + let raw_secret: &[u8; 32] = raw[parsed..end].try_into().unwrap(); + let service_key = raw_secret.into(); + + Ok(Self { + client_nonce, + id, + service_connection_id, + service_key, + }) } } diff --git a/src/connection/handshake/mod.rs b/src/connection/handshake/mod.rs index 8796566..6322e79 100644 --- a/src/connection/handshake/mod.rs +++ b/src/connection/handshake/mod.rs @@ -179,9 +179,9 @@ pub enum HandshakeData { impl HandshakeData { /// actual length of the handshake data - pub fn len(&self) -> usize { + pub fn len(&self, head_len: HeadLen, tag_len: TagLen) -> usize { match self { - HandshakeData::DirSync(d) => d.len(), + HandshakeData::DirSync(d) => d.len(head_len, tag_len), } } /// Serialize into raw bytes @@ -242,8 +242,10 @@ impl Handshake { } } /// return the total length of the handshake - pub fn len(&self) -> usize { - ProtocolVersion::len() + HandshakeKind::len() + self.data.len() + pub fn len(&self, head_len: HeadLen, tag_len: TagLen) -> usize { + ProtocolVersion::len() + + HandshakeKind::len() + + self.data.len(head_len, tag_len) } const MIN_PKT_LEN: usize = 8; /// Parse the packet and return the parsed handshake diff --git a/src/connection/handshake/tests.rs b/src/connection/handshake/tests.rs index c5f84dd..fc8b2b3 100644 --- a/src/connection/handshake/tests.rs +++ b/src/connection/handshake/tests.rs @@ -7,7 +7,6 @@ use crate::{ #[test] fn test_handshake_dirsync_req() { let rand = enc::Random::new(); - let secret = enc::Secret::new_rand(&rand); let cipher = enc::sym::CipherKind::XChaCha20Poly1305; let (_, exchange_key) = @@ -43,8 +42,10 @@ fn test_handshake_dirsync_req() { }, ))); - let mut bytes = Vec::::with_capacity(h_req.len()); - bytes.resize(h_req.len(), 0); + let mut bytes = Vec::::with_capacity( + h_req.len(cipher.nonce_len(), cipher.tag_len()), + ); + bytes.resize(h_req.len(cipher.nonce_len(), cipher.tag_len()), 0); h_req.serialize(cipher.nonce_len(), cipher.tag_len(), &mut bytes); let mut deserialized = match Handshake::deserialize(&bytes) { @@ -70,3 +71,55 @@ fn test_handshake_dirsync_req() { "DirSync Req (de)serialization not working", ); } +#[test] +fn test_handshake_dirsync_reqsp() { + let rand = enc::Random::new(); + let cipher = enc::sym::CipherKind::XChaCha20Poly1305; + + let service_key = enc::Secret::new_rand(&rand); + + let data = dirsync::RespInner::ClearText(dirsync::RespData { + client_nonce: dirsync::Nonce::new(&rand), + id: ID::ID(::core::num::NonZeroU64::new(424242).unwrap()), + service_connection_id: ID::ID( + ::core::num::NonZeroU64::new(434343).unwrap(), + ), + service_key, + }); + + let h_resp = Handshake::new(HandshakeData::DirSync( + dirsync::DirSync::Resp(dirsync::Resp { + client_key_id: KeyID(4444), + data, + }), + )); + + let mut bytes = Vec::::with_capacity( + h_resp.len(cipher.nonce_len(), cipher.tag_len()), + ); + bytes.resize(h_resp.len(cipher.nonce_len(), cipher.tag_len()), 0); + h_resp.serialize(cipher.nonce_len(), cipher.tag_len(), &mut bytes); + + let mut deserialized = match Handshake::deserialize(&bytes) { + Ok(deserialized) => deserialized, + Err(e) => { + assert!(false, "{}", e.to_string()); + return; + } + }; + if let HandshakeData::DirSync(dirsync::DirSync::Resp(r_a)) = + &mut deserialized.data + { + let enc_start = r_a.encrypted_offset() + cipher.nonce_len().0; + if let Err(e) = r_a.data.deserialize_as_cleartext( + &bytes[enc_start..(bytes.len() - cipher.tag_len().0)], + ) { + assert!(false, "DirSync Resp Inner serialize: {}", e.to_string()); + } + }; + + assert!( + deserialized == h_resp, + "DirSync Resp (de)serialization not working", + ); +} diff --git a/src/connection/packet.rs b/src/connection/packet.rs index b051594..925460c 100644 --- a/src/connection/packet.rs +++ b/src/connection/packet.rs @@ -58,11 +58,10 @@ impl ConnectionID { } /// write the ID to a buffer pub fn serialize(&self, out: &mut [u8]) { - assert!(out.len() == 8, "out buffer must be 8 bytes"); match self { - ConnectionID::Handshake => out[..].copy_from_slice(&[0; 8]), + ConnectionID::Handshake => out[..8].copy_from_slice(&[0; 8]), ConnectionID::ID(id) => { - out[..].copy_from_slice(&id.get().to_le_bytes()) + out[..8].copy_from_slice(&id.get().to_le_bytes()) } } } @@ -99,9 +98,9 @@ pub enum PacketData { impl PacketData { /// total length of the data in bytes - pub fn len(&self) -> usize { + pub fn len(&self, head_len: HeadLen, tag_len: TagLen) -> usize { match self { - PacketData::Handshake(h) => h.len(), + PacketData::Handshake(h) => h.len(head_len, tag_len), PacketData::Raw(len) => *len, } } @@ -113,7 +112,10 @@ impl PacketData { tag_len: TagLen, out: &mut [u8], ) { - assert!(self.len() == out.len(), "PacketData: wrong buffer length"); + assert!( + self.len(head_len, tag_len) == out.len(), + "PacketData: wrong buffer length" + ); match self { PacketData::Handshake(h) => h.serialize(head_len, tag_len, out), PacketData::Raw(_) => { @@ -148,8 +150,8 @@ impl Packet { }) } /// get the total length of the packet - pub fn len(&self) -> usize { - ConnectionID::len() + self.data.len() + pub fn len(&self, head_len: HeadLen, tag_len: TagLen) -> usize { + ConnectionID::len() + self.data.len(head_len, tag_len) } /// serialize packet into buffer /// NOTE: assumes that there is exactly asa much buffer as needed diff --git a/src/enc/mod.rs b/src/enc/mod.rs index ff1bb9f..9a01b98 100644 --- a/src/enc/mod.rs +++ b/src/enc/mod.rs @@ -94,6 +94,10 @@ impl ::core::fmt::Debug for Secret { } impl Secret { + /// return the length of the serialized secret + pub const fn len() -> usize { + 32 + } /// New randomly generated secret pub fn new_rand(rand: &Random) -> Self { let mut ret = Self([0; 32]); @@ -110,6 +114,11 @@ impl From<[u8; 32]> for Secret { Self(shared_secret) } } +impl From<&[u8; 32]> for Secret { + fn from(shared_secret: &[u8; 32]) -> Self { + Self(*shared_secret) + } +} impl From<::x25519_dalek::SharedSecret> for Secret { fn from(shared_secret: ::x25519_dalek::SharedSecret) -> Self { diff --git a/src/inner/mod.rs b/src/inner/mod.rs index faa7609..882e8c1 100644 --- a/src/inner/mod.rs +++ b/src/inner/mod.rs @@ -184,11 +184,7 @@ impl HandshakeTracker { &mut handshake_raw[req.encrypted_offset()..], ) { Ok(cleartext) => { - if let Err(e) = - req.data.deserialize_as_cleartext(cleartext) - { - return Err(e.into()); - } + req.data.deserialize_as_cleartext(cleartext)?; } Err(e) => { return Err(handshake::Error::Key(e).into()); @@ -223,7 +219,7 @@ impl HandshakeTracker { ..(resp.encrypted_offset() + resp.encrypted_length())]; match cipher_recv.decrypt(aad, &mut raw_data) { Ok(cleartext) => { - resp.data.deserialize_as_cleartext(&cleartext) + resp.data.deserialize_as_cleartext(&cleartext)?; } Err(e) => { return Err(handshake::Error::Key(e).into()); diff --git a/src/inner/worker.rs b/src/inner/worker.rs index 6d77c03..4f0459f 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -546,7 +546,8 @@ impl Worker { id: ID::new_handshake(), data: PacketData::Handshake(resp_handshake), }; - let mut raw_out = Vec::::with_capacity(packet.len()); + let mut raw_out = + Vec::::with_capacity(packet.len(head_len, tag_len)); packet.serialize(head_len, tag_len, &mut raw_out); if let Err(e) = auth_conn.cipher_send.encrypt( From aff1c313f5d4f98d9e48818645c9a7191a0e53c6 Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Sat, 10 Jun 2023 14:42:24 +0200 Subject: [PATCH 31/34] Cleanup & incomplete tests Signed-off-by: Luca Fulchir --- flake.nix | 3 +- src/config/mod.rs | 9 +- src/connection/handshake/mod.rs | 109 +--------- src/connection/handshake/tests.rs | 2 +- src/connection/handshake/tracker.rs | 326 ++++++++++++++++++++++++++++ src/dnssec/mod.rs | 64 +----- src/dnssec/record.rs | 2 +- src/dnssec/tests.rs | 59 +++++ src/enc/asym.rs | 10 +- src/enc/sym.rs | 1 - src/inner/mod.rs | 231 -------------------- src/inner/worker.rs | 5 +- src/lib.rs | 85 +++++++- src/tests.rs | 126 +++++++++++ 14 files changed, 618 insertions(+), 414 deletions(-) create mode 100644 src/connection/handshake/tracker.rs create mode 100644 src/dnssec/tests.rs create mode 100644 src/tests.rs diff --git a/flake.nix b/flake.nix index 7a74671..21e2442 100644 --- a/flake.nix +++ b/flake.nix @@ -18,6 +18,7 @@ pkgs-unstable = import nixpkgs-unstable { inherit system overlays; }; + RUST_VERSION="1.69.0"; in { devShells.default = pkgs.mkShell { @@ -40,7 +41,7 @@ cargo-flamegraph cargo-license lld - rust-bin.stable."1.69.0".default + rust-bin.stable.${RUST_VERSION}.default rustfmt rust-analyzer # fenrir deps diff --git a/src/config/mod.rs b/src/config/mod.rs index 09773c3..c0fe949 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -3,7 +3,11 @@ use crate::{ connection::handshake::HandshakeID, - enc::{asym::KeyExchangeKind, hkdf::HkdfKind, sym::CipherKind}, + enc::{ + asym::{KeyExchangeKind, KeyID, PrivKey, PubKey}, + hkdf::HkdfKind, + sym::CipherKind, + }, }; use ::std::{ net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, @@ -30,6 +34,8 @@ pub struct Config { pub hkdfs: Vec, /// Supported Ciphers pub ciphers: Vec, + /// list of public/private keys + pub keys: Vec<(KeyID, PrivKey, PubKey)>, } impl Default for Config { @@ -50,6 +56,7 @@ impl Default for Config { key_exchanges: [KeyExchangeKind::X25519DiffieHellman].to_vec(), hkdfs: [HkdfKind::Sha3].to_vec(), ciphers: [CipherKind::XChaCha20Poly1305].to_vec(), + keys: Vec::new(), } } } diff --git a/src/connection/handshake/mod.rs b/src/connection/handshake/mod.rs index 6322e79..6dd248e 100644 --- a/src/connection/handshake/mod.rs +++ b/src/connection/handshake/mod.rs @@ -3,14 +3,11 @@ pub mod dirsync; #[cfg(test)] mod tests; +pub(crate) mod tracker; use crate::{ - auth::ServiceID, - connection::{self, Connection, IDRecv, ProtocolVersion}, - enc::{ - asym::{KeyID, PrivKey, PubKey}, - sym::{HeadLen, TagLen}, - }, + connection::ProtocolVersion, + enc::sym::{HeadLen, TagLen}, }; use ::num_traits::FromPrimitive; @@ -70,106 +67,6 @@ impl HandshakeID { 1 } } - -pub(crate) struct HandshakeServer { - pub id: KeyID, - pub key: PrivKey, -} - -pub(crate) struct HandshakeClient { - pub service_id: ServiceID, - pub service_conn_id: IDRecv, - pub connection: Connection, - pub timeout: Option<::tokio::task::JoinHandle<()>>, -} - -/// Tracks the keys used by the client and the handshake -/// they are associated with -pub(crate) struct HandshakeClientList { - used: Vec<::bitmaps::Bitmap<1024>>, // index = KeyID - keys: Vec>, - list: Vec>, -} - -impl HandshakeClientList { - pub(crate) fn new() -> Self { - Self { - used: [::bitmaps::Bitmap::<1024>::new()].to_vec(), - keys: Vec::with_capacity(16), - list: Vec::with_capacity(16), - } - } - pub(crate) fn get(&self, id: KeyID) -> Option<&HandshakeClient> { - if id.0 as usize >= self.list.len() { - return None; - } - self.list[id.0 as usize].as_ref() - } - pub(crate) fn remove(&mut self, id: KeyID) -> Option { - if id.0 as usize >= self.list.len() { - return None; - } - let used_vec_idx = id.0 as usize / 1024; - let used_bitmap_idx = id.0 as usize % 1024; - let used_iter = match self.used.get_mut(used_vec_idx) { - Some(used_iter) => used_iter, - None => return None, - }; - used_iter.set(used_bitmap_idx, false); - self.keys[id.0 as usize] = None; - let mut owned = None; - ::core::mem::swap(&mut self.list[id.0 as usize], &mut owned); - owned - } - pub(crate) fn add( - &mut self, - priv_key: PrivKey, - pub_key: PubKey, - service_id: ServiceID, - service_conn_id: IDRecv, - connection: Connection, - ) -> Result<(KeyID, &mut HandshakeClient), ()> { - let maybe_free_key_idx = - self.used.iter().enumerate().find_map(|(idx, bmap)| { - match bmap.first_false_index() { - Some(false_idx) => Some(((idx * 1024), false_idx)), - None => None, - } - }); - let free_key_idx = match maybe_free_key_idx { - Some((idx, false_idx)) => { - let free_key_idx = idx * 1024 + false_idx; - if free_key_idx > KeyID::MAX as usize { - return Err(()); - } - self.used[idx].set(false_idx, true); - free_key_idx - } - None => { - let mut bmap = ::bitmaps::Bitmap::<1024>::new(); - bmap.set(0, true); - self.used.push(bmap); - self.used.len() * 1024 - } - }; - if self.keys.len() >= free_key_idx { - self.keys.push(None); - self.list.push(None); - } - self.keys[free_key_idx] = Some((priv_key, pub_key)); - self.list[free_key_idx] = Some(HandshakeClient { - service_id, - service_conn_id, - connection, - timeout: None, - }); - Ok(( - KeyID(free_key_idx as u16), - self.list[free_key_idx].as_mut().unwrap(), - )) - } -} - /// Parsed handshake #[derive(Debug, Clone, PartialEq)] pub enum HandshakeData { diff --git a/src/connection/handshake/tests.rs b/src/connection/handshake/tests.rs index fc8b2b3..83cacfd 100644 --- a/src/connection/handshake/tests.rs +++ b/src/connection/handshake/tests.rs @@ -1,7 +1,7 @@ use crate::{ auth, connection::{handshake::*, ID}, - enc, + enc::{self, asym::KeyID}, }; #[test] diff --git a/src/connection/handshake/tracker.rs b/src/connection/handshake/tracker.rs new file mode 100644 index 0000000..a907f3b --- /dev/null +++ b/src/connection/handshake/tracker.rs @@ -0,0 +1,326 @@ +//! Handhsake handling + +use crate::{ + auth::ServiceID, + connection::{ + self, + handshake::{self, Error, Handshake}, + Connection, IDRecv, + }, + enc::{ + self, + asym::{self, KeyID, PrivKey, PubKey}, + hkdf::{Hkdf, HkdfKind}, + sym::{CipherKind, CipherRecv}, + }, + inner::ThreadTracker, +}; + +pub(crate) struct HandshakeServer { + pub id: KeyID, + pub key: PrivKey, +} + +pub(crate) struct HandshakeClient { + pub service_id: ServiceID, + pub service_conn_id: IDRecv, + pub connection: Connection, + pub timeout: Option<::tokio::task::JoinHandle<()>>, +} + +/// Tracks the keys used by the client and the handshake +/// they are associated with +pub(crate) struct HandshakeClientList { + used: Vec<::bitmaps::Bitmap<1024>>, // index = KeyID + keys: Vec>, + list: Vec>, +} + +impl HandshakeClientList { + pub(crate) fn new() -> Self { + Self { + used: [::bitmaps::Bitmap::<1024>::new()].to_vec(), + keys: Vec::with_capacity(16), + list: Vec::with_capacity(16), + } + } + pub(crate) fn get(&self, id: KeyID) -> Option<&HandshakeClient> { + if id.0 as usize >= self.list.len() { + return None; + } + self.list[id.0 as usize].as_ref() + } + pub(crate) fn remove(&mut self, id: KeyID) -> Option { + if id.0 as usize >= self.list.len() { + return None; + } + let used_vec_idx = id.0 as usize / 1024; + let used_bitmap_idx = id.0 as usize % 1024; + let used_iter = match self.used.get_mut(used_vec_idx) { + Some(used_iter) => used_iter, + None => return None, + }; + used_iter.set(used_bitmap_idx, false); + self.keys[id.0 as usize] = None; + let mut owned = None; + ::core::mem::swap(&mut self.list[id.0 as usize], &mut owned); + owned + } + pub(crate) fn add( + &mut self, + priv_key: PrivKey, + pub_key: PubKey, + service_id: ServiceID, + service_conn_id: IDRecv, + connection: Connection, + ) -> Result<(KeyID, &mut HandshakeClient), ()> { + let maybe_free_key_idx = + self.used.iter().enumerate().find_map(|(idx, bmap)| { + match bmap.first_false_index() { + Some(false_idx) => Some(((idx * 1024), false_idx)), + None => None, + } + }); + let free_key_idx = match maybe_free_key_idx { + Some((idx, false_idx)) => { + let free_key_idx = idx * 1024 + false_idx; + if free_key_idx > KeyID::MAX as usize { + return Err(()); + } + self.used[idx].set(false_idx, true); + free_key_idx + } + None => { + let mut bmap = ::bitmaps::Bitmap::<1024>::new(); + bmap.set(0, true); + self.used.push(bmap); + self.used.len() * 1024 + } + }; + if self.keys.len() >= free_key_idx { + self.keys.push(None); + self.list.push(None); + } + self.keys[free_key_idx] = Some((priv_key, pub_key)); + self.list[free_key_idx] = Some(HandshakeClient { + service_id, + service_conn_id, + connection, + timeout: None, + }); + Ok(( + KeyID(free_key_idx as u16), + self.list[free_key_idx].as_mut().unwrap(), + )) + } +} +/// Information needed to reply after the key exchange +#[derive(Debug, Clone)] +pub(crate) struct AuthNeededInfo { + /// Parsed handshake packet + pub handshake: Handshake, + /// hkdf generated from the handshake + pub hkdf: Hkdf, + /// cipher to be used in both directions + pub cipher: CipherKind, +} + +/// Client information needed to fully establish the conenction +#[derive(Debug)] +pub(crate) struct ClientConnectInfo { + /// The service ID that we are connecting to + pub service_id: ServiceID, + /// The service ID that we are connecting to + pub service_connection_id: IDRecv, + /// Parsed handshake packet + pub handshake: Handshake, + /// Connection + pub connection: Connection, +} +/// Intermediate actions to be taken while parsing the handshake +#[derive(Debug)] +pub(crate) enum HandshakeAction { + /// Parsing finished, all ok, nothing to do + Nonthing, + /// Packet parsed, now go perform authentication + AuthNeeded(AuthNeededInfo), + /// the client can fully establish a connection with this info + ClientConnect(ClientConnectInfo), +} + +/// Tracking of handhsakes and conenctions +/// Note that we have multiple Handshake trackers, pinned to different cores +/// Each of them will handle a subset of all handshakes. +/// Each handshake is routed to a different tracker by checking +/// core = (udp_src_sender_port % total_threads) - 1 +pub(crate) struct HandshakeTracker { + thread_id: ThreadTracker, + key_exchanges: Vec<(asym::KeyKind, asym::KeyExchangeKind)>, + ciphers: Vec, + /// ephemeral keys used server side in key exchange + keys_srv: Vec, + /// ephemeral keys used client side in key exchange + hshake_cli: HandshakeClientList, +} + +impl HandshakeTracker { + pub(crate) fn new(thread_id: ThreadTracker) -> Self { + Self { + thread_id, + ciphers: Vec::new(), + key_exchanges: Vec::new(), + keys_srv: Vec::new(), + hshake_cli: HandshakeClientList::new(), + } + } + pub(crate) fn new_client( + &mut self, + priv_key: PrivKey, + pub_key: PubKey, + service_id: ServiceID, + service_conn_id: IDRecv, + connection: Connection, + ) -> Result<(KeyID, &mut HandshakeClient), ()> { + self.hshake_cli.add( + priv_key, + pub_key, + service_id, + service_conn_id, + connection, + ) + } + pub(crate) fn timeout_client( + &mut self, + key_id: KeyID, + ) -> Option<[IDRecv; 2]> { + if let Some(hshake) = self.hshake_cli.remove(key_id) { + Some([hshake.connection.id_recv, hshake.service_conn_id]) + } else { + None + } + } + pub(crate) fn recv_handshake( + &mut self, + mut handshake: Handshake, + handshake_raw: &mut [u8], + ) -> Result { + use connection::handshake::{dirsync::DirSync, HandshakeData}; + match handshake.data { + HandshakeData::DirSync(ref mut ds) => match ds { + DirSync::Req(ref mut req) => { + let ephemeral_key = { + if let Some(h_k) = + self.keys_srv.iter().find(|k| k.id == req.key_id) + { + // Directory synchronized can only use keys + // for key exchange, not signing keys + if let PrivKey::Exchange(k) = &h_k.key { + Some(k.clone()) + } else { + None + } + } else { + None + } + }; + if ephemeral_key.is_none() { + ::tracing::debug!( + "No such server key id: {:?}", + req.key_id + ); + return Err(handshake::Error::UnknownKeyID.into()); + } + let ephemeral_key = ephemeral_key.unwrap(); + { + if None + == self.key_exchanges.iter().find(|&x| { + *x == (ephemeral_key.kind(), req.exchange) + }) + { + return Err( + enc::Error::UnsupportedKeyExchange.into() + ); + } + } + { + if None + == self.ciphers.iter().find(|&x| *x == req.cipher) + { + return Err(enc::Error::UnsupportedCipher.into()); + } + } + let shared_key = match ephemeral_key + .key_exchange(req.exchange, req.exchange_key) + { + Ok(shared_key) => shared_key, + Err(e) => return Err(handshake::Error::Key(e).into()), + }; + let hkdf = Hkdf::new(HkdfKind::Sha3, b"fenrir", shared_key); + let secret_recv = hkdf.get_secret(b"to_server"); + let cipher_recv = CipherRecv::new(req.cipher, secret_recv); + use crate::enc::sym::AAD; + let aad = AAD(&mut []); // no aad for now + match cipher_recv.decrypt( + aad, + &mut handshake_raw[req.encrypted_offset()..], + ) { + Ok(cleartext) => { + req.data.deserialize_as_cleartext(cleartext)?; + } + Err(e) => { + return Err(handshake::Error::Key(e).into()); + } + } + + let cipher = req.cipher; + + return Ok(HandshakeAction::AuthNeeded(AuthNeededInfo { + handshake, + hkdf, + cipher, + })); + } + DirSync::Resp(resp) => { + let hshake = match self.hshake_cli.get(resp.client_key_id) { + Some(hshake) => hshake, + None => { + ::tracing::debug!( + "No such client key id: {:?}", + resp.client_key_id + ); + return Err(handshake::Error::UnknownKeyID.into()); + } + }; + let cipher_recv = &hshake.connection.cipher_recv; + use crate::enc::sym::AAD; + // no aad for now + let aad = AAD(&mut []); + let mut raw_data = &mut handshake_raw[resp + .encrypted_offset() + ..(resp.encrypted_offset() + resp.encrypted_length())]; + match cipher_recv.decrypt(aad, &mut raw_data) { + Ok(cleartext) => { + resp.data.deserialize_as_cleartext(&cleartext)?; + } + Err(e) => { + return Err(handshake::Error::Key(e).into()); + } + } + let hshake = + self.hshake_cli.remove(resp.client_key_id).unwrap(); + if let Some(timeout) = hshake.timeout { + timeout.abort(); + } + return Ok(HandshakeAction::ClientConnect( + ClientConnectInfo { + service_id: hshake.service_id, + service_connection_id: hshake.service_conn_id, + handshake, + connection: hshake.connection, + }, + )); + } + }, + } + } +} diff --git a/src/dnssec/mod.rs b/src/dnssec/mod.rs index bc0f661..d1128c1 100644 --- a/src/dnssec/mod.rs +++ b/src/dnssec/mod.rs @@ -9,6 +9,9 @@ pub use record::Record; use crate::auth::Domain; +#[cfg(test)] +mod tests; + /// Common errors for Dnssec setup and usage #[derive(::thiserror::Error, Debug)] pub enum Error { @@ -44,7 +47,7 @@ pub struct Dnssec { impl Dnssec { /// Spawn connections to DNS via TCP - pub async fn new(resolvers: &Vec) -> Result { + pub fn new(resolvers: &Vec) -> Result { // use a TCP connection to the DNS. // the records we need are big, will not fit in a UDP packet let resolv_conf_resolvers: Vec; @@ -146,62 +149,3 @@ impl Dnssec { }; } } -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_serialization() { - let rand = enc::Random::new(); - let (_, exchange_key) = - match enc::asym::KeyExchangeKind::X25519DiffieHellman - .new_keypair(&rand) - { - Ok(pair) => pair, - Err(_) => { - assert!(false, "Can't generate random keypair"); - return; - } - }; - use crate::enc; - let record = Record { - public_keys : [(enc::asym::KeyID(42), - enc::asym::PubKey::Exchange(exchange_key))].to_vec(), - addresses: [record::Address { - ip: ::std::net::IpAddr::V4(::std::net::Ipv4Addr::new(127,0,0,1)), - port: Some(::core::num::NonZeroU16::new(31337).unwrap()), - priority: record::AddressPriority::P1, - weight: record::AddressWeight::W1, - handshake_ids: [crate::connection::handshake::HandshakeID::DirectorySynchronized].to_vec(), - public_key_idx : [record::PubKeyIdx(0)].to_vec(), - - }].to_vec(), - key_exchanges: [enc::asym::KeyExchangeKind::X25519DiffieHellman].to_vec(), - hkdfs: [enc::hkdf::HkdfKind::Sha3].to_vec(), - ciphers: [enc::sym::CipherKind::XChaCha20Poly1305].to_vec(), - - }; - let encoded = match record.encode() { - Ok(encoded) => encoded, - Err(e) => { - assert!(false, "{}", e.to_string()); - return; - } - }; - let full_record = "v=Fenrir1 ".to_string() + &encoded; - let record = match Dnssec::parse_txt_record(&full_record) { - Ok(record) => record, - Err(e) => { - assert!(false, "{}", e.to_string()); - return; - } - }; - let _re_encoded = match record.encode() { - Ok(re_encoded) => re_encoded, - Err(e) => { - assert!(false, "{}", e.to_string()); - return; - } - }; - } -} diff --git a/src/dnssec/record.rs b/src/dnssec/record.rs index 6a49a3a..9426b3e 100644 --- a/src/dnssec/record.rs +++ b/src/dnssec/record.rs @@ -46,7 +46,7 @@ use crate::{ connection::handshake::HandshakeID, enc::{ self, - asym::{ExchangePubKey, KeyExchangeKind, KeyID, PubKey}, + asym::{KeyExchangeKind, KeyID, PubKey}, hkdf::HkdfKind, sym::CipherKind, }, diff --git a/src/dnssec/tests.rs b/src/dnssec/tests.rs new file mode 100644 index 0000000..ff450ae --- /dev/null +++ b/src/dnssec/tests.rs @@ -0,0 +1,59 @@ +use super::*; + +#[test] +fn test_dnssec_serialization() { + let rand = enc::Random::new(); + let (_, exchange_key) = + match enc::asym::KeyExchangeKind::X25519DiffieHellman.new_keypair(&rand) + { + Ok(pair) => pair, + Err(_) => { + assert!(false, "Can't generate random keypair"); + return; + } + }; + use crate::{connection::handshake::HandshakeID, enc}; + + let record = Record { + public_keys: [( + enc::asym::KeyID(42), + enc::asym::PubKey::Exchange(exchange_key), + )] + .to_vec(), + addresses: [record::Address { + ip: ::std::net::IpAddr::V4(::std::net::Ipv4Addr::new(127, 0, 0, 1)), + port: Some(::core::num::NonZeroU16::new(31337).unwrap()), + priority: record::AddressPriority::P1, + weight: record::AddressWeight::W1, + handshake_ids: [HandshakeID::DirectorySynchronized].to_vec(), + public_key_idx: [record::PubKeyIdx(0)].to_vec(), + }] + .to_vec(), + key_exchanges: [enc::asym::KeyExchangeKind::X25519DiffieHellman] + .to_vec(), + hkdfs: [enc::hkdf::HkdfKind::Sha3].to_vec(), + ciphers: [enc::sym::CipherKind::XChaCha20Poly1305].to_vec(), + }; + let encoded = match record.encode() { + Ok(encoded) => encoded, + Err(e) => { + assert!(false, "{}", e.to_string()); + return; + } + }; + let full_record = "v=Fenrir1 ".to_string() + &encoded; + let record = match Dnssec::parse_txt_record(&full_record) { + Ok(record) => record, + Err(e) => { + assert!(false, "{}", e.to_string()); + return; + } + }; + let _re_encoded = match record.encode() { + Ok(re_encoded) => re_encoded, + Err(e) => { + assert!(false, "{}", e.to_string()); + return; + } + }; +} diff --git a/src/enc/asym.rs b/src/enc/asym.rs index 16201cf..32c3aa2 100644 --- a/src/enc/asym.rs +++ b/src/enc/asym.rs @@ -169,7 +169,6 @@ impl KeyExchangeKind { let priv_key = ExchangePrivKey::X25519(raw_priv); Ok((priv_key, pub_key)) } - _ => Err(Error::UnsupportedKeyExchange), } } } @@ -300,6 +299,15 @@ impl PrivKey { } } } +// Fake debug implementation to avoid leaking secrets +impl ::core::fmt::Debug for PrivKey { + fn fmt( + &self, + f: &mut core::fmt::Formatter<'_>, + ) -> Result<(), ::std::fmt::Error> { + ::core::fmt::Debug::fmt("[hidden privkey]", f) + } +} /// Ephemeral private keys #[derive(Clone)] diff --git a/src/enc/sym.rs b/src/enc/sym.rs index 37b5d78..5728808 100644 --- a/src/enc/sym.rs +++ b/src/enc/sym.rs @@ -5,7 +5,6 @@ use crate::{ config::Config, enc::{Random, Secret}, }; -use ::zeroize::Zeroize; /// List of possible Ciphers #[derive( diff --git a/src/inner/mod.rs b/src/inner/mod.rs index 882e8c1..001ca16 100644 --- a/src/inner/mod.rs +++ b/src/inner/mod.rs @@ -4,60 +4,6 @@ pub(crate) mod worker; -use crate::{ - auth::ServiceID, - connection::{ - self, - handshake::{ - self, Handshake, HandshakeClient, HandshakeClientList, - HandshakeServer, - }, - Connection, IDRecv, - }, - enc::{ - self, - asym::{self, KeyID, PrivKey, PubKey}, - hkdf::{Hkdf, HkdfKind}, - sym::{CipherKind, CipherRecv}, - }, - Error, -}; -use ::std::vec::Vec; - -/// Information needed to reply after the key exchange -#[derive(Debug, Clone)] -pub(crate) struct AuthNeededInfo { - /// Parsed handshake packet - pub handshake: Handshake, - /// hkdf generated from the handshake - pub hkdf: Hkdf, - /// cipher to be used in both directions - pub cipher: CipherKind, -} - -/// Client information needed to fully establish the conenction -#[derive(Debug)] -pub(crate) struct ClientConnectInfo { - /// The service ID that we are connecting to - pub service_id: ServiceID, - /// The service ID that we are connecting to - pub service_connection_id: IDRecv, - /// Parsed handshake packet - pub handshake: Handshake, - /// Connection - pub connection: Connection, -} -/// Intermediate actions to be taken while parsing the handshake -#[derive(Debug)] -pub(crate) enum HandshakeAction { - /// Parsing finished, all ok, nothing to do - Nonthing, - /// Packet parsed, now go perform authentication - AuthNeeded(AuthNeededInfo), - /// the client can fully establish a connection with this info - ClientConnect(ClientConnectInfo), -} - /// Track the total number of threads and our index /// 65K cpus should be enough for anybody #[derive(Debug, Clone, Copy)] @@ -66,180 +12,3 @@ pub(crate) struct ThreadTracker { /// Note: starts from 1 pub id: u16, } - -/// Tracking of handhsakes and conenctions -/// Note that we have multiple Handshake trackers, pinned to different cores -/// Each of them will handle a subset of all handshakes. -/// Each handshake is routed to a different tracker by checking -/// core = (udp_src_sender_port % total_threads) - 1 -pub(crate) struct HandshakeTracker { - thread_id: ThreadTracker, - key_exchanges: Vec<(asym::KeyKind, asym::KeyExchangeKind)>, - ciphers: Vec, - /// ephemeral keys used server side in key exchange - keys_srv: Vec, - /// ephemeral keys used client side in key exchange - hshake_cli: HandshakeClientList, -} - -impl HandshakeTracker { - pub(crate) fn new(thread_id: ThreadTracker) -> Self { - Self { - thread_id, - ciphers: Vec::new(), - key_exchanges: Vec::new(), - keys_srv: Vec::new(), - hshake_cli: HandshakeClientList::new(), - } - } - pub(crate) fn new_client( - &mut self, - priv_key: PrivKey, - pub_key: PubKey, - service_id: ServiceID, - service_conn_id: IDRecv, - connection: Connection, - ) -> Result<(KeyID, &mut HandshakeClient), ()> { - self.hshake_cli.add( - priv_key, - pub_key, - service_id, - service_conn_id, - connection, - ) - } - pub(crate) fn timeout_client( - &mut self, - key_id: KeyID, - ) -> Option<[IDRecv; 2]> { - if let Some(hshake) = self.hshake_cli.remove(key_id) { - Some([hshake.connection.id_recv, hshake.service_conn_id]) - } else { - None - } - } - pub(crate) fn recv_handshake( - &mut self, - mut handshake: Handshake, - handshake_raw: &mut [u8], - ) -> Result { - use connection::handshake::{dirsync::DirSync, HandshakeData}; - match handshake.data { - HandshakeData::DirSync(ref mut ds) => match ds { - DirSync::Req(ref mut req) => { - let ephemeral_key = { - if let Some(h_k) = - self.keys_srv.iter().find(|k| k.id == req.key_id) - { - // Directory synchronized can only use keys - // for key exchange, not signing keys - if let PrivKey::Exchange(k) = &h_k.key { - Some(k.clone()) - } else { - None - } - } else { - None - } - }; - if ephemeral_key.is_none() { - ::tracing::debug!( - "No such server key id: {:?}", - req.key_id - ); - return Err(handshake::Error::UnknownKeyID.into()); - } - let ephemeral_key = ephemeral_key.unwrap(); - { - if None - == self.key_exchanges.iter().find(|&x| { - *x == (ephemeral_key.kind(), req.exchange) - }) - { - return Err( - enc::Error::UnsupportedKeyExchange.into() - ); - } - } - { - if None - == self.ciphers.iter().find(|&x| *x == req.cipher) - { - return Err(enc::Error::UnsupportedCipher.into()); - } - } - let shared_key = match ephemeral_key - .key_exchange(req.exchange, req.exchange_key) - { - Ok(shared_key) => shared_key, - Err(e) => return Err(handshake::Error::Key(e).into()), - }; - let hkdf = Hkdf::new(HkdfKind::Sha3, b"fenrir", shared_key); - let secret_recv = hkdf.get_secret(b"to_server"); - let cipher_recv = CipherRecv::new(req.cipher, secret_recv); - use crate::enc::sym::AAD; - let aad = AAD(&mut []); // no aad for now - match cipher_recv.decrypt( - aad, - &mut handshake_raw[req.encrypted_offset()..], - ) { - Ok(cleartext) => { - req.data.deserialize_as_cleartext(cleartext)?; - } - Err(e) => { - return Err(handshake::Error::Key(e).into()); - } - } - - let cipher = req.cipher; - - return Ok(HandshakeAction::AuthNeeded(AuthNeededInfo { - handshake, - hkdf, - cipher, - })); - } - DirSync::Resp(resp) => { - let hshake = match self.hshake_cli.get(resp.client_key_id) { - Some(hshake) => hshake, - None => { - ::tracing::debug!( - "No such client key id: {:?}", - resp.client_key_id - ); - return Err(handshake::Error::UnknownKeyID.into()); - } - }; - let cipher_recv = &hshake.connection.cipher_recv; - use crate::enc::sym::AAD; - // no aad for now - let aad = AAD(&mut []); - let mut raw_data = &mut handshake_raw[resp - .encrypted_offset() - ..(resp.encrypted_offset() + resp.encrypted_length())]; - match cipher_recv.decrypt(aad, &mut raw_data) { - Ok(cleartext) => { - resp.data.deserialize_as_cleartext(&cleartext)?; - } - Err(e) => { - return Err(handshake::Error::Key(e).into()); - } - } - let hshake = - self.hshake_cli.remove(resp.client_key_id).unwrap(); - if let Some(timeout) = hshake.timeout { - timeout.abort(); - } - return Ok(HandshakeAction::ClientConnect( - ClientConnectInfo { - service_id: hshake.service_id, - service_connection_id: hshake.service_conn_id, - handshake, - connection: hshake.connection, - }, - )); - } - }, - } - } -} diff --git a/src/inner/worker.rs b/src/inner/worker.rs index 4f0459f..e8f6c5a 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -7,6 +7,7 @@ use crate::{ handshake::{ self, dirsync::{self, DirSync}, + tracker::{HandshakeAction, HandshakeTracker}, Handshake, HandshakeData, }, socket::{UdpClient, UdpServer}, @@ -18,9 +19,9 @@ use crate::{ hkdf::{self, Hkdf, HkdfKind}, sym, Random, Secret, }, - inner::{HandshakeAction, HandshakeTracker, ThreadTracker}, + inner::ThreadTracker, }; -use ::std::{rc::Rc, sync::Arc, vec::Vec}; +use ::std::{sync::Arc, vec::Vec}; /// This worker must be cpu-pinned use ::tokio::{ net::UdpSocket, diff --git a/src/lib.rs b/src/lib.rs index cac07b0..0c01c8b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,6 +20,9 @@ pub mod dnssec; pub mod enc; mod inner; +#[cfg(test)] +mod tests; + use ::std::{sync::Arc, vec::Vec}; use ::tokio::{net::UdpSocket, sync::Mutex}; @@ -74,7 +77,7 @@ pub struct Fenrir { /// listening udp sockets sockets: SocketList, /// DNSSEC resolver, with failovers - dnssec: Option, + dnssec: dnssec::Dnssec, /// Broadcast channel to tell workers to stop working stop_working: ::tokio::sync::broadcast::Sender, /// where to ask for token check @@ -100,10 +103,11 @@ impl Fenrir { /// Create a new Fenrir endpoint pub fn new(config: &Config) -> Result { let (sender, _) = ::tokio::sync::broadcast::channel(1); + let dnssec = dnssec::Dnssec::new(&config.resolvers)?; let endpoint = Fenrir { cfg: config.clone(), sockets: SocketList::new(), - dnssec: None, + dnssec, stop_working: sender, token_check: None, conn_auth_srv: Mutex::new(AuthServerConnections::new()), @@ -113,6 +117,7 @@ impl Fenrir { Ok(endpoint) } + ///FIXME: remove this, move into new() /// Start all workers, listeners pub async fn start( &mut self, @@ -123,7 +128,14 @@ impl Fenrir { self.stop().await; return Err(e.into()); } - self.dnssec = Some(dnssec::Dnssec::new(&self.cfg.resolvers).await?); + Ok(()) + } + ///FIXME: remove this, move into new() + pub async fn setup_no_workers(&mut self) -> Result<(), Error> { + if let Err(e) = self.add_sockets().await { + self.stop().await; + return Err(e.into()); + } Ok(()) } @@ -137,7 +149,6 @@ impl Fenrir { let mut old_thread_pool = Vec::new(); ::std::mem::swap(&mut self._thread_pool, &mut old_thread_pool); let _ = old_thread_pool.into_iter().map(|th| th.join()); - self.dnssec = None; } /// Stop all workers, listeners @@ -148,7 +159,6 @@ impl Fenrir { let mut old_thread_pool = Vec::new(); ::std::mem::swap(&mut self._thread_pool, &mut old_thread_pool); let _ = old_thread_pool.into_iter().map(|th| th.join()); - self.dnssec = None; } /// Add all UDP sockets found in config /// and start listening for packets @@ -235,9 +245,9 @@ impl Fenrir { } /// Get the raw TXT record of a Fenrir domain pub async fn resolv_txt(&self, domain: &Domain) -> Result { - match &self.dnssec { - Some(dnssec) => Ok(dnssec.resolv(domain).await?), - None => Err(Error::NotInitialized), + match self.dnssec.resolv(domain).await { + Ok(res) => Ok(res), + Err(e) => Err(e.into()), } } @@ -250,13 +260,22 @@ impl Fenrir { Ok(dnssec::Dnssec::parse_txt_record(&record_str)?) } - /// Connect to a service + /// Connect to a service, doing the dnssec resolution ourselves pub async fn connect( &self, domain: &Domain, service: ServiceID, ) -> Result<(), Error> { let resolved = self.resolv(domain).await?; + self.connect_resolved(resolved, domain, service).await + } + /// Connect to a service, with the user provided details + pub async fn connect_resolved( + &self, + resolved: dnssec::Record, + domain: &Domain, + service: ServiceID, + ) -> Result<(), Error> { loop { // check if we already have a connection to that auth. srv let is_reserved = { @@ -354,6 +373,54 @@ impl Fenrir { } } + async fn start_single_worker( + &mut self, + ) -> ::std::result::Result< + impl futures::Future>, + Error, + > { + let thread_idx = self._thread_work.len() as u16; + let max_threads = self.cfg.threads.unwrap().get() as u16; + if thread_idx >= max_threads { + ::tracing::error!( + "thread id higher than number of threads in config" + ); + assert!( + false, + "thread_idx is an index that can't reach cfg.threads" + ); + return Err(Error::Setup("Thread id > threads_max".to_owned())); + } + let thread_id = ThreadTracker { + id: thread_idx, + total: max_threads, + }; + let (work_send, work_recv) = ::async_channel::unbounded::(); + let worker = Worker::new_and_loop( + self.cfg.clone(), + thread_id, + self.stop_working.subscribe(), + self.token_check.clone(), + self.cfg.listen.clone(), + work_recv, + ); + loop { + let queues_lock = match Arc::get_mut(&mut self._thread_work) { + Some(queues_lock) => queues_lock, + None => { + // should not even ever happen + ::tokio::time::sleep(::std::time::Duration::from_millis( + 50, + )) + .await; + continue; + } + }; + queues_lock.push(work_send); + break; + } + Ok(worker) + } // TODO: start work on a LocalSet provided by the user /// Start one working thread for each physical cpu /// threads are pinned to each cpu core. diff --git a/src/tests.rs b/src/tests.rs new file mode 100644 index 0000000..1032012 --- /dev/null +++ b/src/tests.rs @@ -0,0 +1,126 @@ +use crate::*; + +#[::tokio::test] +async fn test_connection_dirsync() { + return; + use enc::asym::{KeyID, PrivKey, PubKey}; + let rand = enc::Random::new(); + let (priv_exchange_key, pub_exchange_key) = + match enc::asym::KeyExchangeKind::X25519DiffieHellman.new_keypair(&rand) + { + Ok((privkey, pubkey)) => { + (PrivKey::Exchange(privkey), PubKey::Exchange(pubkey)) + } + Err(_) => { + assert!(false, "Can't generate random keypair"); + return; + } + }; + let dnssec_record = Record { + public_keys: [(KeyID(42), pub_exchange_key)].to_vec(), + addresses: [record::Address { + ip: ::std::net::IpAddr::V4(::std::net::Ipv4Addr::new(127, 0, 0, 1)), + port: Some(::core::num::NonZeroU16::new(31337).unwrap()), + priority: record::AddressPriority::P1, + weight: record::AddressWeight::W1, + handshake_ids: [HandshakeID::DirectorySynchronized].to_vec(), + public_key_idx: [record::PubKeyIdx(0)].to_vec(), + }] + .to_vec(), + key_exchanges: [enc::asym::KeyExchangeKind::X25519DiffieHellman] + .to_vec(), + hkdfs: [enc::hkdf::HkdfKind::Sha3].to_vec(), + ciphers: [enc::sym::CipherKind::XChaCha20Poly1305].to_vec(), + }; + let cfg_client = { + let mut cfg = config::Config::default(); + cfg.threads = Some(::core::num::NonZeroUsize::new(1).unwrap()); + cfg + }; + let cfg_server = { + let mut cfg = cfg_client.clone(); + cfg.keys = [(KeyID(42), priv_exchange_key, pub_exchange_key)].to_vec(); + cfg + }; + + let mut server = Fenrir::new(&cfg_server).unwrap(); + let _ = server.setup_no_workers().await; + let srv_worker = server.start_single_worker().await; + + ::tokio::task::spawn_local(async move { srv_worker }); + let mut client = Fenrir::new(&cfg_client).unwrap(); + let _ = client.setup_no_workers().await; + let cli_worker = server.start_single_worker().await; + ::tokio::task::spawn_local(async move { cli_worker }); + + use crate::{ + connection::handshake::HandshakeID, + dnssec::{record, Record}, + }; + + let _ = client + .connect_resolved( + dnssec_record, + &Domain("example.com".to_owned()), + auth::SERVICEID_AUTH, + ) + .await; + + /* + let thread_id = ThreadTracker { total: 1, id: 0 }; + + let (stop_sender, _) = ::tokio::sync::broadcast::channel::(1); + + use ::std::net; + let cli_socket_addr = [net::SocketAddr::new( + net::IpAddr::V4(net::Ipv4Addr::new(127, 0, 0, 1)), + 0, + )] + .to_vec(); + let srv_socket_addr = [net::SocketAddr::new( + net::IpAddr::V4(net::Ipv4Addr::new(127, 0, 0, 1)), + 0, + )] + .to_vec(); + + let srv_sock = Arc::new(connection::socket::bind_udp(srv_socket_addr[0]) + .await + .unwrap()); + let cli_sock = Arc::new(connection::socket::bind_udp(cli_socket_addr[0]) + .await + .unwrap()); + + use crate::inner::worker::Work; + let (srv_work_send, srv_work_recv) = ::async_channel::unbounded::(); + let (cli_work_send, cli_work_recv) = ::async_channel::unbounded::(); + + let srv_queue = Arc::new([srv_work_recv.clone()].to_vec()); + let cli_queue = Arc::new([cli_work_recv.clone()].to_vec()); + + let listen_work_srv = + + + ::tokio::spawn(Fenrir::listen_udp( + stop_sender.subscribe(), + + + let _server = crate::inner::worker::Worker::new( + cfg.clone(), + thread_id, + stop_sender.subscribe(), + None, + srv_socket_addr, + srv_work_recv, + ); + let _client = crate::inner::worker::Worker::new( + cfg, + thread_id, + stop_sender.subscribe(), + None, + cli_socket_addr, + cli_work_recv, + ); + + todo!() + */ +} From b682068dcae52abf9db1cb42f94815df99eb627a Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Sun, 11 Jun 2023 22:45:40 +0200 Subject: [PATCH 32/34] Test and fix shutdowns we have a Quick but partial shutdown, which lets the async "threads" work in the background and shutdown after a bit more time and the graceful/full shutdown, which waits for everything. Unfortunately `Drop` can't manage async and blocks everything, no way to yeld either, so if we only have a thread we would deadlock if we tried to stop things gracefully Signed-off-by: Luca Fulchir --- Cargo.toml | 12 ++ src/connection/handshake/tracker.rs | 77 +++++---- src/connection/socket.rs | 70 +------- src/inner/worker.rs | 69 ++++---- src/lib.rs | 249 ++++++++++++++++++++-------- src/tests.rs | 119 ++++--------- 6 files changed, 311 insertions(+), 285 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 12ed78d..0544e30 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,6 +50,7 @@ tokio = { version = "1", features = ["full"] } # PERF: todo linux-only, behind "iouring" feature #tokio-uring = { version = "0.4" } tracing = { version = "0.1" } +tracing-test = { version = "0.2" } trust-dns-resolver = { version = "0.22", features = [ "dnssec-ring" ] } trust-dns-client = { version = "0.22", features = [ "dnssec" ] } trust-dns-proto = { version = "0.22" } @@ -72,3 +73,14 @@ incremental = true codegen-units = 256 rpath = false +[profile.test] +opt-level = 0 +debug = true +debug-assertions = true +overflow-checks = true +lto = false +panic = 'unwind' +incremental = true +codegen-units = 256 +rpath = false + diff --git a/src/connection/handshake/tracker.rs b/src/connection/handshake/tracker.rs index a907f3b..f304337 100644 --- a/src/connection/handshake/tracker.rs +++ b/src/connection/handshake/tracker.rs @@ -141,7 +141,7 @@ pub(crate) struct ClientConnectInfo { #[derive(Debug)] pub(crate) enum HandshakeAction { /// Parsing finished, all ok, nothing to do - Nonthing, + Nothing, /// Packet parsed, now go perform authentication AuthNeeded(AuthNeededInfo), /// the client can fully establish a connection with this info @@ -155,7 +155,7 @@ pub(crate) enum HandshakeAction { /// core = (udp_src_sender_port % total_threads) - 1 pub(crate) struct HandshakeTracker { thread_id: ThreadTracker, - key_exchanges: Vec<(asym::KeyKind, asym::KeyExchangeKind)>, + key_exchanges: Vec, ciphers: Vec, /// ephemeral keys used server side in key exchange keys_srv: Vec, @@ -164,16 +164,24 @@ pub(crate) struct HandshakeTracker { } impl HandshakeTracker { - pub(crate) fn new(thread_id: ThreadTracker) -> Self { + pub(crate) fn new( + thread_id: ThreadTracker, + ciphers: Vec, + key_exchanges: Vec, + ) -> Self { Self { thread_id, - ciphers: Vec::new(), - key_exchanges: Vec::new(), + ciphers, + key_exchanges, keys_srv: Vec::new(), hshake_cli: HandshakeClientList::new(), } } - pub(crate) fn new_client( + pub(crate) fn add_server(&mut self, id: KeyID, key: PrivKey) { + self.keys_srv.push(HandshakeServer { id, key }); + self.keys_srv.sort_by(|h_a, h_b| h_a.id.0.cmp(&h_b.id.0)); + } + pub(crate) fn add_client( &mut self, priv_key: PrivKey, pub_key: PubKey, @@ -208,45 +216,34 @@ impl HandshakeTracker { match handshake.data { HandshakeData::DirSync(ref mut ds) => match ds { DirSync::Req(ref mut req) => { - let ephemeral_key = { - if let Some(h_k) = - self.keys_srv.iter().find(|k| k.id == req.key_id) - { + if !self.key_exchanges.contains(&req.exchange) { + return Err(enc::Error::UnsupportedKeyExchange.into()); + } + if !self.ciphers.contains(&req.cipher) { + return Err(enc::Error::UnsupportedCipher.into()); + } + let has_key = self.keys_srv.iter().find(|k| { + if k.id == req.key_id { // Directory synchronized can only use keys // for key exchange, not signing keys - if let PrivKey::Exchange(k) = &h_k.key { - Some(k.clone()) - } else { - None + if let PrivKey::Exchange(_) = k.key { + return true; } - } else { - None } - }; - if ephemeral_key.is_none() { - ::tracing::debug!( - "No such server key id: {:?}", - req.key_id - ); - return Err(handshake::Error::UnknownKeyID.into()); - } - let ephemeral_key = ephemeral_key.unwrap(); - { - if None - == self.key_exchanges.iter().find(|&x| { - *x == (ephemeral_key.kind(), req.exchange) - }) - { - return Err( - enc::Error::UnsupportedKeyExchange.into() - ); + false + }); + + let ephemeral_key; + match has_key { + Some(s_k) => { + if let PrivKey::Exchange(ref k) = &s_k.key { + ephemeral_key = k; + } else { + unreachable!(); + } } - } - { - if None - == self.ciphers.iter().find(|&x| *x == req.cipher) - { - return Err(enc::Error::UnsupportedCipher.into()); + None => { + return Err(handshake::Error::UnknownKeyID.into()) } } let shared_key = match ephemeral_key diff --git a/src/connection/socket.rs b/src/connection/socket.rs index 945dac6..717ecb3 100644 --- a/src/connection/socket.rs +++ b/src/connection/socket.rs @@ -1,6 +1,5 @@ //! Socket related types and functions -use ::arc_swap::ArcSwap; use ::std::{net::SocketAddr, sync::Arc, vec::Vec}; use ::tokio::{net::UdpSocket, task::JoinHandle}; @@ -10,82 +9,31 @@ pub type SocketTracker = /// async free socket list pub(crate) struct SocketList { - pub list: ArcSwap>, + pub list: Vec, } impl SocketList { pub(crate) fn new() -> Self { - Self { - list: ArcSwap::new(Arc::new(Vec::new())), - } + Self { list: Vec::new() } } - // TODO: fn rm_socket() - pub(crate) fn rm_all(&self) -> Self { - let new_list = Arc::new(Vec::new()); - let old_list = self.list.swap(new_list); - Self { - list: old_list.into(), - } + pub(crate) fn rm_all(&mut self) -> Self { + let mut old_list = Vec::new(); + ::core::mem::swap(&mut self.list, &mut old_list); + Self { list: old_list } } pub(crate) async fn add_socket( - &self, + &mut self, socket: Arc, handle: JoinHandle<::std::io::Result<()>>, ) { - // we could simplify this into just a `.swap` instead of `.rcu` but - // it is not yet guaranteed that only one thread will call this fn - // ...we don't need performance here anyway let arc_handle = Arc::new(handle); - self.list.rcu(|old_list| { - let mut new_list = Arc::new(Vec::with_capacity(old_list.len() + 1)); - new_list = old_list.to_vec().into(); - Arc::get_mut(&mut new_list) - .unwrap() - .push((socket.clone(), arc_handle.clone())); - new_list - }); + self.list.push((socket, arc_handle)); } /// This method assumes no other `add_sockets` are being run pub(crate) async fn stop_all(self) { - let mut arc_list = self.list.into_inner(); - let list = loop { - match Arc::try_unwrap(arc_list) { - Ok(list) => break list, - Err(arc_retry) => { - arc_list = arc_retry; - ::tokio::time::sleep(::core::time::Duration::from_millis( - 50, - )) - .await; - } - } - }; - for (_socket, mut handle) in list.into_iter() { + for (_socket, mut handle) in self.list.into_iter() { let _ = Arc::get_mut(&mut handle).unwrap().await; } } - pub(crate) fn lock(&self) -> SocketListRef { - SocketListRef { - list: self.list.load_full(), - } - } -} - -/// Reference to a locked SocketList -// TODO: impl Drop for SocketList -pub(crate) struct SocketListRef { - list: Arc>, -} -impl SocketListRef { - pub(crate) fn find(&self, sock: UdpServer) -> Option> { - match self - .list - .iter() - .find(|&(s, _)| s.local_addr().unwrap() == sock.0) - { - Some((sock_srv, _)) => Some(sock_srv.clone()), - None => None, - } - } } /// Strong typedef for a client socket address diff --git a/src/inner/worker.rs b/src/inner/worker.rs index e8f6c5a..360b004 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -62,7 +62,7 @@ pub(crate) struct Worker { thread_id: ThreadTracker, // PERF: rand uses syscalls. how to do that async? rand: Random, - stop_working: ::tokio::sync::broadcast::Receiver, + stop_working: crate::StopWorkingRecvCh, token_check: Option>>, sockets: Vec, queue: ::async_channel::Receiver, @@ -77,7 +77,7 @@ impl Worker { pub(crate) async fn new_and_loop( cfg: Config, thread_id: ThreadTracker, - stop_working: ::tokio::sync::broadcast::Receiver, + stop_working: crate::StopWorkingRecvCh, token_check: Option>>, socket_addrs: Vec<::std::net::SocketAddr>, queue: ::async_channel::Receiver, @@ -96,9 +96,9 @@ impl Worker { Ok(()) } pub(crate) async fn new( - cfg: Config, + mut cfg: Config, thread_id: ThreadTracker, - stop_working: ::tokio::sync::broadcast::Receiver, + stop_working: crate::StopWorkingRecvCh, token_check: Option>>, socket_addrs: Vec<::std::net::SocketAddr>, queue: ::async_channel::Receiver, @@ -108,36 +108,43 @@ impl Worker { // in the future we will want to have a thread-local listener too, // but before that we need ebpf to pin a connection to a thread // directly from the kernel - let socket_binding = - socket_addrs.into_iter().map(|s_addr| async move { - let socket = ::tokio::spawn(connection::socket::bind_udp( - s_addr.clone(), - )) - .await??; + let mut sock_set = ::tokio::task::JoinSet::new(); + socket_addrs.into_iter().for_each(|s_addr| { + sock_set.spawn(async move { + let socket = + connection::socket::bind_udp(s_addr.clone()).await?; Ok(socket) }); - let sockets_bind_res = - ::futures::future::join_all(socket_binding).await; - let sockets: Result, ::std::io::Error> = - sockets_bind_res - .into_iter() - .map(|s_res| match s_res { - Ok(s) => Ok(s), + }); + // make sure we either add all of them, or none + let mut sockets = Vec::with_capacity(cfg.listen.len()); + while let Some(join_res) = sock_set.join_next().await { + match join_res { + Ok(s_res) => match s_res { + Ok(sock) => sockets.push(sock), Err(e) => { - ::tracing::error!("Worker can't bind on socket: {}", e); - Err(e) + ::tracing::error!("Can't rebind socket"); + return Err(e); } - }) - .collect(); - let sockets = match sockets { - Ok(sockets) => sockets, - Err(e) => { - return Err(e); + }, + Err(e) => return Err(e.into()), } - }; + } let (queue_timeouts_send, queue_timeouts_recv) = mpsc::unbounded_channel(); + let mut handshakes = HandshakeTracker::new( + thread_id, + cfg.ciphers.clone(), + cfg.key_exchanges.clone(), + ); + let mut keys = Vec::new(); + // make sure the keys are no longer in the config + ::core::mem::swap(&mut keys, &mut cfg.keys); + for k in keys.into_iter() { + handshakes.add_server(k.0, k.1); + } + Ok(Self { cfg, thread_id, @@ -150,13 +157,15 @@ impl Worker { queue_timeouts_send, thread_channels: Vec::new(), connections: ConnList::new(thread_id), - handshakes: HandshakeTracker::new(thread_id), + handshakes, }) } pub(crate) async fn work_loop(&mut self) { 'mainloop: loop { let work = ::tokio::select! { - _done = self.stop_working.recv() => { + tell_stopped = self.stop_working.recv() => { + let _ = tell_stopped.unwrap().send( + crate::StopWorking::WorkerStopped).await; break; } maybe_timeout = self.queue.recv() => { @@ -326,7 +335,7 @@ impl Worker { conn.id_recv = auth_recv_id; let (client_key_id, hshake) = match self .handshakes - .new_client( + .add_client( PrivKey::Exchange(priv_key), PubKey::Exchange(pub_key), conn_info.service_id, @@ -617,7 +626,7 @@ impl Worker { IDSend(resp_data.service_connection_id); let _ = self.connections.track(service_connection.into()); } - HandshakeAction::Nonthing => {} + HandshakeAction::Nothing => {} }; } } diff --git a/src/lib.rs b/src/lib.rs index 0c01c8b..d08ca2e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -69,6 +69,17 @@ pub enum Error { Encrypt(enc::Error), } +pub(crate) enum StopWorking { + WorkerStopped, + ListenerStopped, +} + +pub(crate) type StopWorkingSendCh = + ::tokio::sync::broadcast::Sender<::tokio::sync::mpsc::Sender>; +pub(crate) type StopWorkingRecvCh = ::tokio::sync::broadcast::Receiver< + ::tokio::sync::mpsc::Sender, +>; + /// Instance of a fenrir endpoint #[allow(missing_copy_implementations, missing_debug_implementations)] pub struct Fenrir { @@ -79,7 +90,7 @@ pub struct Fenrir { /// DNSSEC resolver, with failovers dnssec: dnssec::Dnssec, /// Broadcast channel to tell workers to stop working - stop_working: ::tokio::sync::broadcast::Sender, + stop_working: StopWorkingSendCh, /// where to ask for token check token_check: Option>>, /// tracks the connections to authentication servers @@ -89,22 +100,74 @@ pub struct Fenrir { // manner _thread_pool: Vec<::std::thread::JoinHandle<()>>, _thread_work: Arc>>, + // This can be different from cfg.listen since using port 0 will result + // in a random port assigned by the operative system + _listen_addrs: Vec<::std::net::SocketAddr>, } - // TODO: graceful vs immediate stop impl Drop for Fenrir { fn drop(&mut self) { - self.stop_sync() + ::tracing::debug!( + "Fenrir fast shutdown.\ + Some threads might remain a bit longer" + ); + let _ = self.stop_sync(); } } impl Fenrir { + /// Gracefully stop all listeners and workers + /// only return when all resources have been deallocated + pub async fn graceful_stop(mut self) { + ::tracing::debug!("Fenrir full shut down"); + if let Some((ch, listeners, workers)) = self.stop_sync() { + self.stop_wait(ch, listeners, workers).await; + } + } + fn stop_sync( + &mut self, + ) -> Option<(::tokio::sync::mpsc::Receiver, usize, usize)> + { + let listeners_num = self.sockets.list.len(); + let workers_num = self._thread_work.len(); + if self.sockets.list.len() > 0 || self._thread_work.len() > 0 { + let (ch_send, ch_recv) = ::tokio::sync::mpsc::channel(4); + let _ = self.stop_working.send(ch_send); + let _ = self.sockets.rm_all(); + self._thread_pool.clear(); + Some((ch_recv, listeners_num, workers_num)) + } else { + None + } + } + async fn stop_wait( + &mut self, + mut ch: ::tokio::sync::mpsc::Receiver, + mut listeners_num: usize, + mut workers_num: usize, + ) { + while listeners_num > 0 && workers_num > 0 { + match ch.recv().await { + Some(stopped) => match stopped { + StopWorking::WorkerStopped => workers_num = workers_num - 1, + StopWorking::ListenerStopped => { + listeners_num = listeners_num - 1 + } + }, + _ => break, + } + } + } /// Create a new Fenrir endpoint - pub fn new(config: &Config) -> Result { + /// spawn threads pinned to cpus in our own way with tokio's runtime + pub async fn with_threads( + config: &Config, + tokio_rt: Arc<::tokio::runtime::Runtime>, + ) -> Result { let (sender, _) = ::tokio::sync::broadcast::channel(1); let dnssec = dnssec::Dnssec::new(&config.resolvers)?; - let endpoint = Fenrir { + let mut endpoint = Self { cfg: config.clone(), sockets: SocketList::new(), dnssec, @@ -113,86 +176,120 @@ impl Fenrir { conn_auth_srv: Mutex::new(AuthServerConnections::new()), _thread_pool: Vec::new(), _thread_work: Arc::new(Vec::new()), + _listen_addrs: Vec::with_capacity(config.listen.len()), }; + endpoint.start_work_threads_pinned(tokio_rt).await?; + match endpoint.add_sockets().await { + Ok(addrs) => endpoint._listen_addrs = addrs, + Err(e) => return Err(e.into()), + } Ok(endpoint) } - - ///FIXME: remove this, move into new() - /// Start all workers, listeners - pub async fn start( - &mut self, - tokio_rt: Arc<::tokio::runtime::Runtime>, - ) -> Result<(), Error> { - self.start_work_threads_pinned(tokio_rt).await?; - if let Err(e) = self.add_sockets().await { - self.stop().await; - return Err(e.into()); + /// Create a new Fenrir endpoint + /// Get the workers that you can use in a tokio LocalSet + /// You should: + /// * move these workers each in its own thread + /// * make sure that the threads are pinned on the cpu + pub async fn with_workers( + config: &Config, + ) -> Result< + ( + Self, + Vec>>, + ), + Error, + > { + let (stop_working, _) = ::tokio::sync::broadcast::channel(1); + let dnssec = dnssec::Dnssec::new(&config.resolvers)?; + let cfg = config.clone(); + let sockets = SocketList::new(); + let conn_auth_srv = Mutex::new(AuthServerConnections::new()); + let thread_pool = Vec::new(); + let thread_work = Arc::new(Vec::new()); + let listen_addrs = Vec::with_capacity(config.listen.len()); + let mut endpoint = Self { + cfg, + sockets, + dnssec, + stop_working: stop_working.clone(), + token_check: None, + conn_auth_srv, + _thread_pool: thread_pool, + _thread_work: thread_work, + _listen_addrs: listen_addrs, + }; + let worker_num = config.threads.unwrap().get(); + let mut workers = Vec::with_capacity(worker_num); + for _ in 0..worker_num { + workers.push(endpoint.start_single_worker().await?); } - Ok(()) - } - ///FIXME: remove this, move into new() - pub async fn setup_no_workers(&mut self) -> Result<(), Error> { - if let Err(e) = self.add_sockets().await { - self.stop().await; - return Err(e.into()); + match endpoint.add_sockets().await { + Ok(addrs) => endpoint._listen_addrs = addrs, + Err(e) => return Err(e.into()), } - Ok(()) + Ok((endpoint, workers)) + } + /// Returns the list of the actual addresses we are listening on + /// Note that this can be different from what was configured: + /// if you specified UDP port 0 a random one has been assigned to you + /// by the operating system. + pub fn addresses(&self) -> Vec<::std::net::SocketAddr> { + self._listen_addrs.clone() } - /// Stop all workers, listeners - /// asyncronous version for Drop - fn stop_sync(&mut self) { - let _ = self.stop_working.send(true); - let toempty_sockets = self.sockets.rm_all(); - let task = ::tokio::task::spawn(toempty_sockets.stop_all()); - let _ = ::futures::executor::block_on(task); - let mut old_thread_pool = Vec::new(); - ::std::mem::swap(&mut self._thread_pool, &mut old_thread_pool); - let _ = old_thread_pool.into_iter().map(|th| th.join()); - } - - /// Stop all workers, listeners - pub async fn stop(&mut self) { - let _ = self.stop_working.send(true); - let toempty_sockets = self.sockets.rm_all(); - toempty_sockets.stop_all().await; - let mut old_thread_pool = Vec::new(); - ::std::mem::swap(&mut self._thread_pool, &mut old_thread_pool); - let _ = old_thread_pool.into_iter().map(|th| th.join()); - } + // only call **after** starting all threads /// Add all UDP sockets found in config /// and start listening for packets - async fn add_sockets(&self) -> ::std::io::Result<()> { - let sockets = self.cfg.listen.iter().map(|s_addr| async { - let socket = - ::tokio::spawn(connection::socket::bind_udp(s_addr.clone())) - .await??; - Ok(socket) + async fn add_sockets( + &mut self, + ) -> ::std::io::Result> { + // try to bind multiple sockets in parallel + let mut sock_set = ::tokio::task::JoinSet::new(); + self.cfg.listen.iter().for_each(|s_addr| { + let socket_address = s_addr.clone(); + let stop_working = self.stop_working.subscribe(); + let th_work = self._thread_work.clone(); + sock_set.spawn(async move { + let s = connection::socket::bind_udp(socket_address).await?; + let arc_s = Arc::new(s); + let join = ::tokio::spawn(Self::listen_udp( + stop_working, + th_work, + arc_s.clone(), + )); + Ok((arc_s, join)) + }); }); - let sockets = ::futures::future::join_all(sockets).await; - for s_res in sockets.into_iter() { - match s_res { - Ok(s) => { - let stop_working = self.stop_working.subscribe(); - let arc_s = Arc::new(s); - let join = ::tokio::spawn(Self::listen_udp( - stop_working, - self._thread_work.clone(), - arc_s.clone(), - )); - self.sockets.add_socket(arc_s, join).await; - } + + // make sure we either add all of them, or none + let mut all_socks = Vec::with_capacity(self.cfg.listen.len()); + while let Some(join_res) = sock_set.join_next().await { + match join_res { + Ok(s_res) => match s_res { + Ok(s) => { + all_socks.push(s); + } + Err(e) => { + return Err(e); + } + }, Err(e) => { - return Err(e); + return Err(e.into()); } } } - Ok(()) + + let mut ret = Vec::with_capacity(self.cfg.listen.len()); + for (arc_s, join) in all_socks.into_iter() { + ret.push(arc_s.local_addr().unwrap()); + self.sockets.add_socket(arc_s, join).await; + } + Ok(ret) } /// Run a dedicated loop to read packets on the listening socket async fn listen_udp( - mut stop_working: ::tokio::sync::broadcast::Receiver, + mut stop_working: StopWorkingRecvCh, work_queues: Arc>>, socket: Arc, ) -> ::std::io::Result<()> { @@ -202,8 +299,11 @@ impl Fenrir { let queues_num = work_queues.len() as u64; loop { let (bytes, sock_sender) = ::tokio::select! { - _done = stop_working.recv() => { - break; + tell_stopped = stop_working.recv() => { + drop(socket); + let _ = tell_stopped.unwrap() + .send(StopWorking::ListenerStopped).await; + return Ok(()); } result = socket.recv_from(&mut buffer) => { result? @@ -241,7 +341,6 @@ impl Fenrir { })) .await; } - Ok(()) } /// Get the raw TXT record of a Fenrir domain pub async fn resolv_txt(&self, domain: &Domain) -> Result { @@ -373,6 +472,7 @@ impl Fenrir { } } + // needs to be called before add_sockets async fn start_single_worker( &mut self, ) -> ::std::result::Result< @@ -404,6 +504,10 @@ impl Fenrir { self.cfg.listen.clone(), work_recv, ); + // don't keep around private keys too much + if (thread_idx + 1) == max_threads { + self.cfg.keys.clear(); + } loop { let queues_lock = match Arc::get_mut(&mut self._thread_work) { Some(queues_lock) => queues_lock, @@ -421,7 +525,8 @@ impl Fenrir { } Ok(worker) } - // TODO: start work on a LocalSet provided by the user + + // needs to be called before add_sockets /// Start one working thread for each physical cpu /// threads are pinned to each cpu core. /// Work will be divided and rerouted so that there is no need to lock @@ -521,6 +626,8 @@ impl Fenrir { } self._thread_pool.push(join_handle); } + // don't keep around private keys too much + self.cfg.keys.clear(); Ok(()) } } diff --git a/src/tests.rs b/src/tests.rs index 1032012..7d19b4f 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -1,8 +1,8 @@ use crate::*; +#[::tracing_test::traced_test] #[::tokio::test] async fn test_connection_dirsync() { - return; use enc::asym::{KeyID, PrivKey, PubKey}; let rand = enc::Random::new(); let (priv_exchange_key, pub_exchange_key) = @@ -16,22 +16,6 @@ async fn test_connection_dirsync() { return; } }; - let dnssec_record = Record { - public_keys: [(KeyID(42), pub_exchange_key)].to_vec(), - addresses: [record::Address { - ip: ::std::net::IpAddr::V4(::std::net::Ipv4Addr::new(127, 0, 0, 1)), - port: Some(::core::num::NonZeroU16::new(31337).unwrap()), - priority: record::AddressPriority::P1, - weight: record::AddressWeight::W1, - handshake_ids: [HandshakeID::DirectorySynchronized].to_vec(), - public_key_idx: [record::PubKeyIdx(0)].to_vec(), - }] - .to_vec(), - key_exchanges: [enc::asym::KeyExchangeKind::X25519DiffieHellman] - .to_vec(), - hkdfs: [enc::hkdf::HkdfKind::Sha3].to_vec(), - ciphers: [enc::sym::CipherKind::XChaCha20Poly1305].to_vec(), - }; let cfg_client = { let mut cfg = config::Config::default(); cfg.threads = Some(::core::num::NonZeroUsize::new(1).unwrap()); @@ -43,21 +27,46 @@ async fn test_connection_dirsync() { cfg }; - let mut server = Fenrir::new(&cfg_server).unwrap(); - let _ = server.setup_no_workers().await; - let srv_worker = server.start_single_worker().await; + let (server, mut srv_workers) = + Fenrir::with_workers(&cfg_server).await.unwrap(); - ::tokio::task::spawn_local(async move { srv_worker }); - let mut client = Fenrir::new(&cfg_client).unwrap(); - let _ = client.setup_no_workers().await; - let cli_worker = server.start_single_worker().await; - ::tokio::task::spawn_local(async move { cli_worker }); + let srv_worker = srv_workers.pop().unwrap(); + let local_thread = ::tokio::task::LocalSet::new(); + local_thread.spawn_local(async move { srv_worker.await }); + + let (client, mut cli_workers) = + Fenrir::with_workers(&cfg_client).await.unwrap(); + let cli_worker = cli_workers.pop().unwrap(); + local_thread.spawn_local(async move { cli_worker.await }); use crate::{ connection::handshake::HandshakeID, dnssec::{record, Record}, }; + let port: u16 = server.addresses()[0].port(); + + let dnssec_record = Record { + public_keys: [(KeyID(42), pub_exchange_key)].to_vec(), + addresses: [record::Address { + ip: ::std::net::IpAddr::V4(::std::net::Ipv4Addr::new(127, 0, 0, 1)), + port: Some(::core::num::NonZeroU16::new(port).unwrap()), + priority: record::AddressPriority::P1, + weight: record::AddressWeight::W1, + handshake_ids: [HandshakeID::DirectorySynchronized].to_vec(), + public_key_idx: [record::PubKeyIdx(0)].to_vec(), + }] + .to_vec(), + key_exchanges: [enc::asym::KeyExchangeKind::X25519DiffieHellman] + .to_vec(), + hkdfs: [enc::hkdf::HkdfKind::Sha3].to_vec(), + ciphers: [enc::sym::CipherKind::XChaCha20Poly1305].to_vec(), + }; + + server.graceful_stop().await; + client.graceful_stop().await; + return; + let _ = client .connect_resolved( dnssec_record, @@ -65,62 +74,6 @@ async fn test_connection_dirsync() { auth::SERVICEID_AUTH, ) .await; - - /* - let thread_id = ThreadTracker { total: 1, id: 0 }; - - let (stop_sender, _) = ::tokio::sync::broadcast::channel::(1); - - use ::std::net; - let cli_socket_addr = [net::SocketAddr::new( - net::IpAddr::V4(net::Ipv4Addr::new(127, 0, 0, 1)), - 0, - )] - .to_vec(); - let srv_socket_addr = [net::SocketAddr::new( - net::IpAddr::V4(net::Ipv4Addr::new(127, 0, 0, 1)), - 0, - )] - .to_vec(); - - let srv_sock = Arc::new(connection::socket::bind_udp(srv_socket_addr[0]) - .await - .unwrap()); - let cli_sock = Arc::new(connection::socket::bind_udp(cli_socket_addr[0]) - .await - .unwrap()); - - use crate::inner::worker::Work; - let (srv_work_send, srv_work_recv) = ::async_channel::unbounded::(); - let (cli_work_send, cli_work_recv) = ::async_channel::unbounded::(); - - let srv_queue = Arc::new([srv_work_recv.clone()].to_vec()); - let cli_queue = Arc::new([cli_work_recv.clone()].to_vec()); - - let listen_work_srv = - - - ::tokio::spawn(Fenrir::listen_udp( - stop_sender.subscribe(), - - - let _server = crate::inner::worker::Worker::new( - cfg.clone(), - thread_id, - stop_sender.subscribe(), - None, - srv_socket_addr, - srv_work_recv, - ); - let _client = crate::inner::worker::Worker::new( - cfg, - thread_id, - stop_sender.subscribe(), - None, - cli_socket_addr, - cli_work_recv, - ); - - todo!() - */ + server.graceful_stop().await; + client.graceful_stop().await; } From 866edc2d7d012d0c2d8c16799537056562aff48f Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Sat, 17 Jun 2023 11:33:47 +0200 Subject: [PATCH 33/34] TONS of bugfixing. Add tests. Client now connects Signed-off-by: Luca Fulchir --- TODO | 1 + src/auth/mod.rs | 16 +- src/config/mod.rs | 26 ++- src/connection/handshake/dirsync.rs | 44 +++-- src/connection/handshake/mod.rs | 6 + src/connection/handshake/tracker.rs | 95 ++++++++-- src/connection/mod.rs | 23 ++- src/connection/socket.rs | 142 ++++++++++----- src/enc/asym.rs | 2 +- src/enc/mod.rs | 2 + src/enc/sym.rs | 39 ++-- src/enc/tests.rs | 135 ++++++++++++++ src/inner/worker.rs | 272 ++++++++++++++-------------- src/lib.rs | 232 +++++++++++++----------- src/tests.rs | 59 ++++-- 15 files changed, 739 insertions(+), 355 deletions(-) create mode 100644 TODO create mode 100644 src/enc/tests.rs diff --git a/TODO b/TODO new file mode 100644 index 0000000..9531367 --- /dev/null +++ b/TODO @@ -0,0 +1 @@ +* Wrapping for everything that wraps (sigh) diff --git a/src/auth/mod.rs b/src/auth/mod.rs index 84be8cb..085816b 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -3,6 +3,8 @@ use crate::enc::Random; use ::zeroize::Zeroize; +/// Anonymous user id +pub const USERID_ANONYMOUS: UserID = UserID([0; UserID::len()]); /// User identifier. 16 bytes for easy uuid conversion #[derive(Debug, Copy, Clone, PartialEq)] pub struct UserID(pub [u8; 16]); @@ -25,8 +27,8 @@ impl UserID { } } /// Anonymous user id - pub fn new_anonymous() -> Self { - UserID([0; 16]) + pub const fn new_anonymous() -> Self { + USERID_ANONYMOUS } /// length of the User ID in bytes pub const fn len() -> usize { @@ -98,6 +100,16 @@ impl TryFrom<&[u8]> for Domain { Ok(Domain(domain_string)) } } +impl From for Domain { + fn from(raw: String) -> Self { + Self(raw) + } +} +impl From<&str> for Domain { + fn from(raw: &str) -> Self { + Self(raw.to_owned()) + } +} impl Domain { /// length of the User ID in bytes diff --git a/src/config/mod.rs b/src/config/mod.rs index c0fe949..f196ac9 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -16,6 +16,23 @@ use ::std::{ vec, }; +/// Key used by a server during the handshake +#[derive(Clone, Debug)] +pub struct ServerKey { + pub id: KeyID, + pub priv_key: PrivKey, + pub pub_key: PubKey, +} + +/// Authentication Server information and keys +#[derive(Clone, Debug)] +pub struct AuthServer { + /// fqdn of the authentication server + pub fqdn: crate::auth::Domain, + /// list of key ids enabled for this domain + pub keys: Vec, +} + /// Main config for libFenrir #[derive(Clone, Debug)] pub struct Config { @@ -34,8 +51,12 @@ pub struct Config { pub hkdfs: Vec, /// Supported Ciphers pub ciphers: Vec, + /// list of authentication servers + /// clients will have this empty + pub servers: Vec, /// list of public/private keys - pub keys: Vec<(KeyID, PrivKey, PubKey)>, + /// clients should have this empty + pub server_keys: Vec, } impl Default for Config { @@ -56,7 +77,8 @@ impl Default for Config { key_exchanges: [KeyExchangeKind::X25519DiffieHellman].to_vec(), hkdfs: [HkdfKind::Sha3].to_vec(), ciphers: [CipherKind::XChaCha20Poly1305].to_vec(), - keys: Vec::new(), + servers: Vec::new(), + server_keys: Vec::new(), } } } diff --git a/src/connection/handshake/dirsync.rs b/src/connection/handshake/dirsync.rs index 8023ed6..0be2dd4 100644 --- a/src/connection/handshake/dirsync.rs +++ b/src/connection/handshake/dirsync.rs @@ -113,10 +113,14 @@ impl Req { + self.exchange_key.kind().pub_len() } /// return the total length of the cleartext data - pub fn encrypted_length(&self) -> usize { + pub fn encrypted_length( + &self, + head_len: HeadLen, + tag_len: TagLen, + ) -> usize { match &self.data { - ReqInner::ClearText(data) => data.len(), - _ => 0, + ReqInner::ClearText(data) => data.len() + head_len.0 + tag_len.0, + ReqInner::CipherText(length) => *length, } } /// actual length of the directory synchronized request @@ -177,11 +181,16 @@ impl super::HandshakeParsing for Req { Some(cipher) => cipher, None => return Err(Error::Parsing), }; - let (exchange_key, len) = match ExchangePubKey::deserialize(&raw[5..]) { - Ok(exchange_key) => exchange_key, - Err(e) => return Err(e.into()), - }; - let data = ReqInner::CipherText(raw.len() - (5 + len)); + const CURR_SIZE: usize = KeyID::len() + + KeyExchangeKind::len() + + HkdfKind::len() + + CipherKind::len(); + let (exchange_key, len) = + match ExchangePubKey::deserialize(&raw[CURR_SIZE..]) { + Ok(exchange_key) => exchange_key, + Err(e) => return Err(e.into()), + }; + let data = ReqInner::CipherText(raw.len() - (CURR_SIZE + len)); Ok(HandshakeData::DirSync(DirSync::Req(Self { key_id, exchange, @@ -436,7 +445,7 @@ impl super::HandshakeParsing for Resp { return Err(Error::NotEnoughData); } let client_key_id: KeyID = - KeyID(u16::from_le_bytes(raw[0..2].try_into().unwrap())); + KeyID(u16::from_le_bytes(raw[0..KeyID::len()].try_into().unwrap())); Ok(HandshakeData::DirSync(DirSync::Resp(Self { client_key_id, data: RespInner::CipherText(raw[KeyID::len()..].len()), @@ -453,10 +462,16 @@ impl Resp { + KeyID::len() } /// return the total length of the cleartext data - pub fn encrypted_length(&self) -> usize { + pub fn encrypted_length( + &self, + head_len: HeadLen, + tag_len: TagLen, + ) -> usize { match &self.data { - RespInner::ClearText(_data) => RespData::len(), - _ => 0, + RespInner::ClearText(_data) => { + RespData::len() + head_len.0 + tag_len.0 + } + RespInner::CipherText(len) => *len, } } /// Total length of the response handshake @@ -471,8 +486,9 @@ impl Resp { _tag_len: TagLen, out: &mut [u8], ) { - out[0..2].copy_from_slice(&self.client_key_id.0.to_le_bytes()); - let start_data = 2 + head_len.0; + out[0..KeyID::len()] + .copy_from_slice(&self.client_key_id.0.to_le_bytes()); + let start_data = KeyID::len() + head_len.0; let end_data = start_data + self.data.len(); self.data.serialize(&mut out[start_data..end_data]); } diff --git a/src/connection/handshake/mod.rs b/src/connection/handshake/mod.rs index 6dd248e..b5204a1 100644 --- a/src/connection/handshake/mod.rs +++ b/src/connection/handshake/mod.rs @@ -37,6 +37,12 @@ pub enum Error { /// Too many client handshakes currently running #[error("Too many client handshakes")] TooManyClientHandshakes, + /// generic internal error + #[error("Internal tracking error")] + InternalTracking, + /// Handshake Timeout + #[error("Handshake timeout")] + Timeout, } /// List of possible handshakes diff --git a/src/connection/handshake/tracker.rs b/src/connection/handshake/tracker.rs index f304337..63ee31d 100644 --- a/src/connection/handshake/tracker.rs +++ b/src/connection/handshake/tracker.rs @@ -1,11 +1,11 @@ //! Handhsake handling use crate::{ - auth::ServiceID, + auth::{Domain, ServiceID}, connection::{ self, handshake::{self, Error, Handshake}, - Connection, IDRecv, + Connection, IDRecv, IDSend, }, enc::{ self, @@ -16,16 +16,23 @@ use crate::{ inner::ThreadTracker, }; +use ::tokio::sync::oneshot; + pub(crate) struct HandshakeServer { pub id: KeyID, pub key: PrivKey, + pub domains: Vec, } +pub(crate) type ConnectAnswer = Result<(KeyID, IDSend), crate::Error>; + pub(crate) struct HandshakeClient { pub service_id: ServiceID, pub service_conn_id: IDRecv, pub connection: Connection, pub timeout: Option<::tokio::task::JoinHandle<()>>, + pub answer: oneshot::Sender, + pub srv_key_id: KeyID, } /// Tracks the keys used by the client and the handshake @@ -73,7 +80,10 @@ impl HandshakeClientList { service_id: ServiceID, service_conn_id: IDRecv, connection: Connection, - ) -> Result<(KeyID, &mut HandshakeClient), ()> { + answer: oneshot::Sender, + srv_key_id: KeyID, + ) -> Result<(KeyID, &mut HandshakeClient), oneshot::Sender> + { let maybe_free_key_idx = self.used.iter().enumerate().find_map(|(idx, bmap)| { match bmap.first_false_index() { @@ -85,7 +95,7 @@ impl HandshakeClientList { Some((idx, false_idx)) => { let free_key_idx = idx * 1024 + false_idx; if free_key_idx > KeyID::MAX as usize { - return Err(()); + return Err(answer); } self.used[idx].set(false_idx, true); free_key_idx @@ -107,6 +117,8 @@ impl HandshakeClientList { service_conn_id, connection, timeout: None, + answer, + srv_key_id, }); Ok(( KeyID(free_key_idx as u16), @@ -136,6 +148,10 @@ pub(crate) struct ClientConnectInfo { pub handshake: Handshake, /// Connection pub connection: Connection, + /// where to wake up the waiting client + pub answer: oneshot::Sender, + /// server public key id that we used on the handshake + pub srv_key_id: KeyID, } /// Intermediate actions to be taken while parsing the handshake #[derive(Debug)] @@ -177,10 +193,42 @@ impl HandshakeTracker { hshake_cli: HandshakeClientList::new(), } } - pub(crate) fn add_server(&mut self, id: KeyID, key: PrivKey) { - self.keys_srv.push(HandshakeServer { id, key }); + pub(crate) fn add_server_key( + &mut self, + id: KeyID, + key: PrivKey, + ) -> Result<(), ()> { + if self.keys_srv.iter().find(|&k| k.id == id).is_some() { + return Err(()); + } + self.keys_srv.push(HandshakeServer { + id, + key, + domains: Vec::new(), + }); self.keys_srv.sort_by(|h_a, h_b| h_a.id.0.cmp(&h_b.id.0)); + Ok(()) } + pub(crate) fn add_server_domain( + &mut self, + domain: &Domain, + key_ids: &[KeyID], + ) -> Result<(), ()> { + // check that all the key ids are present + for id in key_ids.iter() { + if self.keys_srv.iter().find(|k| k.id == *id).is_none() { + return Err(()); + } + } + // add the domain to those keys + for id in key_ids.iter() { + if let Some(srv) = self.keys_srv.iter_mut().find(|k| k.id == *id) { + srv.domains.push(domain.clone()); + } + } + Ok(()) + } + pub(crate) fn add_client( &mut self, priv_key: PrivKey, @@ -188,20 +236,32 @@ impl HandshakeTracker { service_id: ServiceID, service_conn_id: IDRecv, connection: Connection, - ) -> Result<(KeyID, &mut HandshakeClient), ()> { + answer: oneshot::Sender, + srv_key_id: KeyID, + ) -> Result<(KeyID, &mut HandshakeClient), oneshot::Sender> + { self.hshake_cli.add( priv_key, pub_key, service_id, service_conn_id, connection, + answer, + srv_key_id, ) } + pub(crate) fn remove_client( + &mut self, + key_id: KeyID, + ) -> Option { + self.hshake_cli.remove(key_id) + } pub(crate) fn timeout_client( &mut self, key_id: KeyID, ) -> Option<[IDRecv; 2]> { if let Some(hshake) = self.hshake_cli.remove(key_id) { + let _ = hshake.answer.send(Err(Error::Timeout.into())); Some([hshake.connection.id_recv, hshake.service_conn_id]) } else { None @@ -257,9 +317,16 @@ impl HandshakeTracker { let cipher_recv = CipherRecv::new(req.cipher, secret_recv); use crate::enc::sym::AAD; let aad = AAD(&mut []); // no aad for now + + let encrypt_from = req.encrypted_offset(); + let encrypt_to = encrypt_from + + req.encrypted_length( + cipher_recv.nonce_len(), + cipher_recv.tag_len(), + ); match cipher_recv.decrypt( aad, - &mut handshake_raw[req.encrypted_offset()..], + &mut handshake_raw[encrypt_from..encrypt_to], ) { Ok(cleartext) => { req.data.deserialize_as_cleartext(cleartext)?; @@ -292,9 +359,13 @@ impl HandshakeTracker { use crate::enc::sym::AAD; // no aad for now let aad = AAD(&mut []); - let mut raw_data = &mut handshake_raw[resp - .encrypted_offset() - ..(resp.encrypted_offset() + resp.encrypted_length())]; + let data_from = resp.encrypted_offset(); + let data_to = data_from + + resp.encrypted_length( + cipher_recv.nonce_len(), + cipher_recv.tag_len(), + ); + let mut raw_data = &mut handshake_raw[data_from..data_to]; match cipher_recv.decrypt(aad, &mut raw_data) { Ok(cleartext) => { resp.data.deserialize_as_cleartext(&cleartext)?; @@ -314,6 +385,8 @@ impl HandshakeTracker { service_connection_id: hshake.service_conn_id, handshake, connection: hshake.connection, + answer: hshake.answer, + srv_key_id: hshake.srv_key_id, }, )); } diff --git a/src/connection/mod.rs b/src/connection/mod.rs index 07b7c18..45a50e9 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -113,7 +113,11 @@ pub(crate) struct ConnList { impl ConnList { pub(crate) fn new(thread_id: ThreadTracker) -> Self { - let bitmap_id = ::bitmaps::Bitmap::<1024>::new(); + let mut bitmap_id = ::bitmaps::Bitmap::<1024>::new(); + if thread_id.id == 0 { + // make sure we don't count the Handshake ID + bitmap_id.set(0, true); + } const INITIAL_CAP: usize = 128; let mut ret = Self { thread_id, @@ -199,13 +203,6 @@ impl ConnList { } } -use ::std::collections::HashMap; - -enum MapEntry { - Present(IDSend), - Reserved, -} - /// return wether we already have a connection, we are waiting for one, or you /// can start one #[derive(Debug, Clone, Copy)] @@ -218,6 +215,12 @@ pub(crate) enum Reservation { Reserved, } +enum MapEntry { + Present(IDSend), + Reserved, +} +use ::std::collections::HashMap; + /// Link the public key of the authentication server to a connection id /// so that we can reuse that connection to ask for more authentications /// @@ -229,16 +232,16 @@ pub(crate) enum Reservation { /// * wait for the connection to finish /// * remove all those reservations, exept the one key that actually succeded /// While searching, we return a connection ID if just one key is a match +// TODO: can we shard this per-core by hashing the pubkey? or domain? or...??? +// This needs a mutex and it will be our goeal to avoid any synchronization pub(crate) struct AuthServerConnections { conn_map: HashMap, - next_reservation: u64, } impl AuthServerConnections { pub(crate) fn new() -> Self { Self { conn_map: HashMap::with_capacity(32), - next_reservation: 0, } } /// add an ID to the reserved spot, diff --git a/src/connection/socket.rs b/src/connection/socket.rs index 717ecb3..abfc106 100644 --- a/src/connection/socket.rs +++ b/src/connection/socket.rs @@ -1,40 +1,10 @@ //! Socket related types and functions -use ::std::{net::SocketAddr, sync::Arc, vec::Vec}; +use ::std::net::SocketAddr; use ::tokio::{net::UdpSocket, task::JoinHandle}; /// Pair to easily track the socket and its async listening handle -pub type SocketTracker = - (Arc, Arc>>); - -/// async free socket list -pub(crate) struct SocketList { - pub list: Vec, -} -impl SocketList { - pub(crate) fn new() -> Self { - Self { list: Vec::new() } - } - pub(crate) fn rm_all(&mut self) -> Self { - let mut old_list = Vec::new(); - ::core::mem::swap(&mut self.list, &mut old_list); - Self { list: old_list } - } - pub(crate) async fn add_socket( - &mut self, - socket: Arc, - handle: JoinHandle<::std::io::Result<()>>, - ) { - let arc_handle = Arc::new(handle); - self.list.push((socket, arc_handle)); - } - /// This method assumes no other `add_sockets` are being run - pub(crate) async fn stop_all(self) { - for (_socket, mut handle) in self.list.into_iter() { - let _ = Arc::get_mut(&mut handle).unwrap().await; - } - } -} +pub type SocketTracker = (SocketAddr, JoinHandle<::std::io::Result<()>>); /// Strong typedef for a client socket address #[derive(Debug, Copy, Clone)] @@ -53,7 +23,7 @@ fn enable_sock_opt( unsafe { #[allow(trivial_casts)] let val = &value as *const _ as *const ::libc::c_void; - let size = ::std::mem::size_of_val(&value) as ::libc::socklen_t; + let size = ::core::mem::size_of_val(&value) as ::libc::socklen_t; // always clear the error bit before doing a new syscall let _ = ::std::io::Error::last_os_error(); let ret = ::libc::setsockopt(fd, ::libc::SOL_SOCKET, option, val, size); @@ -64,23 +34,107 @@ fn enable_sock_opt( Ok(()) } /// Add an async udp listener -pub async fn bind_udp(sock: SocketAddr) -> ::std::io::Result { - let socket = UdpSocket::bind(sock).await?; +pub async fn bind_udp(addr: SocketAddr) -> ::std::io::Result { + // I know, kind of a mess. but I really wanted SO_REUSE{ADDR,PORT} and + // no-fragmenting stuff. + // I also did not want to load another library for this. + // feel free to simplify, + // especially if we can avoid libc and other libraries + // we currently use libc because it's a dependency of many other deps - use ::std::os::fd::AsRawFd; - let fd = socket.as_raw_fd(); - // can be useful later on for reloads - enable_sock_opt(fd, ::libc::SO_REUSEADDR, 1)?; - enable_sock_opt(fd, ::libc::SO_REUSEPORT, 1)?; + let fd: ::std::os::fd::RawFd = { + let domain = if addr.is_ipv6() { + ::libc::AF_INET6 + } else { + ::libc::AF_INET + }; + #[allow(unsafe_code)] + let tmp = unsafe { ::libc::socket(domain, ::libc::SOCK_DGRAM, 0) }; + let lasterr = ::std::io::Error::last_os_error(); + if tmp == -1 { + return Err(lasterr); + } + tmp.into() + }; + + if let Err(e) = enable_sock_opt(fd, ::libc::SO_REUSEPORT, 1) { + #[allow(unsafe_code)] + unsafe { + ::libc::close(fd); + } + return Err(e); + } + if let Err(e) = enable_sock_opt(fd, ::libc::SO_REUSEADDR, 1) { + #[allow(unsafe_code)] + unsafe { + ::libc::close(fd); + } + return Err(e); + } // We will do path MTU discovery by ourselves, // always set the "don't fragment" bit - if sock.is_ipv6() { - enable_sock_opt(fd, ::libc::IPV6_DONTFRAG, 1)?; + let res = if addr.is_ipv6() { + enable_sock_opt(fd, ::libc::IPV6_DONTFRAG, 1) } else { // FIXME: linux only - enable_sock_opt(fd, ::libc::IP_MTU_DISCOVER, ::libc::IP_PMTUDISC_DO)?; + enable_sock_opt(fd, ::libc::IP_MTU_DISCOVER, ::libc::IP_PMTUDISC_DO) + }; + if let Err(e) = res { + #[allow(unsafe_code)] + unsafe { + ::libc::close(fd); + } + return Err(e); + } + // manually convert rust SockAddr to C sockaddr + #[allow(unsafe_code, trivial_casts, trivial_numeric_casts)] + { + let bind_ret = match addr { + SocketAddr::V4(s4) => { + let ip4: u32 = (*s4.ip()).into(); + let bind_addr = ::libc::sockaddr_in { + sin_family: ::libc::AF_INET as u16, + sin_port: s4.port().to_be(), + sin_addr: ::libc::in_addr { s_addr: ip4 }, + sin_zero: [0; 8], + }; + unsafe { + let c_addr = + &bind_addr as *const _ as *const ::libc::sockaddr; + ::libc::bind(fd, c_addr, 16) + } + } + SocketAddr::V6(s6) => { + let ip6: [u8; 16] = (*s6.ip()).octets(); + let bind_addr = ::libc::sockaddr_in6 { + sin6_family: ::libc::AF_INET6 as u16, + sin6_port: s6.port().to_be(), + sin6_flowinfo: 0, + sin6_addr: ::libc::in6_addr { s6_addr: ip6 }, + sin6_scope_id: 0, + }; + unsafe { + let c_addr = + &bind_addr as *const _ as *const ::libc::sockaddr; + ::libc::bind(fd, c_addr, 24) + } + } + }; + let lasterr = ::std::io::Error::last_os_error(); + if bind_ret != 0 { + unsafe { + ::libc::close(fd); + } + return Err(lasterr); + } } - Ok(socket) + use ::std::os::fd::FromRawFd; + #[allow(unsafe_code)] + let std_sock = unsafe { ::std::net::UdpSocket::from_raw_fd(fd) }; + std_sock.set_nonblocking(true)?; + ::tracing::debug!("Listening udp sock: {}", std_sock.local_addr().unwrap()); + + Ok(UdpSocket::from_std(std_sock)?) } diff --git a/src/enc/asym.rs b/src/enc/asym.rs index 32c3aa2..47a3b5a 100644 --- a/src/enc/asym.rs +++ b/src/enc/asym.rs @@ -152,7 +152,7 @@ pub enum KeyExchangeKind { } impl KeyExchangeKind { /// The serialize length of the field - pub fn len() -> usize { + pub const fn len() -> usize { 1 } /// Build a new keypair for key exchange diff --git a/src/enc/mod.rs b/src/enc/mod.rs index 9a01b98..3aeea7b 100644 --- a/src/enc/mod.rs +++ b/src/enc/mod.rs @@ -4,6 +4,8 @@ pub mod asym; mod errors; pub mod hkdf; pub mod sym; +#[cfg(test)] +mod tests; pub use errors::Error; diff --git a/src/enc/sym.rs b/src/enc/sym.rs index 5728808..d4204e0 100644 --- a/src/enc/sym.rs +++ b/src/enc/sym.rs @@ -25,12 +25,11 @@ pub enum CipherKind { impl CipherKind { /// length of the serialized id for the cipher kind field - pub fn len() -> usize { + pub const fn len() -> usize { 1 } /// required length of the nonce pub fn nonce_len(&self) -> HeadLen { - // TODO: how the hell do I take this from ::chacha20poly1305? HeadLen(Nonce::len()) } /// required length of the key @@ -92,10 +91,7 @@ impl Cipher { } fn nonce_len(&self) -> HeadLen { match self { - Cipher::XChaCha20Poly1305(_) => { - // TODO: how the hell do I take this from ::chacha20poly1305? - HeadLen(::ring::aead::CHACHA20_POLY1305.nonce_len()) - } + Cipher::XChaCha20Poly1305(_) => HeadLen(Nonce::len()), } } fn tag_len(&self) -> TagLen { @@ -117,10 +113,13 @@ impl Cipher { aead::generic_array::GenericArray, AeadInPlace, }; let final_len: usize = { - // FIXME: check min data length - let (nonce_bytes, data_and_tag) = raw_data.split_at_mut(13); + if raw_data.len() <= self.overhead() { + return Err(Error::NotEnoughData(raw_data.len())); + } + let (nonce_bytes, data_and_tag) = + raw_data.split_at_mut(Nonce::len()); let (data_notag, tag_bytes) = data_and_tag.split_at_mut( - data_and_tag.len() + 1 + data_and_tag.len() - ::ring::aead::CHACHA20_POLY1305.tag_len(), ); let nonce = GenericArray::from_slice(nonce_bytes); @@ -172,10 +171,7 @@ impl Cipher { &mut data[Nonce::len()..data_len_notag], ) { Ok(tag) => { - data[data_len_notag..] - // add tag - //data.get_tag_slice() - .copy_from_slice(tag.as_slice()); + data[data_len_notag..].copy_from_slice(tag.as_slice()); Ok(()) } Err(_) => Err(Error::Encrypt), @@ -205,6 +201,10 @@ impl CipherRecv { pub fn nonce_len(&self) -> HeadLen { self.0.nonce_len() } + /// Get the length of the nonce for this cipher + pub fn tag_len(&self) -> TagLen { + self.0.tag_len() + } /// Decrypt a paket. Nonce and Tag are taken from the packet, /// while you need to provide AAD (Additional Authenticated Data) pub fn decrypt<'a>( @@ -285,7 +285,7 @@ struct NonceNum { #[repr(C)] pub union Nonce { num: NonceNum, - raw: [u8; 12], + raw: [u8; Self::len()], } impl ::core::fmt::Debug for Nonce { @@ -303,13 +303,17 @@ impl ::core::fmt::Debug for Nonce { impl Nonce { /// Generate a new random Nonce pub fn new(rand: &Random) -> Self { - let mut raw = [0; 12]; + let mut raw = [0; Self::len()]; rand.fill(&mut raw); Self { raw } } /// Length of this nonce in bytes pub const fn len() -> usize { - return 12; + // FIXME: was:12. xchacha20poly1305 requires 24. + // but we should change keys much earlier than that, and our + // nonces are not random, but sequential. + // we should change keys every 2^30 bytes to be sure (stream max window) + return 24; } /// Get reference to the nonce bytes pub fn as_bytes(&self) -> &[u8] { @@ -319,7 +323,7 @@ impl Nonce { } } /// Create Nonce from array - pub fn from_slice(raw: [u8; 12]) -> Self { + pub fn from_slice(raw: [u8; Self::len()]) -> Self { Self { raw } } /// Go to the next nonce @@ -336,6 +340,7 @@ impl Nonce { } /// Synchronize the mutex acess with a nonce for multithread safety +// TODO: remove mutex, not needed anymore #[derive(Debug)] pub struct NonceSync { nonce: ::std::sync::Mutex, diff --git a/src/enc/tests.rs b/src/enc/tests.rs new file mode 100644 index 0000000..ead07ee --- /dev/null +++ b/src/enc/tests.rs @@ -0,0 +1,135 @@ +use crate::{ + auth, + connection::{handshake::*, ID}, + enc::{self, asym::KeyID}, +}; + +#[test] +fn test_simple_encrypt_decrypt() { + let rand = enc::Random::new(); + let cipher = enc::sym::CipherKind::XChaCha20Poly1305; + let secret = enc::Secret::new_rand(&rand); + let secret2 = secret.clone(); + + let cipher_send = enc::sym::CipherSend::new(cipher, secret, &rand); + let cipher_recv = enc::sym::CipherRecv::new(cipher, secret2); + + let mut data = Vec::new(); + let tot_len = cipher_recv.nonce_len().0 + 1234 + cipher_recv.tag_len().0; + data.resize(tot_len, 0); + rand.fill(&mut data); + data[..enc::sym::Nonce::len()].copy_from_slice(&[0; 24]); + let last = data.len() - cipher_recv.tag_len().0; + data[last..].copy_from_slice(&[0; 16]); + let orig = data.clone(); + let raw_aad: [u8; 8] = [0, 1, 2, 3, 4, 5, 6, 7]; + let aad = enc::sym::AAD(&raw_aad[..]); + let aad2 = enc::sym::AAD(&raw_aad[..]); + if cipher_send.encrypt(aad, &mut data).is_err() { + assert!(false, "Encrypt failed"); + } + if cipher_recv.decrypt(aad2, &mut data).is_err() { + assert!(false, "Decrypt failed"); + } + data[..enc::sym::Nonce::len()].copy_from_slice(&[0; 24]); + let last = data.len() - cipher_recv.tag_len().0; + data[last..].copy_from_slice(&[0; 16]); + assert!(orig == data, "DIFFERENT!\n{:?}\n{:?}\n", orig, data); +} + +#[test] +fn test_encrypt_decrypt() { + let rand = enc::Random::new(); + let cipher = enc::sym::CipherKind::XChaCha20Poly1305; + let secret = enc::Secret::new_rand(&rand); + let secret2 = secret.clone(); + + let cipher_send = enc::sym::CipherSend::new(cipher, secret, &rand); + let cipher_recv = enc::sym::CipherRecv::new(cipher, secret2); + let nonce_len = cipher_recv.nonce_len(); + let tag_len = cipher_recv.tag_len(); + + let service_key = enc::Secret::new_rand(&rand); + + let data = dirsync::RespInner::ClearText(dirsync::RespData { + client_nonce: dirsync::Nonce::new(&rand), + id: ID::ID(::core::num::NonZeroU64::new(424242).unwrap()), + service_connection_id: ID::ID( + ::core::num::NonZeroU64::new(434343).unwrap(), + ), + service_key, + }); + + let resp = dirsync::Resp { + client_key_id: KeyID(4444), + data, + }; + let encrypt_from = resp.encrypted_offset(); + let encrypt_to = encrypt_from + resp.encrypted_length(nonce_len, tag_len); + + let h_resp = + Handshake::new(HandshakeData::DirSync(dirsync::DirSync::Resp(resp))); + + let mut bytes = Vec::::with_capacity( + h_resp.len(cipher.nonce_len(), cipher.tag_len()), + ); + bytes.resize(h_resp.len(cipher.nonce_len(), cipher.tag_len()), 0); + h_resp.serialize(cipher.nonce_len(), cipher.tag_len(), &mut bytes); + + let raw_aad: [u8; 7] = [0, 1, 2, 3, 4, 5, 6]; + let aad = enc::sym::AAD(&raw_aad[..]); + let aad2 = enc::sym::AAD(&raw_aad[..]); + + let pre_encrypt = bytes.clone(); + // encrypt + if cipher_send + .encrypt(aad, &mut bytes[encrypt_from..encrypt_to]) + .is_err() + { + assert!(false, "Encrypt failed"); + } + if cipher_recv + .decrypt(aad2, &mut bytes[encrypt_from..encrypt_to]) + .is_err() + { + assert!(false, "Decrypt failed"); + } + // make sure Nonce and Tag are 0 + bytes[encrypt_from..(encrypt_from + nonce_len.0)].copy_from_slice(&[0; 24]); + let tag_from = encrypt_to - tag_len.0; + bytes[tag_from..(tag_from + tag_len.0)].copy_from_slice(&[0; 16]); + assert!( + pre_encrypt == bytes, + "{}|{}=\n{:?}\n{:?}", + encrypt_from, + encrypt_to, + pre_encrypt, + bytes + ); + + // decrypt + + let mut deserialized = match Handshake::deserialize(&bytes) { + Ok(deserialized) => deserialized, + Err(e) => { + assert!(false, "{}", e.to_string()); + return; + } + }; + // reparse + if let HandshakeData::DirSync(dirsync::DirSync::Resp(r_a)) = + &mut deserialized.data + { + let enc_start = r_a.encrypted_offset() + cipher.nonce_len().0; + if let Err(e) = r_a.data.deserialize_as_cleartext( + &bytes[enc_start..(bytes.len() - cipher.tag_len().0)], + ) { + assert!(false, "DirSync Resp Inner serialize: {}", e.to_string()); + } + }; + + assert!( + deserialized == h_resp, + "DirSync Resp (de)serialization not working", + ); +} diff --git a/src/inner/worker.rs b/src/inner/worker.rs index 360b004..f352fc0 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -37,7 +37,7 @@ pub(crate) struct RawUdp { } pub(crate) struct ConnectInfo { - pub answer: oneshot::Sender>, + pub answer: oneshot::Sender, pub resolved: dnssec::Record, pub service_id: ServiceID, pub domain: Domain, @@ -57,14 +57,15 @@ pub(crate) enum WorkAnswer { } /// Actual worker implementation. -pub(crate) struct Worker { +#[allow(missing_debug_implementations)] +pub struct Worker { cfg: Config, thread_id: ThreadTracker, // PERF: rand uses syscalls. how to do that async? rand: Random, stop_working: crate::StopWorkingRecvCh, token_check: Option>>, - sockets: Vec, + sockets: Vec>, queue: ::async_channel::Receiver, queue_timeouts_recv: mpsc::UnboundedReceiver, queue_timeouts_send: mpsc::UnboundedSender, @@ -73,64 +74,18 @@ pub(crate) struct Worker { handshakes: HandshakeTracker, } +#[allow(unsafe_code)] +unsafe impl Send for Worker {} + impl Worker { - pub(crate) async fn new_and_loop( - cfg: Config, - thread_id: ThreadTracker, - stop_working: crate::StopWorkingRecvCh, - token_check: Option>>, - socket_addrs: Vec<::std::net::SocketAddr>, - queue: ::async_channel::Receiver, - ) -> ::std::io::Result<()> { - // TODO: get a channel to send back information, and send the error - let mut worker = Self::new( - cfg, - thread_id, - stop_working, - token_check, - socket_addrs, - queue, - ) - .await?; - worker.work_loop().await; - Ok(()) - } pub(crate) async fn new( mut cfg: Config, thread_id: ThreadTracker, stop_working: crate::StopWorkingRecvCh, token_check: Option>>, - socket_addrs: Vec<::std::net::SocketAddr>, + sockets: Vec>, queue: ::async_channel::Receiver, ) -> ::std::io::Result { - // bind all sockets again so that we can easily - // send without sharing resources - // in the future we will want to have a thread-local listener too, - // but before that we need ebpf to pin a connection to a thread - // directly from the kernel - let mut sock_set = ::tokio::task::JoinSet::new(); - socket_addrs.into_iter().for_each(|s_addr| { - sock_set.spawn(async move { - let socket = - connection::socket::bind_udp(s_addr.clone()).await?; - Ok(socket) - }); - }); - // make sure we either add all of them, or none - let mut sockets = Vec::with_capacity(cfg.listen.len()); - while let Some(join_res) = sock_set.join_next().await { - match join_res { - Ok(s_res) => match s_res { - Ok(sock) => sockets.push(sock), - Err(e) => { - ::tracing::error!("Can't rebind socket"); - return Err(e); - } - }, - Err(e) => return Err(e.into()), - } - } - let (queue_timeouts_send, queue_timeouts_recv) = mpsc::unbounded_channel(); let mut handshakes = HandshakeTracker::new( @@ -138,11 +93,24 @@ impl Worker { cfg.ciphers.clone(), cfg.key_exchanges.clone(), ); - let mut keys = Vec::new(); + let mut server_keys = Vec::new(); // make sure the keys are no longer in the config - ::core::mem::swap(&mut keys, &mut cfg.keys); - for k in keys.into_iter() { - handshakes.add_server(k.0, k.1); + ::core::mem::swap(&mut server_keys, &mut cfg.server_keys); + for k in server_keys.into_iter() { + if handshakes.add_server_key(k.id, k.priv_key).is_err() { + return Err(::std::io::Error::new( + ::std::io::ErrorKind::AlreadyExists, + "You can't use the same KeyID for multiple keys", + )); + } + } + for srv in cfg.servers.iter() { + if handshakes.add_server_domain(&srv.fqdn, &srv.keys).is_err() { + return Err(::std::io::Error::new( + ::std::io::ErrorKind::NotFound, + "Specified a KeyID that we don't have", + )); + } } Ok(Self { @@ -160,12 +128,15 @@ impl Worker { handshakes, }) } - pub(crate) async fn work_loop(&mut self) { + /// Continuously loop and process work as needed + pub async fn work_loop(&mut self) { 'mainloop: loop { let work = ::tokio::select! { tell_stopped = self.stop_working.recv() => { - let _ = tell_stopped.unwrap().send( + if let Ok(stop_ch) = tell_stopped { + let _ = stop_ch.send( crate::StopWorking::WorkerStopped).await; + } break; } maybe_timeout = self.queue.recv() => { @@ -302,6 +273,7 @@ impl Worker { } }; let hkdf; + if let PubKey::Exchange(srv_pub) = key.1 { let secret = match priv_key.key_exchange(exchange, srv_pub) { @@ -341,11 +313,13 @@ impl Worker { conn_info.service_id, service_conn_id, conn, + conn_info.answer, + key.0, ) { Ok((client_key_id, hshake)) => (client_key_id, hshake), - Err(_) => { + Err(answer) => { ::tracing::warn!("Too many client handshakes"); - let _ = conn_info.answer.send(Err( + let _ = answer.send(Err( handshake::Error::TooManyClientHandshakes .into(), )); @@ -363,7 +337,7 @@ impl Worker { let req_data = dirsync::ReqData { nonce: dirsync::Nonce::new(&self.rand), client_key_id, - id: auth_recv_id.0, + id: auth_recv_id.0, //FIXME: is zero auth: auth_info, }; let req = dirsync::Req { @@ -374,28 +348,50 @@ impl Worker { exchange_key: pub_key, data: dirsync::ReqInner::ClearText(req_data), }; - let mut raw = Vec::::with_capacity(req.len()); - req.serialize( + let encrypt_start = ID::len() + req.encrypted_offset(); + let encrypt_end = encrypt_start + + req.encrypted_length( + cipher_selected.nonce_len(), + cipher_selected.tag_len(), + ); + let h_req = Handshake::new(HandshakeData::DirSync( + DirSync::Req(req), + )); + use connection::{PacketData, ID}; + let packet = Packet { + id: ID::Handshake, + data: PacketData::Handshake(h_req), + }; + + let tot_len = packet.len( + cipher_selected.nonce_len(), + cipher_selected.tag_len(), + ); + let mut raw = Vec::::with_capacity(tot_len); + raw.resize(tot_len, 0); + packet.serialize( cipher_selected.nonce_len(), cipher_selected.tag_len(), &mut raw[..], ); // encrypt - let encrypt_start = req.encrypted_offset(); - let encrypt_end = encrypt_start + req.encrypted_length(); if let Err(e) = hshake.connection.cipher_send.encrypt( sym::AAD(&[]), &mut raw[encrypt_start..encrypt_end], ) { ::tracing::error!("Can't encrypt DirSync Request"); - let _ = conn_info.answer.send(Err(e.into())); + if let Some(client) = + self.handshakes.remove_client(client_key_id) + { + let _ = client.answer.send(Err(e.into())); + }; continue 'mainloop; } // send always from the first socket // FIXME: select based on routing table let sender = self.sockets[0].local_addr().unwrap(); - let dest = UdpServer(addr.as_sockaddr().unwrap()); + let dest = UdpClient(addr.as_sockaddr().unwrap()); // start the timeout right before sending the packet hshake.timeout = Some(::tokio::task::spawn_local( @@ -406,7 +402,7 @@ impl Worker { )); // send packet - self.send_packet(raw, UdpClient(sender), dest).await; + self.send_packet(raw, dest, UdpServer(sender)).await; continue 'mainloop; } @@ -435,17 +431,19 @@ impl Worker { /// Read and do stuff with the raw udp packet async fn recv(&mut self, mut udp: RawUdp) { if udp.packet.id.is_handshake() { - let handshake = match Handshake::deserialize(&udp.data[8..]) { + let handshake = match Handshake::deserialize( + &udp.data[connection::ID::len()..], + ) { Ok(handshake) => handshake, Err(e) => { - ::tracing::warn!("Handshake parsing: {}", e); + ::tracing::debug!("Handshake parsing: {}", e); return; } }; - let action = match self - .handshakes - .recv_handshake(handshake, &mut udp.data[8..]) - { + let action = match self.handshakes.recv_handshake( + handshake, + &mut udp.data[connection::ID::len()..], + ) { Ok(action) => action, Err(err) => { ::tracing::debug!("Handshake recv error {}", err); @@ -454,16 +452,6 @@ impl Worker { }; match action { HandshakeAction::AuthNeeded(authinfo) => { - let token_check = match self.token_check.as_ref() { - Some(token_check) => token_check, - None => { - ::tracing::error!( - "Authentication requested but \ - we have no token checker" - ); - return; - } - }; let req; if let HandshakeData::DirSync(DirSync::Req(r)) = authinfo.handshake.data @@ -477,25 +465,36 @@ impl Worker { let req_data = match req.data { ReqInner::ClearText(req_data) => req_data, _ => { - ::tracing::error!( - "token_check: expected ClearText" - ); + ::tracing::error!("AuthNeeded: expected ClearText"); + assert!(false, "AuthNeeded: unreachable"); return; } }; // FIXME: This part can take a while, // we should just spawn it probably - let is_authenticated = { - let tk_check = token_check.lock().await; - tk_check( - req_data.auth.user, - req_data.auth.token, - req_data.auth.service_id, - req_data.auth.domain, - ) - .await + let maybe_auth_check = { + match &self.token_check { + None => { + if req_data.auth.user == auth::USERID_ANONYMOUS + { + Ok(true) + } else { + Ok(false) + } + } + Some(token_check) => { + let tk_check = token_check.lock().await; + tk_check( + req_data.auth.user, + req_data.auth.token, + req_data.auth.service_id, + req_data.auth.domain, + ) + .await + } + } }; - let is_authenticated = match is_authenticated { + let is_authenticated = match maybe_auth_check { Ok(is_authenticated) => is_authenticated, Err(_) => { ::tracing::error!("error in token auth"); @@ -545,9 +544,9 @@ impl Worker { client_key_id: req_data.client_key_id, data: RespInner::ClearText(resp_data), }; - let offset_to_encrypt = resp.encrypted_offset(); + let encrypt_from = ID::len() + resp.encrypted_offset(); let encrypt_until = - offset_to_encrypt + resp.encrypted_length() + tag_len.0; + encrypt_from + resp.encrypted_length(head_len, tag_len); let resp_handshake = Handshake::new( HandshakeData::DirSync(DirSync::Resp(resp)), ); @@ -556,14 +555,15 @@ impl Worker { id: ID::new_handshake(), data: PacketData::Handshake(resp_handshake), }; - let mut raw_out = - Vec::::with_capacity(packet.len(head_len, tag_len)); + let tot_len = packet.len(head_len, tag_len); + let mut raw_out = Vec::::with_capacity(tot_len); + raw_out.resize(tot_len, 0); packet.serialize(head_len, tag_len, &mut raw_out); - if let Err(e) = auth_conn.cipher_send.encrypt( - aad, - &mut raw_out[offset_to_encrypt..encrypt_until], - ) { + if let Err(e) = auth_conn + .cipher_send + .encrypt(aad, &mut raw_out[encrypt_from..encrypt_until]) + { ::tracing::error!("can't encrypt: {:?}", e); return; } @@ -588,43 +588,46 @@ impl Worker { ::tracing::error!( "ClientConnect on non DS::Resp::ClearText" ); - return; + unreachable!(); } + let auth_srv_conn = IDSend(resp_data.id); let mut conn = cci.connection; - conn.id_send = IDSend(resp_data.id); + conn.id_send = auth_srv_conn; let id_recv = conn.id_recv; let cipher = conn.cipher_recv.kind(); // track the connection to the authentication server if self.connections.track(conn.into()).is_err() { ::tracing::error!("Could not track new connection"); self.connections.remove(id_recv); + let _ = cci.answer.send(Err( + handshake::Error::InternalTracking.into(), + )); return; } - if cci.service_id == auth::SERVICEID_AUTH { - // the user asked a single connection - // to the authentication server, without any additional - // service. No more connections to setup - return; + if cci.service_id != auth::SERVICEID_AUTH { + // create and track the connection to the service + // SECURITY: xor with secrets + //FIXME: the Secret should be XORed with the client + // stored secret (if any) + let hkdf = Hkdf::new( + HkdfKind::Sha3, + cci.service_id.as_bytes(), + resp_data.service_key, + ); + let mut service_connection = Connection::new( + hkdf, + cipher, + connection::Role::Client, + &self.rand, + ); + service_connection.id_recv = cci.service_connection_id; + service_connection.id_send = + IDSend(resp_data.service_connection_id); + let _ = + self.connections.track(service_connection.into()); } - // create and track the connection to the service - // SECURITY: xor with secrets - //FIXME: the Secret should be XORed with the client stored - // secret (if any) - let hkdf = Hkdf::new( - HkdfKind::Sha3, - cci.service_id.as_bytes(), - resp_data.service_key, - ); - let mut service_connection = Connection::new( - hkdf, - cipher, - connection::Role::Client, - &self.rand, - ); - service_connection.id_recv = cci.service_connection_id; - service_connection.id_send = - IDSend(resp_data.service_connection_id); - let _ = self.connections.track(service_connection.into()); + let _ = + cci.answer.send(Ok((cci.srv_key_id, auth_srv_conn))); } HandshakeAction::Nothing => {} }; @@ -644,11 +647,12 @@ impl Worker { Some(src_sock) => src_sock, None => { ::tracing::error!( - "Can't send packet: Server changed listening ip!" + "Can't send packet: Server changed listening ip{}!", + server.0 ); return; } }; - let _ = src_sock.send_to(&data, client.0).await; + let res = src_sock.send_to(&data, client.0).await; } } diff --git a/src/lib.rs b/src/lib.rs index d08ca2e..697fff8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -30,7 +30,7 @@ use crate::{ auth::{Domain, ServiceID, TokenChecker}, connection::{ handshake, - socket::{SocketList, UdpClient, UdpServer}, + socket::{SocketTracker, UdpClient, UdpServer}, AuthServerConnections, Packet, }, inner::{ @@ -86,7 +86,7 @@ pub struct Fenrir { /// library Configuration cfg: Config, /// listening udp sockets - sockets: SocketList, + sockets: Vec, /// DNSSEC resolver, with failovers dnssec: dnssec::Dnssec, /// Broadcast channel to tell workers to stop working @@ -100,9 +100,6 @@ pub struct Fenrir { // manner _thread_pool: Vec<::std::thread::JoinHandle<()>>, _thread_work: Arc>>, - // This can be different from cfg.listen since using port 0 will result - // in a random port assigned by the operative system - _listen_addrs: Vec<::std::net::SocketAddr>, } // TODO: graceful vs immediate stop @@ -127,16 +124,23 @@ impl Fenrir { } fn stop_sync( &mut self, - ) -> Option<(::tokio::sync::mpsc::Receiver, usize, usize)> - { - let listeners_num = self.sockets.list.len(); + ) -> Option<( + ::tokio::sync::mpsc::Receiver, + Vec<::tokio::task::JoinHandle<::std::io::Result<()>>>, + usize, + )> { let workers_num = self._thread_work.len(); - if self.sockets.list.len() > 0 || self._thread_work.len() > 0 { + if self.sockets.len() > 0 || self._thread_work.len() > 0 { let (ch_send, ch_recv) = ::tokio::sync::mpsc::channel(4); let _ = self.stop_working.send(ch_send); - let _ = self.sockets.rm_all(); + let mut old_listeners = Vec::with_capacity(self.sockets.len()); + ::core::mem::swap(&mut old_listeners, &mut self.sockets); self._thread_pool.clear(); - Some((ch_recv, listeners_num, workers_num)) + let listeners = old_listeners + .into_iter() + .map(|(_, joinable)| joinable) + .collect(); + Some((ch_recv, listeners, workers_num)) } else { None } @@ -144,9 +148,10 @@ impl Fenrir { async fn stop_wait( &mut self, mut ch: ::tokio::sync::mpsc::Receiver, - mut listeners_num: usize, + listeners: Vec<::tokio::task::JoinHandle<::std::io::Result<()>>>, mut workers_num: usize, ) { + let mut listeners_num = listeners.len(); while listeners_num > 0 && workers_num > 0 { match ch.recv().await { Some(stopped) => match stopped { @@ -158,6 +163,11 @@ impl Fenrir { _ => break, } } + for l in listeners.into_iter() { + if let Err(e) = l.await { + ::tracing::error!("Unclean shutdown of listener: {:?}", e); + } + } } /// Create a new Fenrir endpoint /// spawn threads pinned to cpus in our own way with tokio's runtime @@ -167,22 +177,32 @@ impl Fenrir { ) -> Result { let (sender, _) = ::tokio::sync::broadcast::channel(1); let dnssec = dnssec::Dnssec::new(&config.resolvers)?; + // bind sockets early so we can change "port 0" (aka: random) + // in the config + let binded_sockets = Self::bind_sockets(&config).await?; + let socket_addrs = binded_sockets + .iter() + .map(|s| s.local_addr().unwrap()) + .collect(); + let cfg = { + let mut tmp = config.clone(); + tmp.listen = socket_addrs; + tmp + }; let mut endpoint = Self { - cfg: config.clone(), - sockets: SocketList::new(), + cfg, + sockets: Vec::with_capacity(config.listen.len()), dnssec, stop_working: sender, token_check: None, conn_auth_srv: Mutex::new(AuthServerConnections::new()), _thread_pool: Vec::new(), _thread_work: Arc::new(Vec::new()), - _listen_addrs: Vec::with_capacity(config.listen.len()), }; - endpoint.start_work_threads_pinned(tokio_rt).await?; - match endpoint.add_sockets().await { - Ok(addrs) => endpoint._listen_addrs = addrs, - Err(e) => return Err(e.into()), - } + endpoint + .start_work_threads_pinned(tokio_rt, binded_sockets.clone()) + .await?; + endpoint.run_listeners(binded_sockets).await?; Ok(endpoint) } /// Create a new Fenrir endpoint @@ -192,41 +212,39 @@ impl Fenrir { /// * make sure that the threads are pinned on the cpu pub async fn with_workers( config: &Config, - ) -> Result< - ( - Self, - Vec>>, - ), - Error, - > { + ) -> Result<(Self, Vec), Error> { let (stop_working, _) = ::tokio::sync::broadcast::channel(1); let dnssec = dnssec::Dnssec::new(&config.resolvers)?; - let cfg = config.clone(); - let sockets = SocketList::new(); - let conn_auth_srv = Mutex::new(AuthServerConnections::new()); - let thread_pool = Vec::new(); - let thread_work = Arc::new(Vec::new()); - let listen_addrs = Vec::with_capacity(config.listen.len()); + // bind sockets early so we can change "port 0" (aka: random) + // in the config + let binded_sockets = Self::bind_sockets(&config).await?; + let socket_addrs = binded_sockets + .iter() + .map(|s| s.local_addr().unwrap()) + .collect(); + let cfg = { + let mut tmp = config.clone(); + tmp.listen = socket_addrs; + tmp + }; let mut endpoint = Self { cfg, - sockets, + sockets: Vec::with_capacity(config.listen.len()), dnssec, stop_working: stop_working.clone(), token_check: None, - conn_auth_srv, - _thread_pool: thread_pool, - _thread_work: thread_work, - _listen_addrs: listen_addrs, + conn_auth_srv: Mutex::new(AuthServerConnections::new()), + _thread_pool: Vec::new(), + _thread_work: Arc::new(Vec::new()), }; let worker_num = config.threads.unwrap().get(); let mut workers = Vec::with_capacity(worker_num); for _ in 0..worker_num { - workers.push(endpoint.start_single_worker().await?); - } - match endpoint.add_sockets().await { - Ok(addrs) => endpoint._listen_addrs = addrs, - Err(e) => return Err(e.into()), + workers.push( + endpoint.start_single_worker(binded_sockets.clone()).await?, + ); } + endpoint.run_listeners(binded_sockets).await?; Ok((endpoint, workers)) } /// Returns the list of the actual addresses we are listening on @@ -234,57 +252,56 @@ impl Fenrir { /// if you specified UDP port 0 a random one has been assigned to you /// by the operating system. pub fn addresses(&self) -> Vec<::std::net::SocketAddr> { - self._listen_addrs.clone() + self.sockets.iter().map(|(s, _)| s.clone()).collect() } - // only call **after** starting all threads - /// Add all UDP sockets found in config - /// and start listening for packets - async fn add_sockets( - &mut self, - ) -> ::std::io::Result> { + // only call **before** starting all threads + /// bind all UDP sockets found in config + async fn bind_sockets(cfg: &Config) -> Result>, Error> { // try to bind multiple sockets in parallel let mut sock_set = ::tokio::task::JoinSet::new(); - self.cfg.listen.iter().for_each(|s_addr| { + cfg.listen.iter().for_each(|s_addr| { let socket_address = s_addr.clone(); - let stop_working = self.stop_working.subscribe(); - let th_work = self._thread_work.clone(); sock_set.spawn(async move { - let s = connection::socket::bind_udp(socket_address).await?; - let arc_s = Arc::new(s); - let join = ::tokio::spawn(Self::listen_udp( - stop_working, - th_work, - arc_s.clone(), - )); - Ok((arc_s, join)) + connection::socket::bind_udp(socket_address).await }); }); - - // make sure we either add all of them, or none - let mut all_socks = Vec::with_capacity(self.cfg.listen.len()); + // make sure we either return all of them, or none + let mut all_socks = Vec::with_capacity(cfg.listen.len()); while let Some(join_res) = sock_set.join_next().await { match join_res { Ok(s_res) => match s_res { Ok(s) => { - all_socks.push(s); + all_socks.push(Arc::new(s)); } Err(e) => { - return Err(e); + return Err(e.into()); } }, Err(e) => { - return Err(e.into()); + return Err(Error::Setup(e.to_string())); } } } - - let mut ret = Vec::with_capacity(self.cfg.listen.len()); - for (arc_s, join) in all_socks.into_iter() { - ret.push(arc_s.local_addr().unwrap()); - self.sockets.add_socket(arc_s, join).await; + assert!(all_socks.len() == cfg.listen.len(), "missing socks"); + Ok(all_socks) + } + // only call **after** starting all threads + /// spawn all listeners + async fn run_listeners( + &mut self, + socks: Vec>, + ) -> Result<(), Error> { + for sock in socks.into_iter() { + let sockaddr = sock.local_addr().unwrap(); + let stop_working = self.stop_working.subscribe(); + let th_work = self._thread_work.clone(); + let joinable = ::tokio::spawn(async move { + Self::listen_udp(stop_working, th_work, sock.clone()).await + }); + self.sockets.push((sockaddr, joinable)); } - Ok(ret) + Ok(()) } /// Run a dedicated loop to read packets on the listening socket @@ -301,12 +318,15 @@ impl Fenrir { let (bytes, sock_sender) = ::tokio::select! { tell_stopped = stop_working.recv() => { drop(socket); - let _ = tell_stopped.unwrap() - .send(StopWorking::ListenerStopped).await; + if let Ok(stop_ch) = tell_stopped { + let _ = stop_ch + .send(StopWorking::ListenerStopped).await; + } return Ok(()); } result = socket.recv_from(&mut buffer) => { - result? + let (bytes, from) = result?; + (bytes, UdpClient(from)) } }; let data: Vec = buffer[..bytes].to_vec(); @@ -324,17 +344,15 @@ impl Fenrir { use connection::packet::ConnectionID; match packet.id { ConnectionID::Handshake => { - let send_port = sock_sender.port() as u64; - ((send_port % queues_num) - 1) as usize - } - ConnectionID::ID(id) => { - ((id.get() % queues_num) - 1) as usize + let send_port = sock_sender.0.port() as u64; + (send_port % queues_num) as usize } + ConnectionID::ID(id) => (id.get() % queues_num) as usize, } }; let _ = work_queues[thread_idx] .send(Work::Recv(RawUdp { - src: UdpClient(sock_sender), + src: sock_sender, dst: sock_receiver, packet, data, @@ -431,7 +449,7 @@ impl Fenrir { .unwrap(); // and tell that thread to connect somewhere - let (send, recv) = ::tokio::sync::oneshot::channel(); + let (send, mut recv) = ::tokio::sync::oneshot::channel(); let _ = self._thread_work[thread_idx] .send(Work::Connect(ConnectInfo { answer: send, @@ -450,10 +468,15 @@ impl Fenrir { conn_auth_lock.remove_reserved(&resolved); Err(e) } - Ok((pubkey, id_send)) => { + Ok((key_id, id_send)) => { + let key = resolved + .public_keys + .iter() + .find(|k| k.0 == key_id) + .unwrap(); let mut conn_auth_lock = self.conn_auth_srv.lock().await; - conn_auth_lock.add(&pubkey, id_send, &resolved); + conn_auth_lock.add(&key.1, id_send, &resolved); //FIXME: user needs to somehow track the connection Ok(()) @@ -472,13 +495,11 @@ impl Fenrir { } } - // needs to be called before add_sockets + // needs to be called before run_listeners async fn start_single_worker( &mut self, - ) -> ::std::result::Result< - impl futures::Future>, - Error, - > { + socks: Vec>, + ) -> ::std::result::Result { let thread_idx = self._thread_work.len() as u16; let max_threads = self.cfg.threads.unwrap().get() as u16; if thread_idx >= max_threads { @@ -496,17 +517,18 @@ impl Fenrir { total: max_threads, }; let (work_send, work_recv) = ::async_channel::unbounded::(); - let worker = Worker::new_and_loop( + let worker = Worker::new( self.cfg.clone(), thread_id, self.stop_working.subscribe(), self.token_check.clone(), - self.cfg.listen.clone(), + socks, work_recv, - ); + ) + .await?; // don't keep around private keys too much if (thread_idx + 1) == max_threads { - self.cfg.keys.clear(); + self.cfg.server_keys.clear(); } loop { let queues_lock = match Arc::get_mut(&mut self._thread_work) { @@ -533,6 +555,7 @@ impl Fenrir { async fn start_work_threads_pinned( &mut self, tokio_rt: Arc<::tokio::runtime::Runtime>, + sockets: Vec>, ) -> ::std::result::Result<(), Error> { use ::std::sync::Mutex; let hw_topology = match ::hwloc2::Topology::new() { @@ -568,7 +591,7 @@ impl Fenrir { let (work_send, work_recv) = ::async_channel::unbounded::(); let th_stop_working = self.stop_working.subscribe(); let th_token_check = self.token_check.clone(); - let th_socket_addrs = self.cfg.listen.clone(); + let th_sockets = sockets.clone(); let thread_id = ThreadTracker { total: cores as u16, id: 1 + (core as u16), @@ -598,17 +621,22 @@ impl Fenrir { // finally run the main worker. // make sure things stay on this thread let tk_local = ::tokio::task::LocalSet::new(); - let _ = tk_local.block_on( - &th_tokio_rt, - Worker::new_and_loop( + let _ = tk_local.block_on(&th_tokio_rt, async move { + let mut worker = match Worker::new( th_config, thread_id, th_stop_working, th_token_check, - th_socket_addrs, + th_sockets, work_recv, - ), - ); + ) + .await + { + Ok(worker) => worker, + Err(_) => return, + }; + worker.work_loop().await + }); }); loop { let queues_lock = match Arc::get_mut(&mut self._thread_work) { @@ -627,7 +655,7 @@ impl Fenrir { self._thread_pool.push(join_handle); } // don't keep around private keys too much - self.cfg.keys.clear(); + self.cfg.server_keys.clear(); Ok(()) } } diff --git a/src/tests.rs b/src/tests.rs index 7d19b4f..acf57cc 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -21,23 +21,46 @@ async fn test_connection_dirsync() { cfg.threads = Some(::core::num::NonZeroUsize::new(1).unwrap()); cfg }; + let test_domain: Domain = "example.com".into(); let cfg_server = { let mut cfg = cfg_client.clone(); - cfg.keys = [(KeyID(42), priv_exchange_key, pub_exchange_key)].to_vec(); + cfg.server_keys = [config::ServerKey { + id: KeyID(42), + priv_key: priv_exchange_key, + pub_key: pub_exchange_key, + }] + .to_vec(); + cfg.servers = [config::AuthServer { + fqdn: test_domain.clone(), + keys: [KeyID(42)].to_vec(), + }] + .to_vec(); cfg }; let (server, mut srv_workers) = Fenrir::with_workers(&cfg_server).await.unwrap(); - - let srv_worker = srv_workers.pop().unwrap(); - let local_thread = ::tokio::task::LocalSet::new(); - local_thread.spawn_local(async move { srv_worker.await }); - let (client, mut cli_workers) = Fenrir::with_workers(&cfg_client).await.unwrap(); - let cli_worker = cli_workers.pop().unwrap(); - local_thread.spawn_local(async move { cli_worker.await }); + let mut srv_worker = srv_workers.pop().unwrap(); + let mut cli_worker = cli_workers.pop().unwrap(); + + ::std::thread::spawn(move || { + let rt = ::tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + let local_thread = ::tokio::task::LocalSet::new(); + local_thread.spawn_local(async move { + srv_worker.work_loop().await; + }); + + local_thread.spawn_local(async move { + ::tokio::time::sleep(::std::time::Duration::from_millis(100)).await; + cli_worker.work_loop().await; + }); + rt.block_on(local_thread); + }); use crate::{ connection::handshake::HandshakeID, @@ -63,17 +86,17 @@ async fn test_connection_dirsync() { ciphers: [enc::sym::CipherKind::XChaCha20Poly1305].to_vec(), }; - server.graceful_stop().await; - client.graceful_stop().await; - return; + ::tokio::time::sleep(::std::time::Duration::from_millis(500)).await; + match client + .connect_resolved(dnssec_record, &test_domain, auth::SERVICEID_AUTH) + .await + { + Ok(()) => {} + Err(e) => { + assert!(false, "Err on client connection: {:?}", e); + } + } - let _ = client - .connect_resolved( - dnssec_record, - &Domain("example.com".to_owned()), - auth::SERVICEID_AUTH, - ) - .await; server.graceful_stop().await; client.graceful_stop().await; } From 376e8fb83361d692016303cb1a41a565afaaee65 Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Sat, 17 Jun 2023 14:06:57 +0200 Subject: [PATCH 34/34] Remove some warnings Signed-off-by: Luca Fulchir --- src/connection/handshake/tracker.rs | 5 ----- src/dnssec/record.rs | 2 +- src/enc/hkdf.rs | 2 +- src/enc/mod.rs | 2 +- src/inner/worker.rs | 5 +---- src/lib.rs | 2 +- 6 files changed, 5 insertions(+), 13 deletions(-) diff --git a/src/connection/handshake/tracker.rs b/src/connection/handshake/tracker.rs index 63ee31d..aaa1163 100644 --- a/src/connection/handshake/tracker.rs +++ b/src/connection/handshake/tracker.rs @@ -133,8 +133,6 @@ pub(crate) struct AuthNeededInfo { pub handshake: Handshake, /// hkdf generated from the handshake pub hkdf: Hkdf, - /// cipher to be used in both directions - pub cipher: CipherKind, } /// Client information needed to fully establish the conenction @@ -336,12 +334,9 @@ impl HandshakeTracker { } } - let cipher = req.cipher; - return Ok(HandshakeAction::AuthNeeded(AuthNeededInfo { handshake, hkdf, - cipher, })); } DirSync::Resp(resp) => { diff --git a/src/dnssec/record.rs b/src/dnssec/record.rs index 9426b3e..a995cea 100644 --- a/src/dnssec/record.rs +++ b/src/dnssec/record.rs @@ -530,7 +530,7 @@ impl Record { bytes_parsed = bytes_parsed + pubkey_length; continue; } - Err(e) => { + Err(_) => { return Err(Error::UnsupportedData(bytes_parsed)); } }; diff --git a/src/enc/hkdf.rs b/src/enc/hkdf.rs index d81276f..e52e236 100644 --- a/src/enc/hkdf.rs +++ b/src/enc/hkdf.rs @@ -125,7 +125,7 @@ impl HkdfSha3 { let mut out: [u8; 32] = [0; 32]; #[allow(unsafe_code)] unsafe { - self.inner.hkdf.expand(context, &mut out); + let _ = self.inner.hkdf.expand(context, &mut out); } out.into() } diff --git a/src/enc/mod.rs b/src/enc/mod.rs index 3aeea7b..663c72d 100644 --- a/src/enc/mod.rs +++ b/src/enc/mod.rs @@ -27,7 +27,7 @@ impl Random { } /// Fill a buffer with randomness pub fn fill(&self, out: &mut [u8]) { - self.rnd.fill(out); + let _ = self.rnd.fill(out); } /// return the underlying ring SystemRandom pub fn ring_rnd(&self) -> &::ring::rand::SystemRandom { diff --git a/src/inner/worker.rs b/src/inner/worker.rs index f352fc0..083a9ac 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -52,9 +52,6 @@ pub(crate) enum Work { DropHandshake(KeyID), Recv(RawUdp), } -pub(crate) enum WorkAnswer { - UNUSED, -} /// Actual worker implementation. #[allow(missing_debug_implementations)] @@ -653,6 +650,6 @@ impl Worker { return; } }; - let res = src_sock.send_to(&data, client.0).await; + let _res = src_sock.send_to(&data, client.0).await; } } diff --git a/src/lib.rs b/src/lib.rs index 697fff8..fbca09a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -449,7 +449,7 @@ impl Fenrir { .unwrap(); // and tell that thread to connect somewhere - let (send, mut recv) = ::tokio::sync::oneshot::channel(); + let (send, recv) = ::tokio::sync::oneshot::channel(); let _ = self._thread_work[thread_idx] .send(Work::Connect(ConnectInfo { answer: send,