From d8cc5ca97499973f3b02d3328c1fa7f3b4638889 Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Thu, 22 Jun 2023 20:12:50 +0200 Subject: [PATCH] Stream enqueue and serialize to the packet Signed-off-by: Luca Fulchir --- src/connection/mod.rs | 158 ++++++++++++++++++++++++++++++----- src/connection/stream/mod.rs | 77 +++++++++++++++++ src/enc/sym.rs | 8 ++ src/inner/worker.rs | 39 ++++++++- 4 files changed, 261 insertions(+), 21 deletions(-) diff --git a/src/connection/mod.rs b/src/connection/mod.rs index ea5dfcb..337b48f 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -14,8 +14,10 @@ use ::std::{ pub use crate::connection::{handshake::Handshake, packet::Packet}; use crate::{ + connection::socket::{UdpClient, UdpServer}, dnssec, enc::{ + self, asym::PubKey, hkdf::Hkdf, sym::{self, CipherRecv, CipherSend}, @@ -165,23 +167,6 @@ impl Conn { } } -/// A single connection and its data -#[derive(Debug)] -pub(crate) struct Connection { - /// Receiving Conn ID - pub(crate) id_recv: IDRecv, - /// Sending Conn ID - pub(crate) id_send: IDSend, - /// The main hkdf used for all secrets in this connection - hkdf: Hkdf, - /// Cipher for decrypting data - pub(crate) cipher_recv: CipherRecv, - /// Cipher for encrypting data - pub(crate) cipher_send: CipherSend, - /// send queue for each Stream - send_queue: BTreeMap>>, -} - /// Role: track the connection direction /// /// The Role is used to select the correct secrets, and track the direction @@ -197,6 +182,41 @@ pub enum Role { Client, } +#[derive(Debug)] +enum TimerKind { + None, + SendData(::tokio::time::Instant), + Keepalive(::tokio::time::Instant), +} + +pub(crate) enum Enqueue { + NoSuchStream, + TimerWait, + Immediate, +} + +/// A single connection and its data +#[derive(Debug)] +pub(crate) struct Connection { + /// Receiving Conn ID + pub(crate) id_recv: IDRecv, + /// Sending Conn ID + pub(crate) id_send: IDSend, + /// Sending address + pub(crate) send_addr: UdpClient, + /// The main hkdf used for all secrets in this connection + hkdf: Hkdf, + /// Cipher for decrypting data + pub(crate) cipher_recv: CipherRecv, + /// Cipher for encrypting data + pub(crate) cipher_send: CipherSend, + mtu: usize, + next_timer: TimerKind, + /// send queue for each Stream + send_queue: BTreeMap, + last_stream_sent: stream::ID, +} + impl Connection { pub(crate) fn new( hkdf: Hkdf, @@ -215,21 +235,119 @@ impl Connection { let cipher_recv = CipherRecv::new(cipher, secret_recv); let cipher_send = CipherSend::new(cipher, secret_send, rand); + use ::std::net::{IpAddr, Ipv4Addr, SocketAddr}; Self { id_recv: IDRecv(ID::Handshake), id_send: IDSend(ID::Handshake), + send_addr: UdpClient(SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + 31337, + )), hkdf, cipher_recv, cipher_send, + mtu: 1280, + next_timer: TimerKind::None, send_queue: BTreeMap::new(), + last_stream_sent: stream::ID(0), } } - pub(crate) fn send(&mut self, stream: stream::ID, data: Vec) { + pub(crate) fn enqueue( + &mut self, + stream: stream::ID, + data: Vec, + ) -> Enqueue { let stream = match self.send_queue.get_mut(&stream) { - None => return, + None => return Enqueue::NoSuchStream, Some(stream) => stream, }; - stream.push(data); + stream.enqueue(data); + let ret; + self.next_timer = match self.next_timer { + TimerKind::None | TimerKind::Keepalive(_) => { + ret = Enqueue::Immediate; + TimerKind::SendData(::tokio::time::Instant::now()) + } + TimerKind::SendData(old_timer) => { + // There already is some data to be sent + // wait for this timer, + // or risk going over max transmission rate + ret = Enqueue::TimerWait; + TimerKind::SendData(old_timer) + } + }; + ret + } + pub(crate) fn write_pkt<'a>( + &mut self, + raw: &'a mut [u8], + ) -> Result<&'a [u8], enc::Error> { + assert!(raw.len() >= 1200, "I should have at least 1200 MTU"); + if self.send_queue.len() == 0 { + return Err(enc::Error::NotEnoughData(0)); + } + raw[..ID::len()] + .copy_from_slice(&self.id_send.0.as_u64().to_le_bytes()); + let data_from = ID::len() + self.cipher_send.nonce_len().0; + let data_max_to = raw.len() - self.cipher_send.tag_len().0; + let mut chunk_from = data_from; + let mut available_len = data_max_to - data_from; + + use std::ops::Bound::{Excluded, Included}; + let last_stream = self.last_stream_sent; + + // Loop over our streams, write them to the packet. + // Notes: + // * to avoid starvation, just round-robin them all for now + // * we can enqueue multiple times the same stream + // This is useful especially for Datagram streams + 'queueloop: { + for (id, stream) in self + .send_queue + .range_mut((Included(last_stream), Included(stream::ID::max()))) + { + if available_len < stream::Chunk::headers_len() + 1 { + break 'queueloop; + } + let bytes = + stream.serialize(*id, &mut raw[chunk_from..data_max_to]); + if bytes == 0 { + break 'queueloop; + } + available_len = available_len - bytes; + chunk_from = chunk_from + bytes; + self.last_stream_sent = *id; + } + if available_len > 0 { + for (id, stream) in self.send_queue.range_mut(( + Included(stream::ID::min()), + Excluded(last_stream), + )) { + if available_len < stream::Chunk::headers_len() + 1 { + break 'queueloop; + } + let bytes = stream + .serialize(*id, &mut raw[chunk_from..data_max_to]); + if bytes == 0 { + break 'queueloop; + } + available_len = available_len - bytes; + chunk_from = chunk_from + bytes; + self.last_stream_sent = *id; + } + } + } + if chunk_from == data_from { + return Err(enc::Error::NotEnoughData(0)); + } + let data_to = chunk_from + self.cipher_send.tag_len().0; + + // encrypt + let aad = sym::AAD(&[]); + match self.cipher_send.encrypt(aad, &mut raw[data_from..data_to]) { + Ok(_) => Ok(&raw[..data_to]), + Err(e) => Err(e), + } } } diff --git a/src/connection/stream/mod.rs b/src/connection/stream/mod.rs index 58dd76e..778c9c5 100644 --- a/src/connection/stream/mod.rs +++ b/src/connection/stream/mod.rs @@ -27,6 +27,14 @@ impl ID { pub const fn len() -> usize { 2 } + /// Minimum possible Stream ID (u16::MIN) + pub const fn min() -> Self { + Self(u16::MIN) + } + /// Maximum possible Stream ID (u16::MAX) + pub const fn max() -> Self { + Self(u16::MAX) + } } /// length of the chunk @@ -79,6 +87,10 @@ impl<'a> Chunk<'a> { const FLAGS_EXCLUDED_BITMASK: u8 = 0x3F; const FLAG_START_BITMASK: u8 = 0x80; const FLAG_END_BITMASK: u8 = 0x40; + /// Return the length of the header of a Chunk + pub const fn headers_len() -> usize { + ID::len() + ChunkLen::len() + Sequence::len() + } /// Returns the total length of the chunk, including headers pub fn len(&self) -> usize { ID::len() + ChunkLen::len() + Sequence::len() + self.data.len() @@ -181,3 +193,68 @@ impl Stream { } } } + +/// Track what has been sent and what has been ACK'd from a stream +#[derive(Debug)] +pub(crate) struct SendTracker { + queue: Vec>, + sent: Vec, + ackd: Vec, + chunk_started: bool, + is_datagram: bool, + next_sequence: Sequence, +} +impl SendTracker { + pub(crate) fn new(rand: &Random) -> Self { + Self { + queue: Vec::with_capacity(4), + sent: Vec::with_capacity(4), + ackd: Vec::with_capacity(4), + chunk_started: false, + is_datagram: false, + next_sequence: Sequence::new(rand), + } + } + /// Enqueue user data to be sent + pub(crate) fn enqueue(&mut self, data: Vec) { + self.queue.push(data); + self.sent.push(0); + self.ackd.push(0); + } + /// Write the user data to the buffer and mark it as sent + pub(crate) fn get(&mut self, out: &mut [u8]) -> usize { + let data = match self.queue.get(0) { + Some(data) => data, + None => return 0, + }; + let len = ::std::cmp::min(out.len(), data.len()); + out[..len].copy_from_slice(&data[self.sent[0]..len]); + self.sent[0] = self.sent[0] + len; + len + } + /// Mark the sent data as successfully received from the receiver + pub(crate) fn ack(&mut self, size: usize) { + todo!() + } + pub(crate) fn serialize(&mut self, id: ID, raw: &mut [u8]) -> usize { + let max_data_len = raw.len() - Chunk::headers_len(); + let data_len = ::std::cmp::min(max_data_len, self.queue[0].len()); + let flag_start = !self.chunk_started; + let flag_end = self.is_datagram && data_len == self.queue[0].len(); + let chunk = Chunk { + id, + flag_start, + flag_end, + sequence: self.next_sequence, + data: &self.queue[0][..data_len], + }; + self.next_sequence = Sequence( + self.next_sequence.0 + ::core::num::Wrapping(data_len as u32), + ); + if chunk.flag_end { + self.chunk_started = false; + } + chunk.serialize(raw); + data_len + } +} diff --git a/src/enc/sym.rs b/src/enc/sym.rs index 14d712d..6c48f73 100644 --- a/src/enc/sym.rs +++ b/src/enc/sym.rs @@ -241,6 +241,14 @@ impl CipherSend { pub fn kind(&self) -> Kind { self.cipher.kind() } + /// Get the length of the nonce for this cipher + pub fn nonce_len(&self) -> NonceLen { + self.cipher.nonce_len() + } + /// Get the length of the nonce for this cipher + pub fn tag_len(&self) -> TagLen { + self.cipher.tag_len() + } } /// XChaCha20Poly1305 cipher diff --git a/src/inner/worker.rs b/src/inner/worker.rs index 0bb41d0..ba9ec3a 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -16,6 +16,7 @@ use crate::{ }, dnssec, enc::{ + self, asym::{self, KeyID, PrivKey, PubKey}, hkdf::{self, Hkdf}, sym, Random, Secret, @@ -53,6 +54,7 @@ pub(crate) enum Work { DropHandshake(KeyID), Recv(RawUdp), UserSend((UserConnTracker, stream::ID, Vec)), + SendData(UserConnTracker), } /// Actual worker implementation. @@ -437,7 +439,42 @@ impl Worker { None => return, Some(conn) => conn, }; - conn.send(stream, data); + use connection::Enqueue; + match conn.enqueue(stream, data) { + Enqueue::Immediate => { + let _ = self + .queue_sender + .send(Work::SendData(tracker)) + .await; + } + Enqueue::TimerWait => {} + Enqueue::NoSuchStream => { + ::tracing::error!( + "Trying to send on unknown stream" + ); + } + } + } + Work::SendData(tracker) => { + let mut raw: Vec = Vec::with_capacity(1280); + raw.resize(raw.capacity(), 0); + let conn = match self.connections.get_mut(tracker) { + None => return, + Some(conn) => conn, + }; + let pkt = match conn.write_pkt(&mut raw) { + Ok(pkt) => pkt, + Err(enc::Error::NotEnoughData(0)) => return, + Err(e) => { + ::tracing::error!("Packet generation: {:?}", e); + return; + } + }; + let dest = conn.send_addr; + let src = UdpServer(self.sockets[0].local_addr().unwrap()); + let len = pkt.len(); + raw.truncate(len); + let _ = self.send_packet(raw, dest, src); } } }