diff --git a/src/connection/mod.rs b/src/connection/mod.rs index 2f8e3c9..7c1d733 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -15,7 +15,7 @@ use crate::{ dnssec, enc::{ asym::PubKey, - hkdf::HkdfSha3, + hkdf::Hkdf, sym::{CipherKind, CipherRecv, CipherSend}, Random, }, @@ -55,7 +55,7 @@ pub struct Connection { /// Sending Connection ID pub id_send: IDSend, /// The main hkdf used for all secrets in this connection - pub hkdf: HkdfSha3, + pub hkdf: Hkdf, /// Cipher for decrypting data pub cipher_recv: CipherRecv, /// Cipher for encrypting data @@ -76,7 +76,7 @@ pub enum Role { impl Connection { pub(crate) fn new( - hkdf: HkdfSha3, + hkdf: Hkdf, cipher: CipherKind, role: Role, rand: &Random, diff --git a/src/dnssec/record.rs b/src/dnssec/record.rs index 48e8edd..51f7fd3 100644 --- a/src/dnssec/record.rs +++ b/src/dnssec/record.rs @@ -3,27 +3,51 @@ //! //! Encoding and decoding in base85, RFC1924 //! -//! //! Basic encoding idea: -//! * 1 byte: half-bytes +//! * 1 byte: divided in two: //! * half: num of addresses //! * half: num of pubkeys +//! * 1 byte: divided in half: +//! * half: number of key exchanges +//! * half: number of Hkdfs +//! * 1 byte: divided in half: +//! * half: number of ciphers +//! * half: nothing //! [ # list of addresses //! * 1 byte: bitfield //! * 0..1 ipv4/ipv6 //! * 2..4 priority (for failover) //! * 5..7 weight between priority -//! * 1 byte: public key id +//! * 1 byte: divided in half: +//! * half: num of public key ids +//! * half: num of handhskae ids //! * 2 bytes: UDP port +//! * [ 1 byte per public key id ] +//! * [ 1 byte per handshake id ] //! * X bytes: IP //! ] //! [ # list of pubkeys //! * 1 byte: pubkey id +//! * 1 byte: pubkey length //! * 1 byte: pubkey type //! * Y bytes: pubkey //! ] +//! [ # list of supported key exchanges +//! * 1 byte for each cipher +//! ] +//! [ # list of supported HDKFs +//! * 1 byte for each hkdf +//! ] +//! [ # list of supported ciphers +//! * 1 byte for each cipher +//! ] -use crate::enc::{self, asym::PubKey}; +use crate::enc::{ + self, + asym::{KeyExchange, PubKey}, + hkdf::HkdfKind, + sym::CipherKind, +}; use ::core::num::NonZeroU16; use ::num_traits::FromPrimitive; use ::std::{net::IpAddr, vec::Vec}; @@ -220,6 +244,10 @@ impl Address { bitfield |= self.weight as u8; raw.push(bitfield); + let len_combined: u8 = self.public_key_ids.len() as u8; + let len_combined = len_combined << 4; + let len_combined = len_combined | self.handshake_ids.len() as u8; + raw.push(len_combined); raw.extend_from_slice( &(match self.port { @@ -228,10 +256,12 @@ impl Address { }), ); - raw.push(self.public_key_ids.len() as u8); for id in self.public_key_ids.iter() { raw.push(id.0); } + for id in self.handshake_ids.iter() { + raw.push(*id as u8); + } match self.ip { IpAddr::V4(ip) => { @@ -250,19 +280,11 @@ impl Address { } let ip_type = raw[0] >> 6; let is_ipv6: bool; - let total_length: usize; match ip_type { 0 => { is_ipv6 = false; - total_length = 8; - } - 1 => { - total_length = 20; - if raw.len() < total_length { - return Err(Error::NotEnoughData(1)); - } - is_ipv6 = true } + 1 => is_ipv6 = true, _ => return Err(Error::UnsupportedData(0)), } let raw_priority = (raw[0] << 2) >> 5; @@ -270,28 +292,33 @@ impl Address { let priority = AddressPriority::from_u8(raw_priority).unwrap(); let weight = AddressWeight::from_u8(raw_weight).unwrap(); + // UDP port let raw_port = u16::from_le_bytes([raw[1], raw[2]]); + let port = if raw_port == 0 { + None + } else { + Some(NonZeroU16::new(raw_port).unwrap()) + }; // Add publickey ids - let num_pubkey_ids = raw[3] as usize; - if raw.len() < 3 + num_pubkey_ids { + let num_pubkey_ids = (raw[3] >> 4) as usize; + let num_handshake_ids = (raw[3] & 0x0F) as usize; + if raw.len() <= 3 + num_pubkey_ids + num_handshake_ids { return Err(Error::NotEnoughData(3)); } + let mut bytes_parsed = 4; let mut public_key_ids = Vec::with_capacity(num_pubkey_ids); - - for raw_pubkey_id in raw[4..num_pubkey_ids].iter() { + for raw_pubkey_id in + raw[bytes_parsed..(bytes_parsed + num_pubkey_ids)].iter() + { public_key_ids.push(PublicKeyID(*raw_pubkey_id)); } // add handshake ids - let next_ptr = 3 + num_pubkey_ids; - let num_handshake_ids = raw[next_ptr] as usize; - if raw.len() < next_ptr + num_handshake_ids { - return Err(Error::NotEnoughData(next_ptr)); - } + bytes_parsed = bytes_parsed + num_pubkey_ids; let mut handshake_ids = Vec::with_capacity(num_handshake_ids); for raw_handshake_id in - raw[next_ptr..(next_ptr + num_pubkey_ids)].iter() + raw[bytes_parsed..(bytes_parsed + num_handshake_ids)].iter() { match HandshakeID::from_u8(*raw_handshake_id) { Some(h_id) => handshake_ids.push(h_id), @@ -304,26 +331,24 @@ impl Address { } } } - let next_ptr = next_ptr + num_pubkey_ids; + bytes_parsed = bytes_parsed + num_handshake_ids; - let port = if raw_port == 0 { - None - } else { - Some(NonZeroU16::new(raw_port).unwrap()) - }; let ip = if is_ipv6 { - let ip_end = next_ptr + 16; + let ip_end = bytes_parsed + 16; if raw.len() < ip_end { - return Err(Error::NotEnoughData(next_ptr)); + return Err(Error::NotEnoughData(bytes_parsed)); } - let raw_ip: [u8; 16] = raw[next_ptr..ip_end].try_into().unwrap(); + let raw_ip: [u8; 16] = + raw[bytes_parsed..ip_end].try_into().unwrap(); + bytes_parsed = bytes_parsed + 16; IpAddr::from(raw_ip) } else { - let ip_end = next_ptr + 4; + let ip_end = bytes_parsed + 4; if raw.len() < ip_end { - return Err(Error::NotEnoughData(next_ptr)); + return Err(Error::NotEnoughData(bytes_parsed)); } - let raw_ip: [u8; 4] = raw[next_ptr..ip_end].try_into().unwrap(); + let raw_ip: [u8; 4] = raw[bytes_parsed..ip_end].try_into().unwrap(); + bytes_parsed = bytes_parsed + 4; IpAddr::from(raw_ip) }; @@ -336,7 +361,7 @@ impl Address { public_key_ids, handshake_ids, }, - total_length, + bytes_parsed, )) } } @@ -353,6 +378,12 @@ pub struct Record { /// List of all authentication servers' addresses. /// Multiple ones can point to the same authentication server pub addresses: Vec
, + /// List of supported key exchanges + pub key_exchanges: Vec, + /// List of supported key exchanges + pub hkdfs: Vec, + /// List of supported ciphers + pub ciphers: Vec, } impl Record { @@ -371,15 +402,27 @@ impl Record { if self.addresses.len() > 16 { return Err(Error::Max16Addresses); } + if self.key_exchanges.len() > 16 { + return Err(Error::Max16KeyExchanges); + } + if self.hkdfs.len() > 16 { + return Err(Error::Max16Hkdfs); + } + if self.ciphers.len() > 16 { + return Err(Error::Max16Ciphers); + } // everything else is all good - let total_size: usize = 1 + let total_size: usize = 3 + self.addresses.iter().map(|a| a.raw_len()).sum::() + self .public_keys .iter() - .map(|(_, key)| 1 + key.kind().pub_len()) - .sum::(); + .map(|(_, key)| 3 + key.kind().pub_len()) + .sum::() + + self.key_exchanges.len() + + self.hkdfs.len() + + self.ciphers.len(); let mut raw = Vec::with_capacity(total_size); @@ -387,33 +430,56 @@ impl Record { let len_combined: u8 = self.addresses.len() as u8; let len_combined = len_combined << 4; let len_combined = len_combined | self.public_keys.len() as u8; - raw.push(len_combined); + // number of key exchanges and hkdfs + let len_combined: u8 = self.key_exchanges.len() as u8; + let len_combined = len_combined << 4; + let len_combined = len_combined | self.hkdfs.len() as u8; + raw.push(len_combined); + let num_of_ciphers: u8 = (self.ciphers.len() as u8) << 4; + raw.push(num_of_ciphers); for address in self.addresses.iter() { address.encode_into(&mut raw); } for (public_key_id, public_key) in self.public_keys.iter() { raw.push(public_key_id.0); + raw.push(public_key.kind().pub_len() as u8); + raw.push(public_key.kind() as u8); public_key.serialize_into(&mut raw); } + for k_x in self.key_exchanges.iter() { + raw.push(*k_x as u8); + } + for h in self.hkdfs.iter() { + raw.push(*h as u8); + } + for c in self.ciphers.iter() { + raw.push(*c as u8); + } Ok(::base85::encode(&raw)) } /// Decode from base85 to the actual object pub fn decode(raw: &[u8]) -> Result { - // bare minimum for 1 address and key - const MIN_RAW_LENGTH: usize = 1 + 8 + 8; + // bare minimum for 1 address, 1 key, 1 key exchange and 1 cipher + const MIN_RAW_LENGTH: usize = 1 + 1 + 1 + 8 + 9 + 1 + 1; if raw.len() <= MIN_RAW_LENGTH { return Err(Error::NotEnoughData(0)); } let mut num_addresses = (raw[0] >> 4) as usize; let mut num_public_keys = (raw[0] & 0x0F) as usize; - let mut bytes_parsed = 1; + let mut num_key_exchanges = (raw[1] >> 4) as usize; + let mut num_hkdfs = (raw[1] & 0x0F) as usize; + let mut num_ciphers = (raw[2] >> 4) as usize; + let mut bytes_parsed = 3; let mut result = Self { addresses: Vec::with_capacity(num_addresses), public_keys: Vec::with_capacity(num_public_keys), + key_exchanges: Vec::with_capacity(num_key_exchanges), + hkdfs: Vec::with_capacity(num_hkdfs), + ciphers: Vec::with_capacity(num_ciphers), }; while num_addresses > 0 { @@ -433,23 +499,97 @@ impl Record { num_addresses = num_addresses - 1; } while num_public_keys > 0 { + if bytes_parsed + 2 >= raw.len() { + return Err(Error::NotEnoughData(bytes_parsed)); + } let id = PublicKeyID(raw[bytes_parsed]); bytes_parsed = bytes_parsed + 1; - let (public_key, bytes) = - match PubKey::deserialize(&raw[bytes_parsed..]) { - Ok(public_key) => public_key, - Err(enc::Error::UnsupportedKey(b)) => { - return Err(Error::UnsupportedData(bytes_parsed + b)) - } - Err(enc::Error::NotEnoughData(b)) => { - return Err(Error::NotEnoughData(bytes_parsed + b)) - } - _ => return Err(Error::UnknownData(bytes_parsed)), - }; + let pubkey_length = raw[bytes_parsed] as usize; + bytes_parsed = bytes_parsed + 1; + if pubkey_length + bytes_parsed >= raw.len() { + return Err(Error::NotEnoughData(bytes_parsed)); + } + let (public_key, bytes) = match PubKey::deserialize( + &raw[bytes_parsed..(bytes_parsed + pubkey_length)], + ) { + Ok(public_key_and_bytes) => public_key_and_bytes, + Err(enc::Error::UnsupportedKey(_)) => { + // continue parsing. This could be a new pubkey type + // that is not supported by an older client + ::tracing::warn!("Unsupported public key type"); + bytes_parsed = bytes_parsed + pubkey_length; + continue; + } + Err(_) => { + return Err(Error::UnsupportedData(bytes_parsed)); + } + }; + if bytes != 1 + pubkey_length { + return Err(Error::UnsupportedData(bytes_parsed)); + } bytes_parsed = bytes_parsed + bytes; result.public_keys.push((id, public_key)); num_public_keys = num_public_keys - 1; } + if bytes_parsed + num_key_exchanges + num_hkdfs + num_ciphers + != raw.len() + { + return Err(Error::NotEnoughData(bytes_parsed)); + } + while num_key_exchanges > 0 { + let key_exchange = match KeyExchange::from_u8(raw[bytes_parsed]) { + Some(key_exchange) => key_exchange, + None => { + // continue parsing. This could be a new key exchange type + // that is not supported by an older client + ::tracing::warn!( + "Unknown Key exchange {}. Ignoring", + raw[bytes_parsed] + ); + bytes_parsed = bytes_parsed + 1; + continue; + } + }; + bytes_parsed = bytes_parsed + 1; + result.key_exchanges.push(key_exchange); + num_key_exchanges = num_key_exchanges - 1; + } + while num_hkdfs > 0 { + let hkdf = match HkdfKind::from_u8(raw[bytes_parsed]) { + Some(hkdf) => hkdf, + None => { + // continue parsing. This could be a new hkdf type + // that is not supported by an older client + ::tracing::warn!( + "Unknown hkdf {}. Ignoring", + raw[bytes_parsed] + ); + bytes_parsed = bytes_parsed + 1; + continue; + } + }; + bytes_parsed = bytes_parsed + 1; + result.hkdfs.push(hkdf); + num_hkdfs = num_hkdfs - 1; + } + while num_ciphers > 0 { + let cipher = match CipherKind::from_u8(raw[bytes_parsed]) { + Some(cipher) => cipher, + None => { + // continue parsing. This could be a new cipher type + // that is not supported by an older client + ::tracing::warn!( + "Unknown Cipher {}. Ignoring", + raw[bytes_parsed] + ); + bytes_parsed = bytes_parsed + 1; + continue; + } + }; + bytes_parsed = bytes_parsed + 1; + result.ciphers.push(cipher); + num_ciphers = num_ciphers - 1; + } if bytes_parsed != raw.len() { Err(Error::UnknownData(bytes_parsed)) } else { @@ -470,6 +610,15 @@ pub enum Error { /// Too many addresses (max 16) #[error("can't encode more than 16 addresses")] Max16Addresses, + /// Too many key exchanges (max 16) + #[error("can't encode more than 16 key exchanges")] + Max16KeyExchanges, + /// Too many Hkdfs (max 16) + #[error("can't encode more than 16 Hkdfs")] + Max16Hkdfs, + /// Too many ciphers (max 16) + #[error("can't encode more than 16 Ciphers")] + Max16Ciphers, /// We need at least one public key #[error("no public keys found")] NoPublicKeyFound, diff --git a/src/enc/hkdf.rs b/src/enc/hkdf.rs index 15d7eca..a4b6868 100644 --- a/src/enc/hkdf.rs +++ b/src/enc/hkdf.rs @@ -1,12 +1,52 @@ //! Hash-based Key Derivation Function //! We just repackage other crates -use ::hkdf::Hkdf; use ::sha3::Sha3_256; use ::zeroize::Zeroize; use crate::enc::sym::Secret; +/// Kind of HKDF +#[derive(Debug, Copy, Clone, PartialEq, ::num_derive::FromPrimitive)] +#[non_exhaustive] +#[repr(u8)] +pub enum HkdfKind { + /// Sha3 + Sha3 = 0, +} + +/// Generic wrapper on Hkdfs +#[derive(Clone)] +pub enum Hkdf { + /// Sha3 based + Sha3(HkdfSha3), +} + +// Fake debug implementation to avoid leaking secrets +impl ::core::fmt::Debug for Hkdf { + fn fmt( + &self, + f: &mut core::fmt::Formatter<'_>, + ) -> Result<(), ::std::fmt::Error> { + ::core::fmt::Debug::fmt("[hidden hkdf]", f) + } +} + +impl Hkdf { + /// New Hkdf + pub fn new(kind: HkdfKind, salt: &[u8], key: Secret) -> Self { + match kind { + HkdfKind::Sha3 => Self::Sha3(HkdfSha3::new(salt, key)), + } + } + /// Get a secret generated from the key and a given context + pub fn get_secret(&self, context: &[u8]) -> Secret { + match self { + Hkdf::Sha3(sha3) => sha3.get_secret(context), + } + } +} + // Hack & tricks: // HKDF are pretty important, but this lib don't zero out the data. // we can't use #[derive(Zeroing)] either. @@ -14,10 +54,10 @@ use crate::enc::sym::Secret; #[derive(Zeroize)] #[zeroize(drop)] -struct Zeroable([u8; ::core::mem::size_of::>()]); +struct Zeroable([u8; ::core::mem::size_of::<::hkdf::Hkdf>()]); union HkdfInner { - hkdf: ::core::mem::ManuallyDrop>, + hkdf: ::core::mem::ManuallyDrop<::hkdf::Hkdf>, zeroable: ::core::mem::ManuallyDrop, } @@ -50,7 +90,7 @@ pub struct HkdfSha3 { impl HkdfSha3 { /// Instantiate a new HKDF with Sha3-256 pub fn new(salt: &[u8], key: Secret) -> Self { - let hkdf = Hkdf::::new(Some(salt), key.as_ref()); + let hkdf = ::hkdf::Hkdf::::new(Some(salt), key.as_ref()); Self { inner: HkdfInner { hkdf: ::core::mem::ManuallyDrop::new(hkdf), diff --git a/src/inner/mod.rs b/src/inner/mod.rs index 9ee6942..b782258 100644 --- a/src/inner/mod.rs +++ b/src/inner/mod.rs @@ -13,7 +13,7 @@ use crate::{ }, enc::{ self, asym, - hkdf::HkdfSha3, + hkdf::{Hkdf, HkdfKind}, sym::{CipherKind, CipherRecv}, }, Error, @@ -26,7 +26,7 @@ pub(crate) struct AuthNeededInfo { /// Parsed handshake packet pub handshake: Handshake, /// hkdf generated from the handshake - pub hkdf: HkdfSha3, + pub hkdf: Hkdf, /// cipher to be used in both directions pub cipher: CipherKind, } @@ -149,7 +149,7 @@ impl HandshakeTracker { Ok(shared_key) => shared_key, Err(e) => return Err(handshake::Error::Key(e).into()), }; - let hkdf = HkdfSha3::new(b"fenrir", shared_key); + let hkdf = Hkdf::new(HkdfKind::Sha3, b"fenrir", shared_key); let secret_recv = hkdf.get_secret(b"to_server"); let cipher_recv = CipherRecv::new(req.cipher, secret_recv); use crate::enc::sym::AAD; diff --git a/src/inner/worker.rs b/src/inner/worker.rs index fcc69ea..1e34b56 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -11,7 +11,12 @@ use crate::{ ConnList, Connection, IDSend, Packet, }, dnssec, - enc::{asym::PubKey, hkdf::HkdfSha3, sym::Secret, Random}, + enc::{ + asym::PubKey, + hkdf::{Hkdf, HkdfKind}, + sym::Secret, + Random, + }, inner::{HandshakeAction, HandshakeTracker, ThreadTracker}, }; use ::std::{rc::Rc, sync::Arc, vec::Vec}; @@ -381,7 +386,8 @@ impl Worker { // SECURITY: //FIXME: the Secret should be XORed with the client stored // secret (if any) - let hkdf = HkdfSha3::new( + let hkdf = Hkdf::new( + HkdfKind::Sha3, cci.service_id.as_bytes(), resp_data.service_key, );