Stream enqueue and serialize to the packet

Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
This commit is contained in:
Luca Fulchir 2023-06-22 20:12:50 +02:00
parent 9c67210e3e
commit a810fc9a9e
Signed by: luca.fulchir
GPG Key ID: 8F6440603D13A78E
4 changed files with 261 additions and 21 deletions

View File

@ -14,8 +14,10 @@ use ::std::{
pub use crate::connection::{handshake::Handshake, packet::Packet}; pub use crate::connection::{handshake::Handshake, packet::Packet};
use crate::{ use crate::{
connection::socket::{UdpClient, UdpServer},
dnssec, dnssec,
enc::{ enc::{
self,
asym::PubKey, asym::PubKey,
hkdf::Hkdf, hkdf::Hkdf,
sym::{self, CipherRecv, CipherSend}, sym::{self, CipherRecv, CipherSend},
@ -165,23 +167,6 @@ impl Conn {
} }
} }
/// A single connection and its data
#[derive(Debug)]
pub(crate) struct Connection {
/// Receiving Conn ID
pub(crate) id_recv: IDRecv,
/// Sending Conn ID
pub(crate) id_send: IDSend,
/// The main hkdf used for all secrets in this connection
hkdf: Hkdf,
/// Cipher for decrypting data
pub(crate) cipher_recv: CipherRecv,
/// Cipher for encrypting data
pub(crate) cipher_send: CipherSend,
/// send queue for each Stream
send_queue: BTreeMap<stream::ID, Vec<Vec<u8>>>,
}
/// Role: track the connection direction /// Role: track the connection direction
/// ///
/// The Role is used to select the correct secrets, and track the direction /// The Role is used to select the correct secrets, and track the direction
@ -197,6 +182,41 @@ pub enum Role {
Client, Client,
} }
#[derive(Debug)]
enum TimerKind {
None,
SendData(::tokio::time::Instant),
Keepalive(::tokio::time::Instant),
}
pub(crate) enum Enqueue {
NoSuchStream,
TimerWait,
Immediate,
}
/// A single connection and its data
#[derive(Debug)]
pub(crate) struct Connection {
/// Receiving Conn ID
pub(crate) id_recv: IDRecv,
/// Sending Conn ID
pub(crate) id_send: IDSend,
/// Sending address
pub(crate) send_addr: UdpClient,
/// The main hkdf used for all secrets in this connection
hkdf: Hkdf,
/// Cipher for decrypting data
pub(crate) cipher_recv: CipherRecv,
/// Cipher for encrypting data
pub(crate) cipher_send: CipherSend,
mtu: usize,
next_timer: TimerKind,
/// send queue for each Stream
send_queue: BTreeMap<stream::ID, stream::SendTracker>,
last_stream_sent: stream::ID,
}
impl Connection { impl Connection {
pub(crate) fn new( pub(crate) fn new(
hkdf: Hkdf, hkdf: Hkdf,
@ -215,21 +235,119 @@ impl Connection {
let cipher_recv = CipherRecv::new(cipher, secret_recv); let cipher_recv = CipherRecv::new(cipher, secret_recv);
let cipher_send = CipherSend::new(cipher, secret_send, rand); let cipher_send = CipherSend::new(cipher, secret_send, rand);
use ::std::net::{IpAddr, Ipv4Addr, SocketAddr};
Self { Self {
id_recv: IDRecv(ID::Handshake), id_recv: IDRecv(ID::Handshake),
id_send: IDSend(ID::Handshake), id_send: IDSend(ID::Handshake),
send_addr: UdpClient(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
31337,
)),
hkdf, hkdf,
cipher_recv, cipher_recv,
cipher_send, cipher_send,
mtu: 1280,
next_timer: TimerKind::None,
send_queue: BTreeMap::new(), send_queue: BTreeMap::new(),
last_stream_sent: stream::ID(0),
} }
} }
pub(crate) fn send(&mut self, stream: stream::ID, data: Vec<u8>) { pub(crate) fn enqueue(
&mut self,
stream: stream::ID,
data: Vec<u8>,
) -> Enqueue {
let stream = match self.send_queue.get_mut(&stream) { let stream = match self.send_queue.get_mut(&stream) {
None => return, None => return Enqueue::NoSuchStream,
Some(stream) => stream, Some(stream) => stream,
}; };
stream.push(data); stream.enqueue(data);
let ret;
self.next_timer = match self.next_timer {
TimerKind::None | TimerKind::Keepalive(_) => {
ret = Enqueue::Immediate;
TimerKind::SendData(::tokio::time::Instant::now())
}
TimerKind::SendData(old_timer) => {
// There already is some data to be sent
// wait for this timer,
// or risk going over max transmission rate
ret = Enqueue::TimerWait;
TimerKind::SendData(old_timer)
}
};
ret
}
pub(crate) fn write_pkt<'a>(
&mut self,
raw: &'a mut [u8],
) -> Result<&'a [u8], enc::Error> {
assert!(raw.len() >= 1200, "I should have at least 1200 MTU");
if self.send_queue.len() == 0 {
return Err(enc::Error::NotEnoughData(0));
}
raw[..ID::len()]
.copy_from_slice(&self.id_send.0.as_u64().to_le_bytes());
let data_from = ID::len() + self.cipher_send.nonce_len().0;
let data_max_to = raw.len() - self.cipher_send.tag_len().0;
let mut chunk_from = data_from;
let mut available_len = data_max_to - data_from;
use std::ops::Bound::{Excluded, Included};
let last_stream = self.last_stream_sent;
// Loop over our streams, write them to the packet.
// Notes:
// * to avoid starvation, just round-robin them all for now
// * we can enqueue multiple times the same stream
// This is useful especially for Datagram streams
'queueloop: {
for (id, stream) in self
.send_queue
.range_mut((Included(last_stream), Included(stream::ID::max())))
{
if available_len < stream::Chunk::headers_len() + 1 {
break 'queueloop;
}
let bytes =
stream.serialize(*id, &mut raw[chunk_from..data_max_to]);
if bytes == 0 {
break 'queueloop;
}
available_len = available_len - bytes;
chunk_from = chunk_from + bytes;
self.last_stream_sent = *id;
}
if available_len > 0 {
for (id, stream) in self.send_queue.range_mut((
Included(stream::ID::min()),
Excluded(last_stream),
)) {
if available_len < stream::Chunk::headers_len() + 1 {
break 'queueloop;
}
let bytes = stream
.serialize(*id, &mut raw[chunk_from..data_max_to]);
if bytes == 0 {
break 'queueloop;
}
available_len = available_len - bytes;
chunk_from = chunk_from + bytes;
self.last_stream_sent = *id;
}
}
}
if chunk_from == data_from {
return Err(enc::Error::NotEnoughData(0));
}
let data_to = chunk_from + self.cipher_send.tag_len().0;
// encrypt
let aad = sym::AAD(&[]);
match self.cipher_send.encrypt(aad, &mut raw[data_from..data_to]) {
Ok(_) => Ok(&raw[..data_to]),
Err(e) => Err(e),
}
} }
} }

View File

@ -27,6 +27,14 @@ impl ID {
pub const fn len() -> usize { pub const fn len() -> usize {
2 2
} }
/// Minimum possible Stream ID (u16::MIN)
pub const fn min() -> Self {
Self(u16::MIN)
}
/// Maximum possible Stream ID (u16::MAX)
pub const fn max() -> Self {
Self(u16::MAX)
}
} }
/// length of the chunk /// length of the chunk
@ -79,6 +87,10 @@ impl<'a> Chunk<'a> {
const FLAGS_EXCLUDED_BITMASK: u8 = 0x3F; const FLAGS_EXCLUDED_BITMASK: u8 = 0x3F;
const FLAG_START_BITMASK: u8 = 0x80; const FLAG_START_BITMASK: u8 = 0x80;
const FLAG_END_BITMASK: u8 = 0x40; const FLAG_END_BITMASK: u8 = 0x40;
/// Return the length of the header of a Chunk
pub const fn headers_len() -> usize {
ID::len() + ChunkLen::len() + Sequence::len()
}
/// Returns the total length of the chunk, including headers /// Returns the total length of the chunk, including headers
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
ID::len() + ChunkLen::len() + Sequence::len() + self.data.len() ID::len() + ChunkLen::len() + Sequence::len() + self.data.len()
@ -181,3 +193,68 @@ impl Stream {
} }
} }
} }
/// Track what has been sent and what has been ACK'd from a stream
#[derive(Debug)]
pub(crate) struct SendTracker {
queue: Vec<Vec<u8>>,
sent: Vec<usize>,
ackd: Vec<usize>,
chunk_started: bool,
is_datagram: bool,
next_sequence: Sequence,
}
impl SendTracker {
pub(crate) fn new(rand: &Random) -> Self {
Self {
queue: Vec::with_capacity(4),
sent: Vec::with_capacity(4),
ackd: Vec::with_capacity(4),
chunk_started: false,
is_datagram: false,
next_sequence: Sequence::new(rand),
}
}
/// Enqueue user data to be sent
pub(crate) fn enqueue(&mut self, data: Vec<u8>) {
self.queue.push(data);
self.sent.push(0);
self.ackd.push(0);
}
/// Write the user data to the buffer and mark it as sent
pub(crate) fn get(&mut self, out: &mut [u8]) -> usize {
let data = match self.queue.get(0) {
Some(data) => data,
None => return 0,
};
let len = ::std::cmp::min(out.len(), data.len());
out[..len].copy_from_slice(&data[self.sent[0]..len]);
self.sent[0] = self.sent[0] + len;
len
}
/// Mark the sent data as successfully received from the receiver
pub(crate) fn ack(&mut self, size: usize) {
todo!()
}
pub(crate) fn serialize(&mut self, id: ID, raw: &mut [u8]) -> usize {
let max_data_len = raw.len() - Chunk::headers_len();
let data_len = ::std::cmp::min(max_data_len, self.queue[0].len());
let flag_start = !self.chunk_started;
let flag_end = self.is_datagram && data_len == self.queue[0].len();
let chunk = Chunk {
id,
flag_start,
flag_end,
sequence: self.next_sequence,
data: &self.queue[0][..data_len],
};
self.next_sequence = Sequence(
self.next_sequence.0 + ::core::num::Wrapping(data_len as u32),
);
if chunk.flag_end {
self.chunk_started = false;
}
chunk.serialize(raw);
data_len
}
}

View File

@ -241,6 +241,14 @@ impl CipherSend {
pub fn kind(&self) -> Kind { pub fn kind(&self) -> Kind {
self.cipher.kind() self.cipher.kind()
} }
/// Get the length of the nonce for this cipher
pub fn nonce_len(&self) -> NonceLen {
self.cipher.nonce_len()
}
/// Get the length of the nonce for this cipher
pub fn tag_len(&self) -> TagLen {
self.cipher.tag_len()
}
} }
/// XChaCha20Poly1305 cipher /// XChaCha20Poly1305 cipher

View File

@ -16,6 +16,7 @@ use crate::{
}, },
dnssec, dnssec,
enc::{ enc::{
self,
asym::{self, KeyID, PrivKey, PubKey}, asym::{self, KeyID, PrivKey, PubKey},
hkdf::{self, Hkdf}, hkdf::{self, Hkdf},
sym, Random, Secret, sym, Random, Secret,
@ -53,6 +54,7 @@ pub(crate) enum Work {
DropHandshake(KeyID), DropHandshake(KeyID),
Recv(RawUdp), Recv(RawUdp),
UserSend((UserConnTracker, stream::ID, Vec<u8>)), UserSend((UserConnTracker, stream::ID, Vec<u8>)),
SendData(UserConnTracker),
} }
/// Actual worker implementation. /// Actual worker implementation.
@ -437,7 +439,42 @@ impl Worker {
None => return, None => return,
Some(conn) => conn, Some(conn) => conn,
}; };
conn.send(stream, data); use connection::Enqueue;
match conn.enqueue(stream, data) {
Enqueue::Immediate => {
let _ = self
.queue_sender
.send(Work::SendData(tracker))
.await;
}
Enqueue::TimerWait => {}
Enqueue::NoSuchStream => {
::tracing::error!(
"Trying to send on unknown stream"
);
}
}
}
Work::SendData(tracker) => {
let mut raw: Vec<u8> = Vec::with_capacity(1280);
raw.resize(raw.capacity(), 0);
let conn = match self.connections.get_mut(tracker) {
None => return,
Some(conn) => conn,
};
let pkt = match conn.write_pkt(&mut raw) {
Ok(pkt) => pkt,
Err(enc::Error::NotEnoughData(0)) => return,
Err(e) => {
::tracing::error!("Packet generation: {:?}", e);
return;
}
};
let dest = conn.send_addr;
let src = UdpServer(self.sockets[0].local_addr().unwrap());
let len = pkt.len();
raw.truncate(len);
let _ = self.send_packet(raw, dest, src);
} }
} }
} }