//! Handhsake handling use crate::{ auth::{Domain, ServiceID}, connection::{ handshake::{self, Error, Handshake}, Conn, IDRecv, IDSend, }, enc::{ self, asym::{self, KeyID, PrivKey, PubKey}, hkdf::{self, Hkdf}, sym::{self, CipherRecv}, }, inner::ThreadTracker, }; use ::tokio::sync::oneshot; pub(crate) struct Server { pub id: KeyID, pub key: PrivKey, pub domains: Vec, } pub(crate) type ConnectAnswer = Result<(KeyID, IDSend), crate::Error>; pub(crate) struct Client { pub service_id: ServiceID, pub service_conn_id: IDRecv, pub connection: Conn, pub timeout: Option<::tokio::task::JoinHandle<()>>, pub answer: oneshot::Sender, pub srv_key_id: KeyID, } /// Tracks the keys used by the client and the handshake /// they are associated with pub(crate) struct ClientList { used: Vec<::bitmaps::Bitmap<1024>>, // index = KeyID keys: Vec>, list: Vec>, } impl ClientList { pub(crate) fn new() -> Self { Self { used: [::bitmaps::Bitmap::<1024>::new()].to_vec(), keys: Vec::with_capacity(16), list: Vec::with_capacity(16), } } pub(crate) fn get(&self, id: KeyID) -> Option<&Client> { if id.0 as usize >= self.list.len() { return None; } self.list[id.0 as usize].as_ref() } pub(crate) fn remove(&mut self, id: KeyID) -> Option { if id.0 as usize >= self.list.len() { return None; } let used_vec_idx = id.0 as usize / 1024; let used_bitmap_idx = id.0 as usize % 1024; let used_iter = match self.used.get_mut(used_vec_idx) { Some(used_iter) => used_iter, None => return None, }; used_iter.set(used_bitmap_idx, false); self.keys[id.0 as usize] = None; let mut owned = None; ::core::mem::swap(&mut self.list[id.0 as usize], &mut owned); owned } pub(crate) fn add( &mut self, priv_key: PrivKey, pub_key: PubKey, service_id: ServiceID, service_conn_id: IDRecv, connection: Conn, answer: oneshot::Sender, srv_key_id: KeyID, ) -> Result<(KeyID, &mut Client), oneshot::Sender> { let maybe_free_key_idx = self.used.iter().enumerate().find_map(|(idx, bmap)| { match bmap.first_false_index() { Some(false_idx) => Some(((idx * 1024), false_idx)), None => None, } }); let free_key_idx = match maybe_free_key_idx { Some((idx, false_idx)) => { let free_key_idx = idx * 1024 + false_idx; if free_key_idx > KeyID::MAX as usize { return Err(answer); } self.used[idx].set(false_idx, true); free_key_idx } None => { let mut bmap = ::bitmaps::Bitmap::<1024>::new(); bmap.set(0, true); self.used.push(bmap); self.used.len() * 1024 } }; if self.keys.len() >= free_key_idx { self.keys.push(None); self.list.push(None); } self.keys[free_key_idx] = Some((priv_key, pub_key)); self.list[free_key_idx] = Some(Client { service_id, service_conn_id, connection, timeout: None, answer, srv_key_id, }); Ok(( KeyID(free_key_idx as u16), self.list[free_key_idx].as_mut().unwrap(), )) } } /// Information needed to reply after the key exchange #[derive(Debug, Clone)] pub(crate) struct AuthNeededInfo { /// Parsed handshake packet pub handshake: Handshake, /// hkdf generated from the handshake pub hkdf: Hkdf, } /// Client information needed to fully establish the conenction #[derive(Debug)] pub(crate) struct ClientConnectInfo { /// The service ID that we are connecting to pub service_id: ServiceID, /// The service ID that we are connecting to pub service_connection_id: IDRecv, /// Parsed handshake packet pub handshake: Handshake, /// Conn pub connection: Conn, /// where to wake up the waiting client pub answer: oneshot::Sender, /// server public key id that we used on the handshake pub srv_key_id: KeyID, } /// Intermediate actions to be taken while parsing the handshake #[derive(Debug)] pub(crate) enum Action { /// Parsing finished, all ok, nothing to do Nothing, /// Packet parsed, now go perform authentication AuthNeeded(AuthNeededInfo), /// the client can fully establish a connection with this info ClientConnect(ClientConnectInfo), } /// Tracking of handhsakes and conenctions /// Note that we have multiple Handshake trackers, pinned to different cores /// Each of them will handle a subset of all handshakes. /// Each handshake is routed to a different tracker by checking /// core = (udp_src_sender_port % total_threads) - 1 pub(crate) struct Tracker { thread_id: ThreadTracker, key_exchanges: Vec, ciphers: Vec, /// ephemeral keys used server side in key exchange keys_srv: Vec, /// ephemeral keys used client side in key exchange hshake_cli: ClientList, } impl Tracker { pub(crate) fn new( thread_id: ThreadTracker, ciphers: Vec, key_exchanges: Vec, ) -> Self { Self { thread_id, ciphers, key_exchanges, keys_srv: Vec::new(), hshake_cli: ClientList::new(), } } pub(crate) fn add_server_key( &mut self, id: KeyID, key: PrivKey, ) -> Result<(), ()> { if self.keys_srv.iter().find(|&k| k.id == id).is_some() { return Err(()); } self.keys_srv.push(Server { id, key, domains: Vec::new(), }); self.keys_srv.sort_by(|h_a, h_b| h_a.id.0.cmp(&h_b.id.0)); Ok(()) } pub(crate) fn add_server_domain( &mut self, domain: &Domain, key_ids: &[KeyID], ) -> Result<(), ()> { // check that all the key ids are present for id in key_ids.iter() { if self.keys_srv.iter().find(|k| k.id == *id).is_none() { return Err(()); } } // add the domain to those keys for id in key_ids.iter() { if let Some(srv) = self.keys_srv.iter_mut().find(|k| k.id == *id) { srv.domains.push(domain.clone()); } } Ok(()) } pub(crate) fn add_client( &mut self, priv_key: PrivKey, pub_key: PubKey, service_id: ServiceID, service_conn_id: IDRecv, connection: Conn, answer: oneshot::Sender, srv_key_id: KeyID, ) -> Result<(KeyID, &mut Client), oneshot::Sender> { self.hshake_cli.add( priv_key, pub_key, service_id, service_conn_id, connection, answer, srv_key_id, ) } pub(crate) fn remove_client(&mut self, key_id: KeyID) -> Option { self.hshake_cli.remove(key_id) } pub(crate) fn timeout_client( &mut self, key_id: KeyID, ) -> Option<[IDRecv; 2]> { if let Some(hshake) = self.hshake_cli.remove(key_id) { let _ = hshake.answer.send(Err(Error::Timeout.into())); Some([hshake.connection.id_recv, hshake.service_conn_id]) } else { None } } pub(crate) fn recv_handshake( &mut self, mut handshake: Handshake, handshake_raw: &mut [u8], ) -> Result { use handshake::dirsync::DirSync; match handshake.data { handshake::Data::DirSync(ref mut ds) => match ds { DirSync::Req(ref mut req) => { if !self.key_exchanges.contains(&req.exchange) { return Err(enc::Error::UnsupportedKeyExchange.into()); } if !self.ciphers.contains(&req.cipher) { return Err(enc::Error::UnsupportedCipher.into()); } let has_key = self.keys_srv.iter().find(|k| { if k.id == req.key_id { // Directory synchronized can only use keys // for key exchange, not signing keys if let PrivKey::Exchange(_) = k.key { return true; } } false }); let ephemeral_key; match has_key { Some(s_k) => { if let PrivKey::Exchange(ref k) = &s_k.key { ephemeral_key = k; } else { unreachable!(); } } None => { return Err(handshake::Error::UnknownKeyID.into()) } } let shared_key = match ephemeral_key .key_exchange(req.exchange, req.exchange_key) { Ok(shared_key) => shared_key, Err(e) => return Err(handshake::Error::Key(e).into()), }; let hkdf = Hkdf::new(hkdf::Kind::Sha3, b"fenrir", shared_key); let secret_recv = hkdf.get_secret(b"to_server"); let cipher_recv = CipherRecv::new(req.cipher, secret_recv); use crate::enc::sym::AAD; let aad = AAD(&mut []); // no aad for now let encrypt_from = req.encrypted_offset(); let encrypt_to = encrypt_from + req.encrypted_length( cipher_recv.nonce_len(), cipher_recv.tag_len(), ); match cipher_recv.decrypt( aad, &mut handshake_raw[encrypt_from..encrypt_to], ) { Ok(cleartext) => { req.data.deserialize_as_cleartext(cleartext)?; } Err(e) => { return Err(handshake::Error::Key(e).into()); } } return Ok(Action::AuthNeeded(AuthNeededInfo { handshake, hkdf, })); } DirSync::Resp(resp) => { let hshake = match self.hshake_cli.get(resp.client_key_id) { Some(hshake) => hshake, None => { ::tracing::debug!( "No such client key id: {:?}", resp.client_key_id ); return Err(handshake::Error::UnknownKeyID.into()); } }; let cipher_recv = &hshake.connection.cipher_recv; use crate::enc::sym::AAD; // no aad for now let aad = AAD(&mut []); let data_from = resp.encrypted_offset(); let data_to = data_from + resp.encrypted_length( cipher_recv.nonce_len(), cipher_recv.tag_len(), ); let mut raw_data = &mut handshake_raw[data_from..data_to]; match cipher_recv.decrypt(aad, &mut raw_data) { Ok(cleartext) => { resp.data.deserialize_as_cleartext(&cleartext)?; } Err(e) => { return Err(handshake::Error::Key(e).into()); } } let hshake = self.hshake_cli.remove(resp.client_key_id).unwrap(); if let Some(timeout) = hshake.timeout { timeout.abort(); } return Ok(Action::ClientConnect(ClientConnectInfo { service_id: hshake.service_id, service_connection_id: hshake.service_conn_id, handshake, connection: hshake.connection, answer: hshake.answer, srv_key_id: hshake.srv_key_id, })); } }, } } }