From d8a27bf969ebfb2ef7322f43a564192233c3e333 Mon Sep 17 00:00:00 2001 From: Luca Fulchir Date: Sun, 25 Jun 2023 19:22:40 +0200 Subject: [PATCH] Stream ROB: Reconstruct data TCP-like this was more convoluted than I thought. maybe someone will simplify this. Signed-off-by: Luca Fulchir --- flake.lock | 24 +- src/connection/handshake/tracker.rs | 8 +- src/connection/mod.rs | 71 ++++- src/connection/stream/errors.rs | 4 +- src/connection/stream/mod.rs | 71 ++++- src/connection/stream/rob.rs | 29 -- src/connection/stream/rob/mod.rs | 204 ++++++++++++ src/connection/stream/rob/tests.rs | 249 +++++++++++++++ src/inner/mod.rs | 26 ++ src/inner/worker.rs | 471 +++++++++++++++------------- 10 files changed, 878 insertions(+), 279 deletions(-) delete mode 100644 src/connection/stream/rob.rs create mode 100644 src/connection/stream/rob/mod.rs create mode 100644 src/connection/stream/rob/tests.rs diff --git a/flake.lock b/flake.lock index 8eb84e1..85c58c2 100644 --- a/flake.lock +++ b/flake.lock @@ -5,11 +5,11 @@ "systems": "systems" }, "locked": { - "lastModified": 1685518550, - "narHash": "sha256-o2d0KcvaXzTrPRIo0kOLV0/QXHhDQ5DTi+OxcjO8xqY=", + "lastModified": 1687171271, + "narHash": "sha256-BJlq+ozK2B1sJDQXS3tzJM5a+oVZmi1q0FlBK/Xqv7M=", "owner": "numtide", "repo": "flake-utils", - "rev": "a1720a10a6cfe8234c0e93907ffe81be440f4cef", + "rev": "abfb11bd1aec8ced1c9bb9adfe68018230f4fb3c", "type": "github" }, "original": { @@ -38,11 +38,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1686921029, - "narHash": "sha256-J1bX9plPCFhTSh6E3TWn9XSxggBh/zDD4xigyaIQBy8=", + "lastModified": 1687555006, + "narHash": "sha256-GD2Kqb/DXQBRJcHqkM2qFZqbVenyO7Co/80JHRMg2U0=", "owner": "nixos", "repo": "nixpkgs", - "rev": "c7ff1b9b95620ce8728c0d7bd501c458e6da9e04", + "rev": "33223d479ffde3d05ac16c6dff04ae43cc27e577", "type": "github" }, "original": { @@ -54,11 +54,11 @@ }, "nixpkgs-unstable": { "locked": { - "lastModified": 1686960236, - "narHash": "sha256-AYCC9rXNLpUWzD9hm+askOfpliLEC9kwAo7ITJc4HIw=", + "lastModified": 1687502512, + "narHash": "sha256-dBL/01TayOSZYxtY4cMXuNCBk8UMLoqRZA+94xiFpJA=", "owner": "nixos", "repo": "nixpkgs", - "rev": "04af42f3b31dba0ef742d254456dc4c14eedac86", + "rev": "3ae20aa58a6c0d1ca95c9b11f59a2d12eebc511f", "type": "github" }, "original": { @@ -98,11 +98,11 @@ "nixpkgs": "nixpkgs_2" }, "locked": { - "lastModified": 1687055571, - "narHash": "sha256-UvLoO6u5n9TzY80BpM4DaacxvyJl7u9mm9CA72d309g=", + "lastModified": 1687660699, + "narHash": "sha256-crI/CA/OJc778I5qJhwhhl8/PKKzc0D7vvVxOtjfvSo=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "2de557c780dcb127128ae987fca9d6c2b0d7dc0f", + "rev": "b3bd1d49f1ae609c1d68a66bba7a95a9a4256031", "type": "github" }, "original": { diff --git a/src/connection/handshake/tracker.rs b/src/connection/handshake/tracker.rs index c40158e..fbedc34 100644 --- a/src/connection/handshake/tracker.rs +++ b/src/connection/handshake/tracker.rs @@ -37,7 +37,7 @@ pub(crate) struct Client { pub(crate) service_id: ServiceID, pub(crate) service_conn_id: IDRecv, pub(crate) connection: Connection, - pub(crate) timeout: Option<::tokio::task::JoinHandle<()>>, + pub(crate) timeout: Option<::tokio::time::Instant>, pub(crate) answer: oneshot::Sender, pub(crate) srv_key_id: KeyID, } @@ -150,6 +150,8 @@ pub(crate) struct ClientConnectInfo { pub(crate) service_connection_id: IDRecv, /// Parsed handshake packet pub(crate) handshake: Handshake, + /// Old timeout for the handshake completion + pub(crate) old_timeout: ::tokio::time::Instant, /// Connection pub(crate) connection: Connection, /// where to wake up the waiting client @@ -374,13 +376,11 @@ impl Tracker { } let hshake = self.hshake_cli.remove(resp.client_key_id).unwrap(); - if let Some(timeout) = hshake.timeout { - timeout.abort(); - } return Ok(Action::ClientConnect(ClientConnectInfo { service_id: hshake.service_id, service_connection_id: hshake.service_conn_id, handshake, + old_timeout: hshake.timeout.unwrap(), connection: hshake.connection, answer: hshake.answer, srv_key_id: hshake.srv_key_id, diff --git a/src/connection/mod.rs b/src/connection/mod.rs index 337b48f..dbe62e1 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -14,7 +14,7 @@ use ::std::{ pub use crate::connection::{handshake::Handshake, packet::Packet}; use crate::{ - connection::socket::{UdpClient, UdpServer}, + connection::socket::UdpClient, dnssec, enc::{ self, @@ -26,6 +26,16 @@ use crate::{ inner::{worker, ThreadTracker}, }; +/// Connaction errors +#[derive(::thiserror::Error, Debug, Copy, Clone)] +pub(crate) enum Error { + /// Can't decrypt packet + #[error("Decrypt error: {0}")] + Decrypt(#[from] crate::enc::Error), + #[error("Chunk parsing: {0}")] + Parse(#[from] stream::Error), +} + /// Fenrir Connection ID /// /// 0 is special as it represents the handshake @@ -192,7 +202,7 @@ enum TimerKind { pub(crate) enum Enqueue { NoSuchStream, TimerWait, - Immediate, + Immediate(::tokio::time::Instant), } /// A single connection and its data @@ -215,6 +225,8 @@ pub(crate) struct Connection { /// send queue for each Stream send_queue: BTreeMap, last_stream_sent: stream::ID, + /// receive queue for each Stream + recv_queue: BTreeMap, } impl Connection { @@ -246,12 +258,46 @@ impl Connection { hkdf, cipher_recv, cipher_send, - mtu: 1280, + mtu: 1200, next_timer: TimerKind::None, send_queue: BTreeMap::new(), last_stream_sent: stream::ID(0), + recv_queue: BTreeMap::new(), } } + pub(crate) fn recv(&mut self, mut udp: crate::RawUdp) -> Result<(), Error> { + let mut data = &mut udp.data[ID::len()..]; + let aad = enc::sym::AAD(&[]); + self.cipher_recv.decrypt(aad, &mut data)?; + let mut bytes_parsed = 0; + let mut chunks = Vec::with_capacity(2); + loop { + let chunk = match stream::Chunk::deserialize(&data[bytes_parsed..]) + { + Ok(chunk) => chunk, + Err(e) => { + return Err(e.into()); + } + }; + bytes_parsed = bytes_parsed + chunk.len(); + chunks.push(chunk); + if bytes_parsed == data.len() { + break; + } + } + for chunk in chunks.into_iter() { + let stream = match self.recv_queue.get_mut(&chunk.id) { + Some(stream) => stream, + None => { + ::tracing::debug!("Ignoring chunk for unknown stream::ID"); + continue; + } + }; + stream.recv(chunk); + } + // FIXME: report if we need to return data to the user + Ok(()) + } pub(crate) fn enqueue( &mut self, stream: stream::ID, @@ -262,11 +308,13 @@ impl Connection { Some(stream) => stream, }; stream.enqueue(data); + let instant; let ret; self.next_timer = match self.next_timer { TimerKind::None | TimerKind::Keepalive(_) => { - ret = Enqueue::Immediate; - TimerKind::SendData(::tokio::time::Instant::now()) + instant = ::tokio::time::Instant::now(); + ret = Enqueue::Immediate(instant); + TimerKind::SendData(instant) } TimerKind::SendData(old_timer) => { // There already is some data to be sent @@ -282,7 +330,7 @@ impl Connection { &mut self, raw: &'a mut [u8], ) -> Result<&'a [u8], enc::Error> { - assert!(raw.len() >= 1200, "I should have at least 1200 MTU"); + assert!(raw.len() >= self.mtu, "I should have at least 1200 MTU"); if self.send_queue.len() == 0 { return Err(enc::Error::NotEnoughData(0)); } @@ -378,6 +426,17 @@ impl ConnList { ret.connections.resize_with(INITIAL_CAP, || None); ret } + pub fn get_id_mut(&mut self, id: ID) -> Option<&mut Connection> { + let conn_id = match id { + ID::Handshake => { + return None; + } + ID::ID(conn_id) => conn_id, + }; + let id_in_thread: usize = + (conn_id.get() / (self.thread_id.total as u64)) as usize; + (&mut self.connections[id_in_thread]).into() + } pub fn get_mut( &mut self, tracker: UserConnTracker, diff --git a/src/connection/stream/errors.rs b/src/connection/stream/errors.rs index 133d976..07dcbd8 100644 --- a/src/connection/stream/errors.rs +++ b/src/connection/stream/errors.rs @@ -1,10 +1,12 @@ //! Errors while parsing streams - /// Crypto errors #[derive(::thiserror::Error, Debug, Copy, Clone)] pub enum Error { /// Error while parsing key material #[error("Not enough data for stream chunk: {0}")] NotEnoughData(usize), + /// Sequence outside of the window + #[error("Sequence out of the sliding window")] + OutOfWindow, } diff --git a/src/connection/stream/mod.rs b/src/connection/stream/mod.rs index 778c9c5..96971cc 100644 --- a/src/connection/stream/mod.rs +++ b/src/connection/stream/mod.rs @@ -48,6 +48,30 @@ impl ChunkLen { } } +//TODO: make pub? +#[derive(Debug, Copy, Clone)] +pub(crate) struct SequenceStart(pub(crate) Sequence); +impl SequenceStart { + pub(crate) fn plus_u32(&self, other: u32) -> Sequence { + self.0.plus_u32(other) + } + pub(crate) fn offset(&self, seq: Sequence) -> usize { + if self.0 .0 <= seq.0 { + (seq.0 - self.0 .0).0 as usize + } else { + (seq.0 + (Sequence::max().0 - self.0 .0)).0 as usize + } + } +} +// SequenceEnd is INCLUSIVE +#[derive(Debug, Copy, Clone)] +pub(crate) struct SequenceEnd(pub(crate) Sequence); +impl SequenceEnd { + pub(crate) fn plus_u32(&self, other: u32) -> Sequence { + self.0.plus_u32(other) + } +} + /// Sequence number to rebuild the stream correctly #[derive(Debug, Copy, Clone)] pub struct Sequence(pub ::core::num::Wrapping); @@ -56,14 +80,52 @@ impl Sequence { const SEQ_NOFLAG: u32 = 0x3FFFFFFF; /// return a new sequence number, starting at random pub fn new(rand: &Random) -> Self { - let seq: u32 = 0; - rand.fill(&mut seq.to_le_bytes()); + let mut raw_seq: [u8; 4] = [0; 4]; + rand.fill(&mut raw_seq); + let seq = u32::from_le_bytes(raw_seq); Self(::core::num::Wrapping(seq & Self::SEQ_NOFLAG)) } /// Length of the serialized field pub const fn len() -> usize { 4 } + /// Maximum possible sequence + pub const fn max() -> Self { + Self(::core::num::Wrapping(Self::SEQ_NOFLAG)) + } + pub(crate) fn is_between( + &self, + start: SequenceStart, + end: SequenceEnd, + ) -> bool { + if start.0 .0 < end.0 .0 { + start.0 .0 <= self.0 && self.0 <= end.0 .0 + } else { + start.0 .0 <= self.0 || self.0 <= end.0 .0 + } + } + pub(crate) fn remaining_window(&self, end: SequenceEnd) -> u32 { + if self.0 <= end.0 .0 { + (end.0 .0 .0 - self.0 .0) + 1 + } else { + end.0 .0 .0 + 1 + (Self::max().0 - self.0).0 + } + } + pub(crate) fn plus_u32(self, other: u32) -> Self { + Self(::core::num::Wrapping( + (self.0 .0 + other) & Self::SEQ_NOFLAG, + )) + } +} + +impl ::core::ops::Add for Sequence { + type Output = Self; + + fn add(self, other: Self) -> Self { + Self(::core::num::Wrapping( + (self.0 + other.0).0 & Self::SEQ_NOFLAG, + )) + } } /// Chunk of data representing a stream @@ -192,6 +254,11 @@ impl Stream { data: Tracker::new(kind, rand), } } + pub(crate) fn recv(&mut self, chunk: Chunk) -> Result<(), Error> { + match &mut self.data { + Tracker::ROB(tracker) => tracker.recv(chunk), + } + } } /// Track what has been sent and what has been ACK'd from a stream diff --git a/src/connection/stream/rob.rs b/src/connection/stream/rob.rs deleted file mode 100644 index 5d28f59..0000000 --- a/src/connection/stream/rob.rs +++ /dev/null @@ -1,29 +0,0 @@ -//! Implementation of the Reliable, Ordered, Bytestream transmission model -//! AKA: TCP-like - -use crate::{ - connection::stream::{Chunk, Error, Sequence}, - enc::Random, -}; - -/// Reliable, Ordered, Bytestream stream tracker -/// AKA: TCP-like -#[derive(Debug, Clone)] -pub(crate) struct ReliableOrderedBytestream { - window_start: Sequence, - window_len: usize, - data: Vec, -} - -impl ReliableOrderedBytestream { - pub(crate) fn new(rand: &Random) -> Self { - Self { - window_start: Sequence::new(rand), - window_len: 1048576, // 1MB. should be enough for anybody. (lol) - data: Vec::new(), - } - } - pub(crate) fn recv(&mut self, chunk: Chunk) -> Result<(), Error> { - todo!() - } -} diff --git a/src/connection/stream/rob/mod.rs b/src/connection/stream/rob/mod.rs new file mode 100644 index 0000000..1bfd159 --- /dev/null +++ b/src/connection/stream/rob/mod.rs @@ -0,0 +1,204 @@ +//! Implementation of the Reliable, Ordered, Bytestream transmission model +//! AKA: TCP-like + +use crate::{ + connection::stream::{Chunk, Error, Sequence, SequenceEnd, SequenceStart}, + enc::Random, +}; + +#[cfg(test)] +mod tests; + +/// Reliable, Ordered, Bytestream stream tracker +/// AKA: TCP-like +#[derive(Debug, Clone)] +pub(crate) struct ReliableOrderedBytestream { + pub(crate) window_start: SequenceStart, + window_end: SequenceEnd, + pivot: u32, + data: Vec, + missing: Vec<(Sequence, Sequence)>, +} + +impl ReliableOrderedBytestream { + pub(crate) fn new(rand: &Random) -> Self { + let window_len = 1048576; // 1MB. should be enough for anybody. (lol) + let window_start = SequenceStart(Sequence::new(rand)); + let window_end = SequenceEnd(window_start.0.plus_u32(window_len - 1)); + let mut data = Vec::with_capacity(window_len as usize); + data.resize(data.capacity(), 0); + + Self { + window_start, + window_end, + pivot: window_len, + data, + missing: [(window_start.0, window_end.0)].to_vec(), + } + } + pub(crate) fn with_window_size(rand: &Random, size: u32) -> Self { + assert!( + size < Sequence::max().0 .0, + "Max window size is {}", + Sequence::max().0 .0 + ); + let window_len = size; // 1MB. should be enough for anybody. (lol) + let window_start = SequenceStart(Sequence::new(rand)); + let window_end = SequenceEnd(window_start.0.plus_u32(window_len - 1)); + let mut data = Vec::with_capacity(window_len as usize); + data.resize(data.capacity(), 0); + + Self { + window_start, + window_end, + pivot: window_len, + data, + missing: [(window_start.0, window_end.0)].to_vec(), + } + } + pub(crate) fn window_size(&self) -> u32 { + self.data.len() as u32 + } + pub(crate) fn get(&mut self) -> Vec { + if self.missing.len() == 0 { + let (first, second) = self.data.split_at(self.pivot as usize); + let mut ret = Vec::with_capacity(self.data.len()); + ret.extend_from_slice(first); + ret.extend_from_slice(second); + self.window_start = + SequenceStart(self.window_start.plus_u32(ret.len() as u32)); + self.window_end = + SequenceEnd(self.window_end.plus_u32(ret.len() as u32)); + self.data.clear(); + return ret; + } + let data_len = self.window_start.offset(self.missing[0].0); + let last_missing_idx = self.missing.len() - 1; + let mut last_missing = &mut self.missing[last_missing_idx]; + last_missing.1 = last_missing.1.plus_u32(data_len as u32); + self.window_start = + SequenceStart(self.window_start.plus_u32(data_len as u32)); + self.window_end = + SequenceEnd(self.window_end.plus_u32(data_len as u32)); + + let mut ret = Vec::with_capacity(data_len); + let (first, second) = self.data[..].split_at(self.pivot as usize); + let first_len = ::core::cmp::min(data_len, first.len()); + let second_len = data_len - first_len; + + ret.extend_from_slice(&first[..first_len]); + ret.extend_from_slice(&second[..second_len]); + + self.pivot = + ((self.pivot as usize + data_len) % self.data.len()) as u32; + ret + } + pub(crate) fn recv(&mut self, chunk: Chunk) -> Result<(), Error> { + if !chunk + .sequence + .is_between(self.window_start, self.window_end) + { + return Err(Error::OutOfWindow); + } + // make sure we consider only the bytes inside the sliding window + let maxlen = ::std::cmp::min( + chunk.sequence.remaining_window(self.window_end) as usize, + chunk.data.len(), + ); + if maxlen == 0 { + // or empty chunk, but we don't care + return Err(Error::OutOfWindow); + } + // translate Sequences to offsets in self.data + let data = &chunk.data[..maxlen]; + let offset = self.window_start.offset(chunk.sequence); + let offset_end = offset + chunk.data.len() - 1; + + // Find the chunks we are missing that we can copy, + // and fix the missing tracker + let mut copy_ranges = Vec::new(); + let mut to_delete = Vec::new(); + let mut to_add = Vec::new(); + // note: te included ranges are (INCLUSIVE, INCLUSIVE) + for (idx, el) in self.missing.iter_mut().enumerate() { + let missing_from = self.window_start.offset(el.0); + if missing_from > offset_end { + break; + } + let missing_to = self.window_start.offset(el.1); + if missing_to < offset { + continue; + } + if missing_from >= offset && missing_from <= offset_end { + if missing_to <= offset_end { + // [.....chunk.....] + // [..missing..] + to_delete.push(idx); + copy_ranges.push((missing_from, missing_to)); + } else { + // [....chunk....] + // [...missing...] + copy_ranges.push((missing_from, offset_end)); + el.0 = + el.0.plus_u32(((offset_end - missing_from) + 1) as u32); + } + } else if missing_from < offset { + if missing_to > offset_end { + // [..chunk..] + // [....missing....] + // chunk is in the middle of a missing fragment + to_add.push(( + el.0.plus_u32(((offset_end - missing_from) + 1) as u32), + el.1, + )); + el.1 = el.0.plus_u32(((offset - missing_from) - 1) as u32); + copy_ranges.push((offset, offset_end)); + } else if offset <= missing_to { + // [....chunk....] + // [...missing...] + // chunk + copy_ranges.push((offset, (missing_to - 0))); + el.1 = + el.0.plus_u32(((offset_end - missing_from) - 1) as u32); + } + } + } + self.missing.append(&mut to_add); + self.missing + .sort_by(|(from_a, _), (from_b, _)| from_a.0 .0.cmp(&from_b.0 .0)); + { + let mut deleted = 0; + for idx in to_delete.into_iter() { + self.missing.remove(idx + deleted); + deleted = deleted + 1; + } + } + // copy only the missing data + let (first, second) = self.data[..].split_at_mut(self.pivot as usize); + for (from, to) in copy_ranges.into_iter() { + let to = to + 1; + if from <= first.len() { + let first_from = from; + let first_to = ::core::cmp::min(first.len(), to); + let data_first_from = from - offset; + let data_first_to = first_to - offset; + first[first_from..first_to] + .copy_from_slice(&data[data_first_from..data_first_to]); + + let second_to = to - first_to; + let data_second_to = data_first_to + second_to; + second[..second_to] + .copy_from_slice(&data[data_first_to..data_second_to]); + } else { + let second_from = from - first.len(); + let second_to = to - first.len(); + let data_from = from - offset; + let data_to = to - offset; + second[second_from..second_to] + .copy_from_slice(&data[data_from..data_to]); + } + } + + Ok(()) + } +} diff --git a/src/connection/stream/rob/tests.rs b/src/connection/stream/rob/tests.rs new file mode 100644 index 0000000..20cb508 --- /dev/null +++ b/src/connection/stream/rob/tests.rs @@ -0,0 +1,249 @@ +use crate::{ + connection::stream::{self, rob::*, Chunk}, + enc::Random, +}; + +#[::tracing_test::traced_test] +#[test] +fn test_stream_rob_sequential() { + let rand = Random::new(); + let mut rob = ReliableOrderedBytestream::with_window_size(&rand, 1048576); + + let mut data = Vec::with_capacity(1024); + data.resize(data.capacity(), 0); + rand.fill(&mut data[..]); + + let start = rob.window_start.0; + + let chunk = Chunk { + id: stream::ID(42), + flag_start: true, + flag_end: false, + sequence: start, + data: &data[..512], + }; + let got = rob.get(); + assert!(&got[..] == &[], "rob: got data?"); + let _ = rob.recv(chunk); + let got = rob.get(); + assert!( + &data[..512] == &got[..], + "ROB1: DIFF: {:?} {:?}", + &data[..512].len(), + &got[..].len() + ); + let chunk = Chunk { + id: stream::ID(42), + flag_start: false, + flag_end: true, + sequence: start.plus_u32(512), + data: &data[512..], + }; + let _ = rob.recv(chunk); + let got = rob.get(); + assert!( + &data[512..] == &got[..], + "ROB2: DIFF: {:?} {:?}", + &data[512..].len(), + &got[..].len() + ); +} + +#[::tracing_test::traced_test] +#[test] +fn test_stream_rob_retransmit() { + let rand = Random::new(); + let max_window: usize = 100; + let mut rob = + ReliableOrderedBytestream::with_window_size(&rand, max_window as u32); + + let mut data = Vec::with_capacity(120); + data.resize(data.capacity(), 0); + for i in 0..data.len() { + data[i] = i as u8; + } + + let start = rob.window_start.0; + + let chunk = Chunk { + id: stream::ID(42), + flag_start: true, + flag_end: false, + sequence: start, + data: &data[..40], + }; + let _ = rob.recv(chunk); + let chunk = Chunk { + id: stream::ID(42), + flag_start: false, + flag_end: false, + sequence: start.plus_u32(50), + data: &data[50..60], + }; + let _ = rob.recv(chunk); + let chunk = Chunk { + id: stream::ID(42), + flag_start: false, + flag_end: false, + sequence: start.plus_u32(40), + data: &data[40..60], + }; + let _ = rob.recv(chunk); + let chunk = Chunk { + id: stream::ID(42), + flag_start: false, + flag_end: false, + sequence: start.plus_u32(80), + data: &data[80..], + }; + let _ = rob.recv(chunk); + let chunk = Chunk { + id: stream::ID(42), + flag_start: false, + flag_end: false, + sequence: start.plus_u32(50), + data: &data[50..90], + }; + let _ = rob.recv(chunk); + let chunk = Chunk { + id: stream::ID(42), + flag_start: false, + flag_end: false, + sequence: start.plus_u32(max_window as u32), + data: &data[max_window..], + }; + let _ = rob.recv(chunk); + let chunk = Chunk { + id: stream::ID(42), + flag_start: false, + flag_end: true, + sequence: start.plus_u32(90), + data: &data[90..max_window], + }; + let _ = rob.recv(chunk); + let got = rob.get(); + assert!( + &data[..max_window] == &got[..], + "DIFF:\n {:?}\n {:?}", + &data[..max_window], + &got[..], + ); +} +#[::tracing_test::traced_test] +#[test] +fn test_stream_rob_rolling() { + let rand = Random::new(); + let max_window: usize = 100; + let mut rob = + ReliableOrderedBytestream::with_window_size(&rand, max_window as u32); + + let mut data = Vec::with_capacity(120); + data.resize(data.capacity(), 0); + for i in 0..data.len() { + data[i] = i as u8; + } + + let start = rob.window_start.0; + + let chunk = Chunk { + id: stream::ID(42), + flag_start: true, + flag_end: false, + sequence: start, + data: &data[..40], + }; + let _ = rob.recv(chunk); + let chunk = Chunk { + id: stream::ID(42), + flag_start: true, + flag_end: false, + sequence: start.plus_u32(50), + data: &data[50..100], + }; + let _ = rob.recv(chunk); + let got = rob.get(); + assert!( + &data[..40] == &got[..], + "DIFF:\n {:?}\n {:?}", + &data[..40], + &got[..], + ); + let chunk = Chunk { + id: stream::ID(42), + flag_start: true, + flag_end: false, + sequence: start.plus_u32(40), + data: &data[40..], + }; + let _ = rob.recv(chunk); + let got = rob.get(); + assert!( + &data[40..] == &got[..], + "DIFF:\n {:?}\n {:?}", + &data[40..], + &got[..], + ); +} +#[::tracing_test::traced_test] +#[test] +fn test_stream_rob_rolling_second_case() { + let rand = Random::new(); + let max_window: usize = 100; + let mut rob = + ReliableOrderedBytestream::with_window_size(&rand, max_window as u32); + + let mut data = Vec::with_capacity(120); + data.resize(data.capacity(), 0); + for i in 0..data.len() { + data[i] = i as u8; + } + + let start = rob.window_start.0; + + let chunk = Chunk { + id: stream::ID(42), + flag_start: true, + flag_end: false, + sequence: start, + data: &data[..40], + }; + let _ = rob.recv(chunk); + let chunk = Chunk { + id: stream::ID(42), + flag_start: true, + flag_end: false, + sequence: start.plus_u32(50), + data: &data[50..100], + }; + let _ = rob.recv(chunk); + let got = rob.get(); + assert!( + &data[..40] == &got[..], + "DIFF:\n {:?}\n {:?}", + &data[..40], + &got[..], + ); + let chunk = Chunk { + id: stream::ID(42), + flag_start: true, + flag_end: false, + sequence: start.plus_u32(40), + data: &data[40..100], + }; + let _ = rob.recv(chunk); + let chunk = Chunk { + id: stream::ID(42), + flag_start: true, + flag_end: false, + sequence: start.plus_u32(100), + data: &data[100..], + }; + let _ = rob.recv(chunk); + let got = rob.get(); + assert!( + &data[40..] == &got[..], + "DIFF:\n {:?}\n {:?}", + &data[40..], + &got[..], + ); +} diff --git a/src/inner/mod.rs b/src/inner/mod.rs index e23d614..6102fde 100644 --- a/src/inner/mod.rs +++ b/src/inner/mod.rs @@ -66,6 +66,32 @@ impl Timers { } } } + pub(crate) fn add( + &mut self, + duration: ::tokio::time::Duration, + work: Work, + ) -> ::tokio::time::Instant { + // the returned time is the key in the map. + // Make sure it is unique. + // + // We can be pretty sure we won't do a lot of stuff + // in a single nanosecond, so if we hit a time that is already present + // just add a nanosecond and retry + let mut time = ::tokio::time::Instant::now() + duration; + let mut work = work; + loop { + if let Some(old_val) = self.times.insert(time, work) { + work = self.times.insert(time, old_val).unwrap(); + time = time + ::std::time::Duration::from_nanos(1); + } else { + break; + } + } + time + } + pub(crate) fn remove(&mut self, time: ::tokio::time::Instant) { + let _ = self.times.remove(&time); + } /// Get all the work from now up until now + SLEEP_RESOLUTION pub(crate) fn get_work(&mut self) -> Vec { let now: ::tokio::time::Instant = ::std::time::Instant::now().into(); diff --git a/src/inner/worker.rs b/src/inner/worker.rs index ba9ec3a..725578d 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -54,7 +54,7 @@ pub(crate) enum Work { DropHandshake(KeyID), Recv(RawUdp), UserSend((UserConnTracker, stream::ID, Vec)), - SendData(UserConnTracker), + SendData((UserConnTracker, ::tokio::time::Instant)), } /// Actual worker implementation. @@ -317,6 +317,8 @@ impl Worker { connection::Role::Client, &self.rand, ); + let dest = UdpClient(addr.as_sockaddr().unwrap()); + conn.send_addr = dest; let auth_recv_id = self.connections.reserve_first(); let service_conn_id = self.connections.reserve_first(); @@ -407,15 +409,13 @@ impl Worker { // send always from the first socket // FIXME: select based on routing table let sender = self.sockets[0].local_addr().unwrap(); - let dest = UdpClient(addr.as_sockaddr().unwrap()); // start the timeout right before sending the packet - hshake.timeout = Some(::tokio::task::spawn_local( - Self::handshake_timeout( - self.queue_timeouts_send.clone(), - client_key_id, - ), - )); + let time_drop = self.work_timers.add( + ::tokio::time::Duration::from_secs(10), + Work::DropHandshake(client_key_id), + ); + hshake.timeout = Some(time_drop); // send packet self.send_packet(raw, dest, UdpServer(sender)).await; @@ -441,10 +441,10 @@ impl Worker { }; use connection::Enqueue; match conn.enqueue(stream, data) { - Enqueue::Immediate => { + Enqueue::Immediate(instant) => { let _ = self .queue_sender - .send(Work::SendData(tracker)) + .send(Work::SendData((tracker, instant))) .await; } Enqueue::TimerWait => {} @@ -455,8 +455,23 @@ impl Worker { } } } - Work::SendData(tracker) => { - let mut raw: Vec = Vec::with_capacity(1280); + Work::SendData((tracker, instant)) => { + // make sure we don't process events before they are + // actually needed. + // This is basically busy waiting with extra steps, + // but we don't want to spawn lots of timers and + // we don't really have a fine-grained sleep that is + // multiplatform + let now = ::tokio::time::Instant::now(); + if instant <= now { + let _ = self + .queue_sender + .send(Work::SendData((tracker, instant))) + .await; + return; + } + + let mut raw: Vec = Vec::with_capacity(1200); raw.resize(raw.capacity(), 0); let conn = match self.connections.get_mut(tracker) { None => return, @@ -479,13 +494,6 @@ impl Worker { } } } - async fn handshake_timeout( - timeout_queue: mpsc::UnboundedSender, - key_id: KeyID, - ) { - ::tokio::time::sleep(::std::time::Duration::from_secs(10)).await; - let _ = timeout_queue.send(Work::DropHandshake(key_id)); - } /// Read and do stuff with the raw udp packet async fn recv(&mut self, mut udp: RawUdp) { if udp.packet.id.is_handshake() { @@ -508,224 +516,237 @@ impl Worker { return; } }; - match action { - handshake::Action::AuthNeeded(authinfo) => { - let req; - if let handshake::Data::DirSync(DirSync::Req(r)) = - authinfo.handshake.data - { - req = r; - } else { - ::tracing::error!("AuthInfo on non DS::Req"); + self.recv_handshake(udp, action).await; + } else { + self.recv_packet(udp); + } + } + /// Receive a non-handshake packet + fn recv_packet(&mut self, udp: RawUdp) { + let conn = match self.connections.get_id_mut(udp.packet.id) { + None => return, + Some(conn) => conn, + }; + if let Err(e) = conn.recv(udp) { + ::tracing::trace!("Conn Recv: {:?}", e.to_string()); + } + } + /// Receive an handshake packet + async fn recv_handshake(&mut self, udp: RawUdp, action: handshake::Action) { + match action { + handshake::Action::AuthNeeded(authinfo) => { + let req; + if let handshake::Data::DirSync(DirSync::Req(r)) = + authinfo.handshake.data + { + req = r; + } else { + ::tracing::error!("AuthInfo on non DS::Req"); + return; + } + let req_data = match req.data { + dirsync::req::State::ClearText(req_data) => req_data, + _ => { + ::tracing::error!("AuthNeeded: expected ClearText"); + assert!(false, "AuthNeeded: unreachable"); return; } - let req_data = match req.data { - dirsync::req::State::ClearText(req_data) => req_data, - _ => { - ::tracing::error!("AuthNeeded: expected ClearText"); - assert!(false, "AuthNeeded: unreachable"); - return; - } - }; - // FIXME: This part can take a while, - // we should just spawn it probably - let maybe_auth_check = { - match &self.token_check { - None => { - if req_data.auth.user == auth::USERID_ANONYMOUS - { - Ok(true) - } else { - Ok(false) - } - } - Some(token_check) => { - let tk_check = token_check.lock().await; - tk_check( - req_data.auth.user, - req_data.auth.token, - req_data.auth.service_id, - req_data.auth.domain, - ) - .await + }; + // FIXME: This part can take a while, + // we should just spawn it probably + let maybe_auth_check = { + match &self.token_check { + None => { + if req_data.auth.user == auth::USERID_ANONYMOUS { + Ok(true) + } else { + Ok(false) } } - }; - let is_authenticated = match maybe_auth_check { - Ok(is_authenticated) => is_authenticated, - Err(_) => { - ::tracing::error!("error in token auth"); - // TODO: retry? - return; + Some(token_check) => { + let tk_check = token_check.lock().await; + tk_check( + req_data.auth.user, + req_data.auth.token, + req_data.auth.service_id, + req_data.auth.domain, + ) + .await } - }; - if !is_authenticated { - ::tracing::warn!( - "Wrong authentication for user {:?}", - req_data.auth.user - ); - // TODO: error response + } + }; + let is_authenticated = match maybe_auth_check { + Ok(is_authenticated) => is_authenticated, + Err(_) => { + ::tracing::error!("error in token auth"); + // TODO: retry? return; } - // Client has correctly authenticated - // TODO: contact the service, get the key and - // connection ID - let srv_conn_id = connection::ID::new_rand(&self.rand); - let srv_secret = Secret::new_rand(&self.rand); - let head_len = req.cipher.nonce_len(); - let tag_len = req.cipher.tag_len(); + }; + if !is_authenticated { + ::tracing::warn!( + "Wrong authentication for user {:?}", + req_data.auth.user + ); + // TODO: error response + return; + } + // Client has correctly authenticated + // TODO: contact the service, get the key and + // connection ID + let srv_conn_id = connection::ID::new_rand(&self.rand); + let srv_secret = Secret::new_rand(&self.rand); + let head_len = req.cipher.nonce_len(); + let tag_len = req.cipher.tag_len(); - let mut auth_conn = Connection::new( - authinfo.hkdf, - req.cipher, - connection::Role::Server, + let mut auth_conn = Connection::new( + authinfo.hkdf, + req.cipher, + connection::Role::Server, + &self.rand, + ); + auth_conn.id_send = IDSend(req_data.id); + auth_conn.send_addr = udp.src; + // track connection + let auth_id_recv = self.connections.reserve_first(); + auth_conn.id_recv = auth_id_recv; + + let resp_data = dirsync::resp::Data { + client_nonce: req_data.nonce, + id: auth_conn.id_recv.0, + service_connection_id: srv_conn_id, + service_key: srv_secret, + }; + use crate::enc::sym::AAD; + // no aad for now + let aad = AAD(&mut []); + + let resp = dirsync::resp::Resp { + client_key_id: req_data.client_key_id, + data: dirsync::resp::State::ClearText(resp_data), + }; + let encrypt_from = + connection::ID::len() + resp.encrypted_offset(); + let encrypt_until = + encrypt_from + resp.encrypted_length(head_len, tag_len); + let resp_handshake = Handshake::new(handshake::Data::DirSync( + DirSync::Resp(resp), + )); + let packet = Packet { + id: connection::ID::new_handshake(), + data: packet::Data::Handshake(resp_handshake), + }; + let tot_len = packet.len(head_len, tag_len); + let mut raw_out = Vec::::with_capacity(tot_len); + raw_out.resize(tot_len, 0); + packet.serialize(head_len, tag_len, &mut raw_out); + + if let Err(e) = auth_conn + .cipher_send + .encrypt(aad, &mut raw_out[encrypt_from..encrypt_until]) + { + ::tracing::error!("can't encrypt: {:?}", e); + return; + } + self.send_packet(raw_out, udp.src, udp.dst).await; + } + handshake::Action::ClientConnect(cci) => { + self.work_timers.remove(cci.old_timeout); + let ds_resp; + if let handshake::Data::DirSync(DirSync::Resp(resp)) = + cci.handshake.data + { + ds_resp = resp; + } else { + ::tracing::error!("ClientConnect on non DS::Resp"); + return; + } + // track connection + let resp_data; + if let dirsync::resp::State::ClearText(r_data) = ds_resp.data { + resp_data = r_data; + } else { + ::tracing::error!( + "ClientConnect on non DS::Resp::ClearText" + ); + unreachable!(); + } + let auth_id_send = IDSend(resp_data.id); + let mut conn = cci.connection; + conn.id_send = auth_id_send; + let id_recv = conn.id_recv; + let cipher = conn.cipher_recv.kind(); + // track the connection to the authentication server + let track_auth_conn = match self.connections.track(conn) { + Ok(track_auth_conn) => track_auth_conn, + Err(_) => { + ::tracing::error!( + "Could not track new auth srv connection" + ); + self.connections.remove(id_recv); + // FIXME: proper connection closing + let _ = cci.answer.send(Err( + handshake::Error::InternalTracking.into(), + )); + return; + } + }; + let authsrv_conn = AuthSrvConn(connection::Conn { + queue: self.queue_sender.clone(), + conn: track_auth_conn, + }); + let mut service_conn = None; + if cci.service_id != auth::SERVICEID_AUTH { + // create and track the connection to the service + // SECURITY: xor with secrets + //FIXME: the Secret should be XORed with the client + // stored secret (if any) + let hkdf = Hkdf::new( + hkdf::Kind::Sha3, + cci.service_id.as_bytes(), + resp_data.service_key, + ); + let mut service_connection = Connection::new( + hkdf, + cipher, + connection::Role::Client, &self.rand, ); - auth_conn.id_send = IDSend(req_data.id); - // track connection - let auth_id_recv = self.connections.reserve_first(); - auth_conn.id_recv = auth_id_recv; - - let resp_data = dirsync::resp::Data { - client_nonce: req_data.nonce, - id: auth_conn.id_recv.0, - service_connection_id: srv_conn_id, - service_key: srv_secret, - }; - use crate::enc::sym::AAD; - // no aad for now - let aad = AAD(&mut []); - - let resp = dirsync::resp::Resp { - client_key_id: req_data.client_key_id, - data: dirsync::resp::State::ClearText(resp_data), - }; - let encrypt_from = - connection::ID::len() + resp.encrypted_offset(); - let encrypt_until = - encrypt_from + resp.encrypted_length(head_len, tag_len); - let resp_handshake = Handshake::new( - handshake::Data::DirSync(DirSync::Resp(resp)), - ); - let packet = Packet { - id: connection::ID::new_handshake(), - data: packet::Data::Handshake(resp_handshake), - }; - let tot_len = packet.len(head_len, tag_len); - let mut raw_out = Vec::::with_capacity(tot_len); - raw_out.resize(tot_len, 0); - packet.serialize(head_len, tag_len, &mut raw_out); - - if let Err(e) = auth_conn - .cipher_send - .encrypt(aad, &mut raw_out[encrypt_from..encrypt_until]) - { - ::tracing::error!("can't encrypt: {:?}", e); - return; - } - self.send_packet(raw_out, udp.src, udp.dst).await; - } - handshake::Action::ClientConnect(cci) => { - let ds_resp; - if let handshake::Data::DirSync(DirSync::Resp(resp)) = - cci.handshake.data - { - ds_resp = resp; - } else { - ::tracing::error!("ClientConnect on non DS::Resp"); - return; - } - // track connection - let resp_data; - if let dirsync::resp::State::ClearText(r_data) = - ds_resp.data - { - resp_data = r_data; - } else { - ::tracing::error!( - "ClientConnect on non DS::Resp::ClearText" - ); - unreachable!(); - } - let auth_id_send = IDSend(resp_data.id); - let mut conn = cci.connection; - conn.id_send = auth_id_send; - let id_recv = conn.id_recv; - let cipher = conn.cipher_recv.kind(); - // track the connection to the authentication server - let track_auth_conn = match self.connections.track(conn) { - Ok(track_auth_conn) => track_auth_conn, - Err(_) => { - ::tracing::error!( - "Could not track new auth srv connection" - ); - self.connections.remove(id_recv); - // FIXME: proper connection closing - let _ = cci.answer.send(Err( - handshake::Error::InternalTracking.into(), - )); - return; - } - }; - let authsrv_conn = AuthSrvConn(connection::Conn { + service_connection.id_recv = cci.service_connection_id; + service_connection.id_send = + IDSend(resp_data.service_connection_id); + let track_serv_conn = + match self.connections.track(service_connection) { + Ok(track_serv_conn) => track_serv_conn, + Err(_) => { + ::tracing::error!( + "Could not track new service connection" + ); + self.connections + .remove(cci.service_connection_id); + // FIXME: proper connection closing + // FIXME: drop auth srv connection if we just + // established it + let _ = cci.answer.send(Err( + handshake::Error::InternalTracking.into(), + )); + return; + } + }; + service_conn = Some(ServiceConn(connection::Conn { queue: self.queue_sender.clone(), - conn: track_auth_conn, - }); - let mut service_conn = None; - if cci.service_id != auth::SERVICEID_AUTH { - // create and track the connection to the service - // SECURITY: xor with secrets - //FIXME: the Secret should be XORed with the client - // stored secret (if any) - let hkdf = Hkdf::new( - hkdf::Kind::Sha3, - cci.service_id.as_bytes(), - resp_data.service_key, - ); - let mut service_connection = Connection::new( - hkdf, - cipher, - connection::Role::Client, - &self.rand, - ); - service_connection.id_recv = cci.service_connection_id; - service_connection.id_send = - IDSend(resp_data.service_connection_id); - let track_serv_conn = - match self.connections.track(service_connection) { - Ok(track_serv_conn) => track_serv_conn, - Err(_) => { - ::tracing::error!( - "Could not track new service connection" - ); - self.connections - .remove(cci.service_connection_id); - // FIXME: proper connection closing - // FIXME: drop auth srv connection if we just - // established it - let _ = cci.answer.send(Err( - handshake::Error::InternalTracking - .into(), - )); - return; - } - }; - service_conn = Some(ServiceConn(connection::Conn { - queue: self.queue_sender.clone(), - conn: track_serv_conn, - })); - } - let _ = - cci.answer.send(Ok(handshake::tracker::ConnectOk { - auth_key_id: cci.srv_key_id, - auth_id_send, - authsrv_conn, - service_conn, - })); + conn: track_serv_conn, + })); } - handshake::Action::Nothing => {} - }; - } + let _ = cci.answer.send(Ok(handshake::tracker::ConnectOk { + auth_key_id: cci.srv_key_id, + auth_id_send, + authsrv_conn, + service_conn, + })); + } + handshake::Action::Nothing => {} + }; } async fn send_packet( &self,