From 28cbe2ae20b6b5ce6a1ede8c45df4b8922292277 Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Mon, 22 May 2023 15:05:17 +0200 Subject: [PATCH] more refactoring Signed-off-by: Luca Fulchir --- src/connection/handshake/mod.rs | 7 +- src/connection/mod.rs | 20 ++- src/connection/socket.rs | 42 +++++ src/inner/mod.rs | 17 +- src/lib.rs | 299 ++++++++++++++------------------ 5 files changed, 200 insertions(+), 185 deletions(-) diff --git a/src/connection/handshake/mod.rs b/src/connection/handshake/mod.rs index 74d3be1..99eab75 100644 --- a/src/connection/handshake/mod.rs +++ b/src/connection/handshake/mod.rs @@ -2,12 +2,12 @@ pub mod dirsync; -use ::num_traits::FromPrimitive; - use crate::{ connection::{self, ProtocolVersion}, enc::sym::{HeadLen, TagLen}, }; +use ::num_traits::FromPrimitive; +use ::std::sync::Arc; /// Handshake errors #[derive(::thiserror::Error, Debug, Copy, Clone)] @@ -36,8 +36,7 @@ pub(crate) struct HandshakeServer { pub(crate) struct HandshakeClient { pub id: crate::enc::asym::KeyID, pub key: crate::enc::asym::PrivKey, - pub hkdf: crate::enc::hkdf::HkdfSha3, - pub cipher: crate::enc::sym::CipherKind, + pub connection: Arc, } /// Parsed handshake diff --git a/src/connection/mod.rs b/src/connection/mod.rs index b247cbc..3c74c94 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -16,6 +16,13 @@ use crate::enc::{ sym::{CipherKind, CipherRecv, CipherSend}, }; +/// strong typedef for receiving connection id +#[derive(Debug, Copy, Clone)] +pub struct IDRecv(pub ID); +/// strong typedef for sending connection id +#[derive(Debug, Copy, Clone)] +pub struct IDSend(pub ID); + /// Version of the fenrir protocol in use #[derive(::num_derive::FromPrimitive, Debug, Copy, Clone)] #[repr(u8)] @@ -37,8 +44,10 @@ impl ProtocolVersion { /// A single connection and its data #[derive(Debug)] pub struct Connection { - /// Connection ID - pub id: ID, + /// Receiving Connection ID + pub id_recv: IDRecv, + /// Sending Connection ID + pub id_send: IDSend, /// The main hkdf used for all secrets in this connection pub hkdf: HkdfSha3, /// Cipher for decrypting data @@ -78,7 +87,8 @@ impl Connection { let mut cipher_send = CipherSend::new(cipher, secret_send, rand); Self { - id: ID::Handshake, + id_recv: IDRecv(ID::Handshake), + id_send: IDSend(ID::Handshake), hkdf, cipher_recv, cipher_send, @@ -132,8 +142,8 @@ impl ConnList { id = (self.ids_used.len() as u64) * 1024; self.ids_used.push(new_bitmap); } - let new_id = ID::new_u64(id); - conn.id = new_id; + let new_id = IDRecv(ID::new_u64(id)); + conn.id_recv = new_id; let conn = Arc::new(conn); if (self.connections.len() as u64) < id { self.connections.push(Some(conn.clone())); diff --git a/src/connection/socket.rs b/src/connection/socket.rs index 455c567..b95e7c9 100644 --- a/src/connection/socket.rs +++ b/src/connection/socket.rs @@ -106,3 +106,45 @@ pub(crate) struct UdpClient(pub SocketAddr); /// Strong typedef for a server socket address #[derive(Debug, Copy, Clone)] pub(crate) struct UdpServer(pub SocketAddr); + +/// Enable some common socket options. This is just the unsafe part +fn enable_sock_opt( + fd: ::std::os::fd::RawFd, + option: ::libc::c_int, + value: ::libc::c_int, +) -> ::std::io::Result<()> { + #[allow(unsafe_code)] + unsafe { + #[allow(trivial_casts)] + let val = &value as *const _ as *const ::libc::c_void; + let size = ::std::mem::size_of_val(&value) as ::libc::socklen_t; + // always clear the error bit before doing a new syscall + let _ = ::std::io::Error::last_os_error(); + let ret = ::libc::setsockopt(fd, ::libc::SOL_SOCKET, option, val, size); + if ret != 0 { + return Err(::std::io::Error::last_os_error()); + } + } + Ok(()) +} +/// Add an async udp listener +pub async fn bind_udp(sock: SocketAddr) -> ::std::io::Result { + let socket = UdpSocket::bind(sock).await?; + + use ::std::os::fd::AsRawFd; + let fd = socket.as_raw_fd(); + // can be useful later on for reloads + enable_sock_opt(fd, ::libc::SO_REUSEADDR, 1)?; + enable_sock_opt(fd, ::libc::SO_REUSEPORT, 1)?; + + // We will do path MTU discovery by ourselves, + // always set the "don't fragment" bit + if sock.is_ipv6() { + enable_sock_opt(fd, ::libc::IPV6_DONTFRAG, 1)?; + } else { + // FIXME: linux only + enable_sock_opt(fd, ::libc::IP_MTU_DISCOVER, ::libc::IP_PMTUDISC_DO)?; + } + + Ok(socket) +} diff --git a/src/inner/mod.rs b/src/inner/mod.rs index 22cafc8..6007868 100644 --- a/src/inner/mod.rs +++ b/src/inner/mod.rs @@ -21,7 +21,7 @@ use ::std::{sync::Arc, vec::Vec}; /// Information needed to reply after the key exchange #[derive(Debug, Clone)] pub struct AuthNeededInfo { - /// Parsed handshake + /// Parsed handshake packet pub handshake: Handshake, /// hkdf generated from the handshake pub hkdf: HkdfSha3, @@ -32,12 +32,10 @@ pub struct AuthNeededInfo { /// Client information needed to fully establish the conenction #[derive(Debug)] pub struct ClientConnectInfo { - /// Parsed handshake + /// Parsed handshake packet pub handshake: Handshake, - /// hkdf generated from the handshake - pub hkdf: HkdfSha3, - /// cipher to be used in both directions - pub cipher_recv: CipherRecv, + /// Connection + pub connection: Arc, } /// Intermediate actions to be taken while parsing the handshake #[derive(Debug)] @@ -182,9 +180,7 @@ impl Tracker { return Err(handshake::Error::UnknownKeyID.into()); } let hshake = hshake.unwrap(); - let secret_recv = hshake.hkdf.get_secret(b"to_client"); - let cipher_recv = - CipherRecv::new(hshake.cipher, secret_recv); + let cipher_recv = &hshake.connection.cipher_recv; use crate::enc::sym::AAD; // no aad for now let aad = AAD(&mut []); @@ -202,8 +198,7 @@ impl Tracker { return Ok(HandshakeAction::ClientConnect( ClientConnectInfo { handshake, - hkdf: hshake.hkdf, - cipher_recv, + connection: hshake.connection, }, )); } diff --git a/src/lib.rs b/src/lib.rs index 6972d8b..a4faa07 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -33,9 +33,12 @@ use ::tokio::{ use crate::{ connection::{ - handshake::{self, Handshake, HandshakeClient, HandshakeServer}, + handshake::{ + self, dirsync::DirSync, Handshake, HandshakeClient, HandshakeData, + HandshakeServer, + }, socket::{SocketList, UdpClient, UdpServer}, - ConnList, Connection, + ConnList, Connection, IDSend, }, enc::{ asym, @@ -167,34 +170,13 @@ impl Fenrir { self.dnssec = None; } - /// Enable some common socket options. This is just the unsafe part - fn enable_sock_opt( - fd: ::std::os::fd::RawFd, - option: ::libc::c_int, - value: ::libc::c_int, - ) -> ::std::io::Result<()> { - #[allow(unsafe_code)] - unsafe { - #[allow(trivial_casts)] - let val = &value as *const _ as *const ::libc::c_void; - let size = ::std::mem::size_of_val(&value) as ::libc::socklen_t; - // always clear the error bit before doing a new syscall - let _ = ::std::io::Error::last_os_error(); - let ret = - ::libc::setsockopt(fd, ::libc::SOL_SOCKET, option, val, size); - if ret != 0 { - return Err(::std::io::Error::last_os_error()); - } - } - Ok(()) - } - /// Add all UDP sockets found in config /// and start listening for packets async fn add_sockets(&self) -> ::std::io::Result<()> { let sockets = self.cfg.listen.iter().map(|s_addr| async { let socket = - ::tokio::spawn(Self::bind_udp(s_addr.clone())).await??; + ::tokio::spawn(connection::socket::bind_udp(s_addr.clone())) + .await??; Ok(socket) }); let sockets = ::futures::future::join_all(sockets).await; @@ -218,32 +200,6 @@ impl Fenrir { Ok(()) } - /// Add an async udp listener - async fn bind_udp(sock: SocketAddr) -> ::std::io::Result { - let socket = UdpSocket::bind(sock).await?; - - use ::std::os::fd::AsRawFd; - let fd = socket.as_raw_fd(); - // can be useful later on for reloads - Self::enable_sock_opt(fd, ::libc::SO_REUSEADDR, 1)?; - Self::enable_sock_opt(fd, ::libc::SO_REUSEPORT, 1)?; - - // We will do path MTU discovery by ourselves, - // always set the "don't fragment" bit - if sock.is_ipv6() { - Self::enable_sock_opt(fd, ::libc::IPV6_DONTFRAG, 1)?; - } else { - // FIXME: linux only - Self::enable_sock_opt( - fd, - ::libc::IP_MTU_DISCOVER, - ::libc::IP_PMTUDISC_DO, - )?; - } - - Ok(socket) - } - /// Run a dedicated loop to read packets on the listening socket async fn listen_udp( mut stop_working: ::tokio::sync::broadcast::Receiver, @@ -350,121 +306,134 @@ impl Fenrir { dirsync::{self, DirSync}, HandshakeData, }; - match authinfo.handshake.data { - HandshakeData::DirSync(ds) => match ds { - DirSync::Req(req) => { - use dirsync::ReqInner; - let req_data = match req.data { - ReqInner::ClearText(req_data) => req_data, - _ => { - ::tracing::error!( - "token_check: expected ClearText" - ); - return; - } - }; - let is_authenticated = match tk_check( - req_data.auth.user, - req_data.auth.token, - req_data.auth.service_id, - req_data.auth.domain, - ) - .await - { - Ok(is_authenticated) => is_authenticated, - Err(_) => { - ::tracing::error!( - "error in token auth" - ); - // TODO: retry? - return; - } - }; - if !is_authenticated { - ::tracing::warn!( - "Wrong authentication for user {:?}", - req_data.auth.user - ); - // TODO: error response - return; - } - // Client has correctly authenticated - // TODO: contact the service, get the key and - // connection ID - let srv_conn_id = - connection::ID::new_rand(&self.rand); - let srv_secret = - enc::sym::Secret::new_rand(&self.rand); - let head_len = req.cipher.nonce_len(); - let tag_len = req.cipher.tag_len(); - - let raw_conn = Connection::new( - authinfo.hkdf, - req.cipher, - connection::Role::Server, - &self.rand, - ); - // track connection - let auth_conn = { - let mut lock = - self.connections.write().await; - lock.reserve_first(raw_conn) - }; - - let resp_data = dirsync::RespData { - client_nonce: req_data.nonce, - id: auth_conn.id, - service_id: srv_conn_id, - service_key: srv_secret, - }; - use crate::enc::sym::AAD; - // no aad for now - let aad = AAD(&mut []); - - use dirsync::RespInner; - let resp = dirsync::Resp { - client_key_id: req_data.client_key_id, - data: RespInner::ClearText(resp_data), - }; - let offset_to_encrypt = resp.encrypted_offset(); - let encrypt_until = offset_to_encrypt - + resp.encrypted_length() - + tag_len.0; - let resp_handshake = Handshake::new( - HandshakeData::DirSync(DirSync::Resp(resp)), - ); - use connection::{Packet, PacketData, ID}; - let packet = Packet { - id: ID::new_handshake(), - data: PacketData::Handshake(resp_handshake), - }; - let mut raw_out = - Vec::::with_capacity(packet.len()); - packet.serialize( - head_len, - tag_len, - &mut raw_out, - ); - - if let Err(e) = auth_conn.cipher_send.encrypt( - aad, - &mut raw_out - [offset_to_encrypt..encrypt_until], - ) { - ::tracing::error!("can't encrypt: {:?}", e); - return; - } - self.send_packet(raw_out, udp.src, udp.dst) - .await; - } - DirSync::Resp(resp) => { - todo!() - } - _ => { - todo!() - } - }, + let req; + if let HandshakeData::DirSync(DirSync::Req(r)) = + authinfo.handshake.data + { + req = r; + } else { + ::tracing::error!("AuthInfo on non DS::Req"); + return; } + use dirsync::ReqInner; + let req_data = match req.data { + ReqInner::ClearText(req_data) => req_data, + _ => { + ::tracing::error!( + "token_check: expected ClearText" + ); + return; + } + }; + let is_authenticated = match tk_check( + req_data.auth.user, + req_data.auth.token, + req_data.auth.service_id, + req_data.auth.domain, + ) + .await + { + Ok(is_authenticated) => is_authenticated, + Err(_) => { + ::tracing::error!("error in token auth"); + // TODO: retry? + return; + } + }; + if !is_authenticated { + ::tracing::warn!( + "Wrong authentication for user {:?}", + req_data.auth.user + ); + // TODO: error response + return; + } + // Client has correctly authenticated + // TODO: contact the service, get the key and + // connection ID + let srv_conn_id = connection::ID::new_rand(&self.rand); + let srv_secret = enc::sym::Secret::new_rand(&self.rand); + let head_len = req.cipher.nonce_len(); + let tag_len = req.cipher.tag_len(); + + let mut raw_conn = Connection::new( + authinfo.hkdf, + req.cipher, + connection::Role::Server, + &self.rand, + ); + raw_conn.id_send = IDSend(req_data.id); + // track connection + let auth_conn = { + let mut lock = self.connections.write().await; + lock.reserve_first(raw_conn) + }; + + let resp_data = dirsync::RespData { + client_nonce: req_data.nonce, + id: auth_conn.id_recv.0, + service_id: srv_conn_id, + service_key: srv_secret, + }; + use crate::enc::sym::AAD; + // no aad for now + let aad = AAD(&mut []); + + use dirsync::RespInner; + let resp = dirsync::Resp { + client_key_id: req_data.client_key_id, + data: RespInner::ClearText(resp_data), + }; + let offset_to_encrypt = resp.encrypted_offset(); + let encrypt_until = + offset_to_encrypt + resp.encrypted_length() + tag_len.0; + let resp_handshake = Handshake::new( + HandshakeData::DirSync(DirSync::Resp(resp)), + ); + use connection::{Packet, PacketData, ID}; + let packet = Packet { + id: ID::new_handshake(), + data: PacketData::Handshake(resp_handshake), + }; + let mut raw_out = Vec::::with_capacity(packet.len()); + packet.serialize(head_len, tag_len, &mut raw_out); + + if let Err(e) = auth_conn.cipher_send.encrypt( + aad, + &mut raw_out[offset_to_encrypt..encrypt_until], + ) { + ::tracing::error!("can't encrypt: {:?}", e); + return; + } + self.send_packet(raw_out, udp.src, udp.dst).await; + return; + } + HandshakeAction::ClientConnect(mut cci) => { + let ds_resp; + if let HandshakeData::DirSync(DirSync::Resp(resp)) = + cci.handshake.data + { + ds_resp = resp; + } else { + ::tracing::error!("ClientConnect on non DS::Resp"); + return; + } + // track connection + use handshake::dirsync; + let resp_data; + if let dirsync::RespInner::ClearText(r_data) = ds_resp.data + { + resp_data = r_data; + } else { + ::tracing::error!( + "ClientConnect on non DS::Resp::ClearText" + ); + return; + } + // FIXME: conn tracking and arc counting + let conn = Arc::get_mut(&mut cci.connection).unwrap(); + conn.id_send = IDSend(resp_data.id); + todo!(); } _ => {} };