libFenrir/src/dnssec/record.rs
Luca Fulchir 376e8fb833
Remove some warnings
Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
2023-06-17 14:06:57 +02:00

680 lines
22 KiB
Rust

//!
//! Structs and information to create/parse the _fenrir DNSSEC record
//!
//! Encoding and decoding in base85, RFC1924
//!
//! Basic encoding idea:
//! * 1 byte: divided in two:
//! * half: num of addresses
//! * half: num of pubkeys
//! * 1 byte: divided in half:
//! * half: number of key exchanges
//! * half: number of Hkdfs
//! * 1 byte: divided in half:
//! * half: number of ciphers
//! * half: nothing
//! [ # list of pubkeys (max: 16)
//! * 2 byte: pubkey id
//! * 1 byte: pubkey length
//! * 1 byte: pubkey type
//! * Y bytes: pubkey
//! ]
//! [ # list of addresses
//! * 1 byte: bitfield
//! * 0..1 ipv4/ipv6
//! * 2..4 priority (for failover)
//! * 5..7 weight between priority
//! * 1 byte: divided in half:
//! * half: num of public key indexes
//! * half: num of handshake ids
//! * 2 bytes: UDP port
//! * [ HALF byte per public key idx ] (index on the list of public keys)
//! * [ 1 byte per handshake id ]
//! * X bytes: IP
//! ]
//! [ # list of supported key exchanges
//! * 1 byte for each cipher
//! ]
//! [ # list of supported HDKFs
//! * 1 byte for each hkdf
//! ]
//! [ # list of supported ciphers
//! * 1 byte for each cipher
//! ]
use crate::{
connection::handshake::HandshakeID,
enc::{
self,
asym::{KeyExchangeKind, KeyID, PubKey},
hkdf::HkdfKind,
sym::CipherKind,
},
};
use ::core::num::NonZeroU16;
use ::num_traits::FromPrimitive;
use ::std::{net::IpAddr, vec::Vec};
/*
* Public key data
*/
/// Public Key Index.
/// this points to the index of the public key in the array of public keys.
/// needed to have one byte less in the list of public keys
/// supported by and address
#[derive(Debug, Copy, Clone, PartialEq)]
pub struct PubKeyIdx(pub u8);
/*
* Address data
*/
/// Priority of each group of addresses
#[derive(::num_derive::FromPrimitive, Debug, Copy, Clone)]
// public enum: use non_exhaustive to force users to add a default case
// so in the future we can expand this easily
#[non_exhaustive]
#[repr(u8)]
pub enum AddressPriority {
/// Initially contact addresses in this priority
P1 = 0,
/// First failover
P2,
/// Second failover
P3,
/// Third failover
P4,
/// Fourth failover
P5,
/// Fifth failover
P6,
/// Sisth failover
P7,
/// Seventh failover
P8,
}
impl TryFrom<&str> for AddressPriority {
type Error = ::std::io::Error;
fn try_from(raw: &str) -> Result<Self, Self::Error> {
if let Ok(priority_u8) = raw.parse::<u8>() {
if priority_u8 >= 1 {
if let Some(priority) =
AddressPriority::from_u8(priority_u8 - 1)
{
return Ok(priority);
}
}
}
return Err(::std::io::Error::new(
::std::io::ErrorKind::InvalidData,
"Priority must be between 1 and 8",
));
}
}
/// Inside of each group, weight of the address
/// This helps in distributing the traffic to multiple authentication servers:
/// * client sums all weights in a group
/// * generate a random number [0..sum_of_weights]
/// * the number indicates which server will take the connection
/// So to equally distribute all connections you just have to use the same
/// weight in the same group
#[derive(::num_derive::FromPrimitive, Debug, Copy, Clone)]
// public enum: use non_exhaustive to force users to add a default case
// so in the future we can expand this easily
#[non_exhaustive]
#[repr(u8)]
pub enum AddressWeight {
/// Minimum weigth: 1
W1 = 0,
/// little weigth: 2
W2,
/// little weigth: 3
W3,
/// medium weigth: 4
W4,
/// medium weigth: 5
W5,
/// heavy weigth: 6
W6,
/// heavy weigth: 7
W7,
/// Maximum weigth: 8
W8,
}
impl TryFrom<&str> for AddressWeight {
type Error = ::std::io::Error;
fn try_from(raw: &str) -> Result<Self, Self::Error> {
if let Ok(weight_u8) = raw.parse::<u8>() {
if weight_u8 >= 1 {
if let Some(weight) = AddressWeight::from_u8(weight_u8 - 1) {
return Ok(weight);
}
}
}
return Err(::std::io::Error::new(
::std::io::ErrorKind::InvalidData,
"Weight must be between 1 and 8",
));
}
}
/// Authentication server address information:
/// * ip
/// * udp port
/// * priority
/// * weight within priority
/// * list of supported handshakes IDs
/// * list of public keys. Indexes in the Record.public_keys
#[derive(Debug, Clone)]
pub struct Address {
/// Ip address of server, v4 or v6
pub ip: IpAddr,
/// udp port. None means that this address is reachable only
/// with Fenrir over IP
pub port: Option<NonZeroU16>,
/// Priority group of this address
pub priority: AddressPriority,
/// Weight of this address in the priority group
pub weight: AddressWeight,
/// List of supported handshakes
pub handshake_ids: Vec<HandshakeID>,
/// Public key IDs used by this address
pub public_key_idx: Vec<PubKeyIdx>,
}
impl Address {
/// return this Address as a socket address
/// Note that since Fenrir can work on top of IP, without ports,
/// this is not guaranteed to return a SocketAddr
pub fn as_sockaddr(&self) -> Option<::std::net::SocketAddr> {
match self.port {
Some(port) => {
Some(::std::net::SocketAddr::new(self.ip, port.get()))
}
None => None,
}
}
fn len(&self) -> usize {
let mut size = 4;
let num_pubkey_idx = self.public_key_idx.len();
let idx_bytes = (num_pubkey_idx / 2) + (num_pubkey_idx % 2);
size = size + idx_bytes + self.handshake_ids.len();
size + match self.ip {
IpAddr::V4(_) => 4,
IpAddr::V6(_) => 16,
}
}
fn serialize_into(&self, raw: &mut [u8]) {
let mut bitfield: u8 = match self.ip {
IpAddr::V4(_) => 0,
IpAddr::V6(_) => 1,
};
bitfield <<= 3;
bitfield |= self.priority as u8;
bitfield <<= 3;
bitfield |= self.weight as u8;
raw[0] = bitfield;
let len_combined: u8 = self.public_key_idx.len() as u8;
let len_combined = len_combined << 4;
let len_combined = len_combined | self.handshake_ids.len() as u8;
raw[1] = len_combined;
raw[2..4].copy_from_slice(
&(match self.port {
Some(port) => port.get().to_le_bytes(),
None => [0, 0], // oh noez, which zero goes first?
}),
);
let mut written: usize = 4;
// pair every idx, since the max is 16
for chunk in self.public_key_idx.chunks(2) {
let second = {
if chunk.len() == 2 {
chunk[1].0
} else {
0
}
};
let tmp = chunk[0].0 << 4;
let tmp = tmp | second;
raw[written] = tmp;
written = written + 1;
}
for id in self.handshake_ids.iter() {
raw[written] = *id as u8;
written = written + 1;
}
let next_written;
match self.ip {
IpAddr::V4(ip) => {
next_written = written + 4;
let raw_ip = ip.octets();
raw[written..next_written].copy_from_slice(&raw_ip);
}
IpAddr::V6(ip) => {
next_written = written + 16;
let raw_ip = ip.octets();
raw[written..next_written].copy_from_slice(&raw_ip);
}
};
assert!(
next_written == raw.len(),
"write how much? {} - {}",
next_written,
raw.len()
);
}
fn decode_raw(raw: &[u8]) -> Result<(Self, usize), Error> {
if raw.len() < 9 {
return Err(Error::NotEnoughData(0));
}
// 3-byte bitfield
let ip_type = raw[0] >> 6;
let is_ipv6: bool;
let ip_len: usize;
match ip_type {
0 => {
is_ipv6 = false;
ip_len = 4;
}
1 => {
is_ipv6 = true;
ip_len = 16;
}
_ => return Err(Error::UnsupportedData(0)),
}
let raw_priority = (raw[0] << 2) >> 5;
let raw_weight = (raw[0] << 5) >> 5;
let priority = AddressPriority::from_u8(raw_priority).unwrap();
let weight = AddressWeight::from_u8(raw_weight).unwrap();
// Add publickey ids
let num_pubkey_idx = (raw[1] >> 4) as usize;
let num_handshake_ids = (raw[1] & 0x0F) as usize;
// UDP port
let raw_port = u16::from_le_bytes([raw[2], raw[3]]);
let port = if raw_port == 0 {
None
} else {
Some(NonZeroU16::new(raw_port).unwrap())
};
if raw.len() <= 3 + num_pubkey_idx + num_handshake_ids + ip_len {
return Err(Error::NotEnoughData(3));
}
let mut bytes_parsed = 4;
let mut public_key_idx = Vec::with_capacity(num_pubkey_idx);
let idx_bytes = (num_pubkey_idx / 2) + (num_pubkey_idx % 2);
let mut idx_added = 0;
for raw_pubkey_idx_pair in
raw[bytes_parsed..(bytes_parsed + idx_bytes)].iter()
{
let first = PubKeyIdx(raw_pubkey_idx_pair >> 4);
let second = PubKeyIdx(raw_pubkey_idx_pair & 0x0F);
public_key_idx.push(first);
if num_pubkey_idx - idx_added >= 2 {
public_key_idx.push(second);
}
idx_added = idx_added + 2;
}
bytes_parsed = bytes_parsed + idx_bytes;
// add handshake ids
let mut handshake_ids = Vec::with_capacity(num_handshake_ids);
for raw_handshake_id in
raw[bytes_parsed..(bytes_parsed + num_handshake_ids)].iter()
{
match HandshakeID::from_u8(*raw_handshake_id) {
Some(h_id) => handshake_ids.push(h_id),
None => {
::tracing::warn!(
"Unsupported handshake {}. Upgrade?",
*raw_handshake_id
);
// ignore unsupported handshakes
}
}
}
bytes_parsed = bytes_parsed + num_handshake_ids;
let ip = if is_ipv6 {
let ip_end = bytes_parsed + 16;
if raw.len() < ip_end {
return Err(Error::NotEnoughData(bytes_parsed));
}
let raw_ip: [u8; 16] =
raw[bytes_parsed..ip_end].try_into().unwrap();
bytes_parsed = bytes_parsed + 16;
IpAddr::from(raw_ip)
} else {
let ip_end = bytes_parsed + 4;
if raw.len() < ip_end {
return Err(Error::NotEnoughData(bytes_parsed));
}
let raw_ip: [u8; 4] = raw[bytes_parsed..ip_end].try_into().unwrap();
bytes_parsed = bytes_parsed + 4;
IpAddr::from(raw_ip)
};
Ok((
Self {
ip,
port,
priority,
weight,
public_key_idx,
handshake_ids,
},
bytes_parsed,
))
}
}
/*
* Actual record putting it all toghether
*/
/// All informations found in the DNSSEC record
#[derive(Debug, Clone)]
pub struct Record {
/// Public keys used by any authentication server
pub public_keys: Vec<(KeyID, PubKey)>,
/// List of all authentication servers' addresses.
/// Multiple ones can point to the same authentication server
pub addresses: Vec<Address>,
/// List of supported key exchanges
pub key_exchanges: Vec<KeyExchangeKind>,
/// List of supported key exchanges
pub hkdfs: Vec<HkdfKind>,
/// List of supported ciphers
pub ciphers: Vec<CipherKind>,
}
impl Record {
/// Simply encode all the record in base85
pub fn encode(&self) -> Result<String, Error> {
// check possible failure scenarios
if self.public_keys.len() == 0 {
return Err(Error::NoPublicKeyFound);
}
if self.public_keys.len() > 16 {
return Err(Error::Max16PublicKeys);
}
if self.addresses.len() == 0 {
return Err(Error::NoAddressFound);
}
if self.addresses.len() > 16 {
return Err(Error::Max16Addresses);
}
if self.key_exchanges.len() > 16 {
return Err(Error::Max16KeyExchanges);
}
if self.hkdfs.len() > 16 {
return Err(Error::Max16Hkdfs);
}
if self.ciphers.len() > 16 {
return Err(Error::Max16Ciphers);
}
// everything else is all good
let total_size: usize = 3
+ self.addresses.iter().map(|a| a.len()).sum::<usize>()
+ self
.public_keys
.iter()
.map(|(_, key)| 3 + key.kind().pub_len())
.sum::<usize>()
+ self.key_exchanges.len()
+ self.hkdfs.len()
+ self.ciphers.len();
let mut raw = Vec::with_capacity(total_size);
raw.resize(total_size, 0);
// amount of data. addresses, then pubkeys. 4 bits each
let len_combined: u8 = self.addresses.len() as u8;
let len_combined = len_combined << 4;
let len_combined = len_combined | self.public_keys.len() as u8;
raw[0] = len_combined;
// number of key exchanges and hkdfs
let len_combined: u8 = self.key_exchanges.len() as u8;
let len_combined = len_combined << 4;
let len_combined = len_combined | self.hkdfs.len() as u8;
raw[1] = len_combined;
let num_of_ciphers: u8 = (self.ciphers.len() as u8) << 4;
raw[2] = num_of_ciphers;
let mut written: usize = 3;
for (public_key_id, public_key) in self.public_keys.iter() {
let key_id_bytes = public_key_id.0.to_le_bytes();
let written_next = written + KeyID::len();
raw[written..written_next].copy_from_slice(&key_id_bytes);
written = written_next;
raw[written] = public_key.len() as u8;
written = written + 1;
let written_next = written + public_key.len();
public_key.serialize_into(&mut raw[written..written_next]);
written = written_next;
}
for address in self.addresses.iter() {
let len = address.len();
let written_next = written + len;
address.serialize_into(&mut raw[written..written_next]);
written = written_next;
}
for k_x in self.key_exchanges.iter() {
raw[written] = *k_x as u8;
written = written + 1;
}
for h in self.hkdfs.iter() {
raw[written] = *h as u8;
written = written + 1;
}
for c in self.ciphers.iter() {
raw[written] = *c as u8;
written = written + 1;
}
Ok(::base85::encode(&raw))
}
/// Decode from base85 to the actual object
pub fn decode(raw: &[u8]) -> Result<Self, Error> {
// bare minimum for lengths, (1 address), (1 key),no cipher negotiation
const MIN_RAW_LENGTH: usize = 3 + (6 + 4) + (4 + 32);
if raw.len() <= MIN_RAW_LENGTH {
return Err(Error::NotEnoughData(0));
}
let mut num_addresses = (raw[0] >> 4) as usize;
let mut num_public_keys = (raw[0] & 0x0F) as usize;
let mut num_key_exchanges = (raw[1] >> 4) as usize;
let mut num_hkdfs = (raw[1] & 0x0F) as usize;
let mut num_ciphers = (raw[2] >> 4) as usize;
let mut bytes_parsed = 3;
let mut result = Self {
addresses: Vec::with_capacity(num_addresses),
public_keys: Vec::with_capacity(num_public_keys),
key_exchanges: Vec::with_capacity(num_key_exchanges),
hkdfs: Vec::with_capacity(num_hkdfs),
ciphers: Vec::with_capacity(num_ciphers),
};
while num_public_keys > 0 {
if bytes_parsed + 3 >= raw.len() {
return Err(Error::NotEnoughData(bytes_parsed));
}
let raw_key_id =
u16::from_le_bytes([raw[bytes_parsed], raw[bytes_parsed + 1]]);
let id = KeyID(raw_key_id);
bytes_parsed = bytes_parsed + KeyID::len();
let pubkey_length = raw[bytes_parsed] as usize;
bytes_parsed = bytes_parsed + 1;
let bytes_next_key = bytes_parsed + pubkey_length;
if bytes_next_key > raw.len() {
return Err(Error::NotEnoughData(bytes_parsed));
}
let (public_key, bytes) =
match PubKey::deserialize(&raw[bytes_parsed..bytes_next_key]) {
Ok((public_key, bytes)) => (public_key, bytes),
Err(enc::Error::UnsupportedKey(_)) => {
// continue parsing. This could be a new pubkey type
// that is not supported by an older client
bytes_parsed = bytes_parsed + pubkey_length;
continue;
}
Err(_) => {
return Err(Error::UnsupportedData(bytes_parsed));
}
};
if bytes != pubkey_length {
return Err(Error::UnsupportedData(bytes_parsed));
}
bytes_parsed = bytes_parsed + bytes;
result.public_keys.push((id, public_key));
num_public_keys = num_public_keys - 1;
}
while num_addresses > 0 {
let (address, bytes) =
match Address::decode_raw(&raw[bytes_parsed..]) {
Ok(address) => address,
Err(Error::UnsupportedData(b)) => {
return Err(Error::UnsupportedData(bytes_parsed + b))
}
Err(Error::NotEnoughData(b)) => {
return Err(Error::NotEnoughData(bytes_parsed + b))
}
Err(e) => return Err(e),
};
bytes_parsed = bytes_parsed + bytes;
result.addresses.push(address);
num_addresses = num_addresses - 1;
}
for addr in result.addresses.iter() {
for idx in addr.public_key_idx.iter() {
if idx.0 as usize >= result.public_keys.len() {
return Err(Error::Max16PublicKeys);
}
if !result.public_keys[idx.0 as usize]
.1
.kind()
.capabilities()
.has_exchange()
{
return Err(Error::UnsupportedData(bytes_parsed));
}
}
}
if bytes_parsed + num_key_exchanges + num_hkdfs + num_ciphers
!= raw.len()
{
return Err(Error::NotEnoughData(bytes_parsed));
}
while num_key_exchanges > 0 {
let key_exchange = match KeyExchangeKind::from_u8(raw[bytes_parsed])
{
Some(key_exchange) => key_exchange,
None => {
// continue parsing. This could be a new key exchange type
// that is not supported by an older client
::tracing::warn!(
"Unknown Key exchange {}. Ignoring",
raw[bytes_parsed]
);
bytes_parsed = bytes_parsed + 1;
continue;
}
};
bytes_parsed = bytes_parsed + 1;
result.key_exchanges.push(key_exchange);
num_key_exchanges = num_key_exchanges - 1;
}
while num_hkdfs > 0 {
let hkdf = match HkdfKind::from_u8(raw[bytes_parsed]) {
Some(hkdf) => hkdf,
None => {
// continue parsing. This could be a new hkdf type
// that is not supported by an older client
::tracing::warn!(
"Unknown hkdf {}. Ignoring",
raw[bytes_parsed]
);
bytes_parsed = bytes_parsed + 1;
continue;
}
};
bytes_parsed = bytes_parsed + 1;
result.hkdfs.push(hkdf);
num_hkdfs = num_hkdfs - 1;
}
while num_ciphers > 0 {
let cipher = match CipherKind::from_u8(raw[bytes_parsed]) {
Some(cipher) => cipher,
None => {
// continue parsing. This could be a new cipher type
// that is not supported by an older client
::tracing::warn!(
"Unknown Cipher {}. Ignoring",
raw[bytes_parsed]
);
bytes_parsed = bytes_parsed + 1;
continue;
}
};
bytes_parsed = bytes_parsed + 1;
result.ciphers.push(cipher);
num_ciphers = num_ciphers - 1;
}
if bytes_parsed != raw.len() {
Err(Error::UnknownData(bytes_parsed))
} else {
Ok(result)
}
}
}
/// Possible errors in encoding or decoding the DNSSEC record
#[derive(::thiserror::Error, Debug)]
pub enum Error {
/// General IO error
#[error("IO error: {0:?}")]
IO(#[from] ::std::io::Error),
/// We need at least one address
#[error("no addresses found")]
NoAddressFound,
/// Too many addresses (max 16)
#[error("can't encode more than 16 addresses")]
Max16Addresses,
/// Too many key exchanges (max 16)
#[error("can't encode more than 16 key exchanges")]
Max16KeyExchanges,
/// Too many Hkdfs (max 16)
#[error("can't encode more than 16 Hkdfs")]
Max16Hkdfs,
/// Too many ciphers (max 16)
#[error("can't encode more than 16 Ciphers")]
Max16Ciphers,
/// We need at least one public key
#[error("no public keys found")]
NoPublicKeyFound,
/// Maximum 16 public keys supported
#[error("can't encode more than 16 public keys")]
Max16PublicKeys,
/// Not enough data to decode something meaningful
#[error("not enough data. Parsed {0} bytes")]
NotEnoughData(usize),
/// Unsupported Data: can't parse
#[error("Unsupported data. Parsed {0} bytes")]
UnsupportedData(usize),
/// Unknown data at the end
#[error("Unknown data after {0} bytes")]
UnknownData(usize),
}