Stream enqueue and serialize to the packet
Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
This commit is contained in:
parent
e3ae166ca9
commit
d8cc5ca974
|
@ -14,8 +14,10 @@ use ::std::{
|
|||
pub use crate::connection::{handshake::Handshake, packet::Packet};
|
||||
|
||||
use crate::{
|
||||
connection::socket::{UdpClient, UdpServer},
|
||||
dnssec,
|
||||
enc::{
|
||||
self,
|
||||
asym::PubKey,
|
||||
hkdf::Hkdf,
|
||||
sym::{self, CipherRecv, CipherSend},
|
||||
|
@ -165,23 +167,6 @@ impl Conn {
|
|||
}
|
||||
}
|
||||
|
||||
/// A single connection and its data
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Connection {
|
||||
/// Receiving Conn ID
|
||||
pub(crate) id_recv: IDRecv,
|
||||
/// Sending Conn ID
|
||||
pub(crate) id_send: IDSend,
|
||||
/// The main hkdf used for all secrets in this connection
|
||||
hkdf: Hkdf,
|
||||
/// Cipher for decrypting data
|
||||
pub(crate) cipher_recv: CipherRecv,
|
||||
/// Cipher for encrypting data
|
||||
pub(crate) cipher_send: CipherSend,
|
||||
/// send queue for each Stream
|
||||
send_queue: BTreeMap<stream::ID, Vec<Vec<u8>>>,
|
||||
}
|
||||
|
||||
/// Role: track the connection direction
|
||||
///
|
||||
/// The Role is used to select the correct secrets, and track the direction
|
||||
|
@ -197,6 +182,41 @@ pub enum Role {
|
|||
Client,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum TimerKind {
|
||||
None,
|
||||
SendData(::tokio::time::Instant),
|
||||
Keepalive(::tokio::time::Instant),
|
||||
}
|
||||
|
||||
pub(crate) enum Enqueue {
|
||||
NoSuchStream,
|
||||
TimerWait,
|
||||
Immediate,
|
||||
}
|
||||
|
||||
/// A single connection and its data
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Connection {
|
||||
/// Receiving Conn ID
|
||||
pub(crate) id_recv: IDRecv,
|
||||
/// Sending Conn ID
|
||||
pub(crate) id_send: IDSend,
|
||||
/// Sending address
|
||||
pub(crate) send_addr: UdpClient,
|
||||
/// The main hkdf used for all secrets in this connection
|
||||
hkdf: Hkdf,
|
||||
/// Cipher for decrypting data
|
||||
pub(crate) cipher_recv: CipherRecv,
|
||||
/// Cipher for encrypting data
|
||||
pub(crate) cipher_send: CipherSend,
|
||||
mtu: usize,
|
||||
next_timer: TimerKind,
|
||||
/// send queue for each Stream
|
||||
send_queue: BTreeMap<stream::ID, stream::SendTracker>,
|
||||
last_stream_sent: stream::ID,
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
pub(crate) fn new(
|
||||
hkdf: Hkdf,
|
||||
|
@ -215,21 +235,119 @@ impl Connection {
|
|||
let cipher_recv = CipherRecv::new(cipher, secret_recv);
|
||||
let cipher_send = CipherSend::new(cipher, secret_send, rand);
|
||||
|
||||
use ::std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
Self {
|
||||
id_recv: IDRecv(ID::Handshake),
|
||||
id_send: IDSend(ID::Handshake),
|
||||
send_addr: UdpClient(SocketAddr::new(
|
||||
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
|
||||
31337,
|
||||
)),
|
||||
hkdf,
|
||||
cipher_recv,
|
||||
cipher_send,
|
||||
mtu: 1280,
|
||||
next_timer: TimerKind::None,
|
||||
send_queue: BTreeMap::new(),
|
||||
last_stream_sent: stream::ID(0),
|
||||
}
|
||||
}
|
||||
pub(crate) fn send(&mut self, stream: stream::ID, data: Vec<u8>) {
|
||||
pub(crate) fn enqueue(
|
||||
&mut self,
|
||||
stream: stream::ID,
|
||||
data: Vec<u8>,
|
||||
) -> Enqueue {
|
||||
let stream = match self.send_queue.get_mut(&stream) {
|
||||
None => return,
|
||||
None => return Enqueue::NoSuchStream,
|
||||
Some(stream) => stream,
|
||||
};
|
||||
stream.push(data);
|
||||
stream.enqueue(data);
|
||||
let ret;
|
||||
self.next_timer = match self.next_timer {
|
||||
TimerKind::None | TimerKind::Keepalive(_) => {
|
||||
ret = Enqueue::Immediate;
|
||||
TimerKind::SendData(::tokio::time::Instant::now())
|
||||
}
|
||||
TimerKind::SendData(old_timer) => {
|
||||
// There already is some data to be sent
|
||||
// wait for this timer,
|
||||
// or risk going over max transmission rate
|
||||
ret = Enqueue::TimerWait;
|
||||
TimerKind::SendData(old_timer)
|
||||
}
|
||||
};
|
||||
ret
|
||||
}
|
||||
pub(crate) fn write_pkt<'a>(
|
||||
&mut self,
|
||||
raw: &'a mut [u8],
|
||||
) -> Result<&'a [u8], enc::Error> {
|
||||
assert!(raw.len() >= 1200, "I should have at least 1200 MTU");
|
||||
if self.send_queue.len() == 0 {
|
||||
return Err(enc::Error::NotEnoughData(0));
|
||||
}
|
||||
raw[..ID::len()]
|
||||
.copy_from_slice(&self.id_send.0.as_u64().to_le_bytes());
|
||||
let data_from = ID::len() + self.cipher_send.nonce_len().0;
|
||||
let data_max_to = raw.len() - self.cipher_send.tag_len().0;
|
||||
let mut chunk_from = data_from;
|
||||
let mut available_len = data_max_to - data_from;
|
||||
|
||||
use std::ops::Bound::{Excluded, Included};
|
||||
let last_stream = self.last_stream_sent;
|
||||
|
||||
// Loop over our streams, write them to the packet.
|
||||
// Notes:
|
||||
// * to avoid starvation, just round-robin them all for now
|
||||
// * we can enqueue multiple times the same stream
|
||||
// This is useful especially for Datagram streams
|
||||
'queueloop: {
|
||||
for (id, stream) in self
|
||||
.send_queue
|
||||
.range_mut((Included(last_stream), Included(stream::ID::max())))
|
||||
{
|
||||
if available_len < stream::Chunk::headers_len() + 1 {
|
||||
break 'queueloop;
|
||||
}
|
||||
let bytes =
|
||||
stream.serialize(*id, &mut raw[chunk_from..data_max_to]);
|
||||
if bytes == 0 {
|
||||
break 'queueloop;
|
||||
}
|
||||
available_len = available_len - bytes;
|
||||
chunk_from = chunk_from + bytes;
|
||||
self.last_stream_sent = *id;
|
||||
}
|
||||
if available_len > 0 {
|
||||
for (id, stream) in self.send_queue.range_mut((
|
||||
Included(stream::ID::min()),
|
||||
Excluded(last_stream),
|
||||
)) {
|
||||
if available_len < stream::Chunk::headers_len() + 1 {
|
||||
break 'queueloop;
|
||||
}
|
||||
let bytes = stream
|
||||
.serialize(*id, &mut raw[chunk_from..data_max_to]);
|
||||
if bytes == 0 {
|
||||
break 'queueloop;
|
||||
}
|
||||
available_len = available_len - bytes;
|
||||
chunk_from = chunk_from + bytes;
|
||||
self.last_stream_sent = *id;
|
||||
}
|
||||
}
|
||||
}
|
||||
if chunk_from == data_from {
|
||||
return Err(enc::Error::NotEnoughData(0));
|
||||
}
|
||||
let data_to = chunk_from + self.cipher_send.tag_len().0;
|
||||
|
||||
// encrypt
|
||||
let aad = sym::AAD(&[]);
|
||||
match self.cipher_send.encrypt(aad, &mut raw[data_from..data_to]) {
|
||||
Ok(_) => Ok(&raw[..data_to]),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -27,6 +27,14 @@ impl ID {
|
|||
pub const fn len() -> usize {
|
||||
2
|
||||
}
|
||||
/// Minimum possible Stream ID (u16::MIN)
|
||||
pub const fn min() -> Self {
|
||||
Self(u16::MIN)
|
||||
}
|
||||
/// Maximum possible Stream ID (u16::MAX)
|
||||
pub const fn max() -> Self {
|
||||
Self(u16::MAX)
|
||||
}
|
||||
}
|
||||
|
||||
/// length of the chunk
|
||||
|
@ -79,6 +87,10 @@ impl<'a> Chunk<'a> {
|
|||
const FLAGS_EXCLUDED_BITMASK: u8 = 0x3F;
|
||||
const FLAG_START_BITMASK: u8 = 0x80;
|
||||
const FLAG_END_BITMASK: u8 = 0x40;
|
||||
/// Return the length of the header of a Chunk
|
||||
pub const fn headers_len() -> usize {
|
||||
ID::len() + ChunkLen::len() + Sequence::len()
|
||||
}
|
||||
/// Returns the total length of the chunk, including headers
|
||||
pub fn len(&self) -> usize {
|
||||
ID::len() + ChunkLen::len() + Sequence::len() + self.data.len()
|
||||
|
@ -181,3 +193,68 @@ impl Stream {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Track what has been sent and what has been ACK'd from a stream
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct SendTracker {
|
||||
queue: Vec<Vec<u8>>,
|
||||
sent: Vec<usize>,
|
||||
ackd: Vec<usize>,
|
||||
chunk_started: bool,
|
||||
is_datagram: bool,
|
||||
next_sequence: Sequence,
|
||||
}
|
||||
impl SendTracker {
|
||||
pub(crate) fn new(rand: &Random) -> Self {
|
||||
Self {
|
||||
queue: Vec::with_capacity(4),
|
||||
sent: Vec::with_capacity(4),
|
||||
ackd: Vec::with_capacity(4),
|
||||
chunk_started: false,
|
||||
is_datagram: false,
|
||||
next_sequence: Sequence::new(rand),
|
||||
}
|
||||
}
|
||||
/// Enqueue user data to be sent
|
||||
pub(crate) fn enqueue(&mut self, data: Vec<u8>) {
|
||||
self.queue.push(data);
|
||||
self.sent.push(0);
|
||||
self.ackd.push(0);
|
||||
}
|
||||
/// Write the user data to the buffer and mark it as sent
|
||||
pub(crate) fn get(&mut self, out: &mut [u8]) -> usize {
|
||||
let data = match self.queue.get(0) {
|
||||
Some(data) => data,
|
||||
None => return 0,
|
||||
};
|
||||
let len = ::std::cmp::min(out.len(), data.len());
|
||||
out[..len].copy_from_slice(&data[self.sent[0]..len]);
|
||||
self.sent[0] = self.sent[0] + len;
|
||||
len
|
||||
}
|
||||
/// Mark the sent data as successfully received from the receiver
|
||||
pub(crate) fn ack(&mut self, size: usize) {
|
||||
todo!()
|
||||
}
|
||||
pub(crate) fn serialize(&mut self, id: ID, raw: &mut [u8]) -> usize {
|
||||
let max_data_len = raw.len() - Chunk::headers_len();
|
||||
let data_len = ::std::cmp::min(max_data_len, self.queue[0].len());
|
||||
let flag_start = !self.chunk_started;
|
||||
let flag_end = self.is_datagram && data_len == self.queue[0].len();
|
||||
let chunk = Chunk {
|
||||
id,
|
||||
flag_start,
|
||||
flag_end,
|
||||
sequence: self.next_sequence,
|
||||
data: &self.queue[0][..data_len],
|
||||
};
|
||||
self.next_sequence = Sequence(
|
||||
self.next_sequence.0 + ::core::num::Wrapping(data_len as u32),
|
||||
);
|
||||
if chunk.flag_end {
|
||||
self.chunk_started = false;
|
||||
}
|
||||
chunk.serialize(raw);
|
||||
data_len
|
||||
}
|
||||
}
|
||||
|
|
|
@ -241,6 +241,14 @@ impl CipherSend {
|
|||
pub fn kind(&self) -> Kind {
|
||||
self.cipher.kind()
|
||||
}
|
||||
/// Get the length of the nonce for this cipher
|
||||
pub fn nonce_len(&self) -> NonceLen {
|
||||
self.cipher.nonce_len()
|
||||
}
|
||||
/// Get the length of the nonce for this cipher
|
||||
pub fn tag_len(&self) -> TagLen {
|
||||
self.cipher.tag_len()
|
||||
}
|
||||
}
|
||||
|
||||
/// XChaCha20Poly1305 cipher
|
||||
|
|
|
@ -16,6 +16,7 @@ use crate::{
|
|||
},
|
||||
dnssec,
|
||||
enc::{
|
||||
self,
|
||||
asym::{self, KeyID, PrivKey, PubKey},
|
||||
hkdf::{self, Hkdf},
|
||||
sym, Random, Secret,
|
||||
|
@ -53,6 +54,7 @@ pub(crate) enum Work {
|
|||
DropHandshake(KeyID),
|
||||
Recv(RawUdp),
|
||||
UserSend((UserConnTracker, stream::ID, Vec<u8>)),
|
||||
SendData(UserConnTracker),
|
||||
}
|
||||
|
||||
/// Actual worker implementation.
|
||||
|
@ -437,7 +439,42 @@ impl Worker {
|
|||
None => return,
|
||||
Some(conn) => conn,
|
||||
};
|
||||
conn.send(stream, data);
|
||||
use connection::Enqueue;
|
||||
match conn.enqueue(stream, data) {
|
||||
Enqueue::Immediate => {
|
||||
let _ = self
|
||||
.queue_sender
|
||||
.send(Work::SendData(tracker))
|
||||
.await;
|
||||
}
|
||||
Enqueue::TimerWait => {}
|
||||
Enqueue::NoSuchStream => {
|
||||
::tracing::error!(
|
||||
"Trying to send on unknown stream"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Work::SendData(tracker) => {
|
||||
let mut raw: Vec<u8> = Vec::with_capacity(1280);
|
||||
raw.resize(raw.capacity(), 0);
|
||||
let conn = match self.connections.get_mut(tracker) {
|
||||
None => return,
|
||||
Some(conn) => conn,
|
||||
};
|
||||
let pkt = match conn.write_pkt(&mut raw) {
|
||||
Ok(pkt) => pkt,
|
||||
Err(enc::Error::NotEnoughData(0)) => return,
|
||||
Err(e) => {
|
||||
::tracing::error!("Packet generation: {:?}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
let dest = conn.send_addr;
|
||||
let src = UdpServer(self.sockets[0].local_addr().unwrap());
|
||||
let len = pkt.len();
|
||||
raw.truncate(len);
|
||||
let _ = self.send_packet(raw, dest, src);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue