User conn tracking, enqueue data, timers
Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
This commit is contained in:
parent
2fe91d5dd3
commit
9c67210e3e
|
@ -5,7 +5,11 @@ pub mod packet;
|
|||
pub mod socket;
|
||||
pub mod stream;
|
||||
|
||||
use ::std::{collections::HashMap, rc::Rc, vec::Vec};
|
||||
use ::core::num::Wrapping;
|
||||
use ::std::{
|
||||
collections::{BTreeMap, HashMap},
|
||||
vec::Vec,
|
||||
};
|
||||
|
||||
pub use crate::connection::{handshake::Handshake, packet::Packet};
|
||||
|
||||
|
@ -125,12 +129,12 @@ impl ProtocolVersion {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Eq, Hash, PartialEq)]
|
||||
pub(crate) struct UserConnTracker(usize);
|
||||
#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
|
||||
pub(crate) struct UserConnTracker(Wrapping<usize>);
|
||||
impl UserConnTracker {
|
||||
fn advance(&mut self) -> Self {
|
||||
let old = self.0;
|
||||
self.0 = self.0 + 1;
|
||||
self.0 = self.0 + Wrapping(1);
|
||||
UserConnTracker(old)
|
||||
}
|
||||
}
|
||||
|
@ -151,8 +155,13 @@ pub struct Conn {
|
|||
|
||||
impl Conn {
|
||||
/// Queue some data to be sent in this connection
|
||||
pub fn send(&mut self, stream: stream::ID, _data: Vec<u8>) {
|
||||
todo!()
|
||||
// TODO: send_and_wait, that wait for recipient ACK
|
||||
pub async fn send(&mut self, stream: stream::ID, data: Vec<u8>) {
|
||||
use crate::inner::worker::Work;
|
||||
let _ = self
|
||||
.queue
|
||||
.send(Work::UserSend((self.conn, stream, data)))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -160,15 +169,17 @@ impl Conn {
|
|||
#[derive(Debug)]
|
||||
pub(crate) struct Connection {
|
||||
/// Receiving Conn ID
|
||||
pub id_recv: IDRecv,
|
||||
pub(crate) id_recv: IDRecv,
|
||||
/// Sending Conn ID
|
||||
pub id_send: IDSend,
|
||||
pub(crate) id_send: IDSend,
|
||||
/// The main hkdf used for all secrets in this connection
|
||||
pub hkdf: Hkdf,
|
||||
hkdf: Hkdf,
|
||||
/// Cipher for decrypting data
|
||||
pub cipher_recv: CipherRecv,
|
||||
pub(crate) cipher_recv: CipherRecv,
|
||||
/// Cipher for encrypting data
|
||||
pub cipher_send: CipherSend,
|
||||
pub(crate) cipher_send: CipherSend,
|
||||
/// send queue for each Stream
|
||||
send_queue: BTreeMap<stream::ID, Vec<Vec<u8>>>,
|
||||
}
|
||||
|
||||
/// Role: track the connection direction
|
||||
|
@ -210,14 +221,22 @@ impl Connection {
|
|||
hkdf,
|
||||
cipher_recv,
|
||||
cipher_send,
|
||||
send_queue: BTreeMap::new(),
|
||||
}
|
||||
}
|
||||
pub(crate) fn send(&mut self, stream: stream::ID, data: Vec<u8>) {
|
||||
let stream = match self.send_queue.get_mut(&stream) {
|
||||
None => return,
|
||||
Some(stream) => stream,
|
||||
};
|
||||
stream.push(data);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct ConnList {
|
||||
thread_id: ThreadTracker,
|
||||
connections: Vec<Option<Rc<Connection>>>,
|
||||
user_tracker: HashMap<UserConnTracker, usize>,
|
||||
connections: Vec<Option<Connection>>,
|
||||
user_tracker: BTreeMap<UserConnTracker, usize>,
|
||||
last_tracked: UserConnTracker,
|
||||
/// Bitmap to track which connection ids are used or free
|
||||
ids_used: Vec<::bitmaps::Bitmap<1024>>,
|
||||
|
@ -234,13 +253,27 @@ impl ConnList {
|
|||
let mut ret = Self {
|
||||
thread_id,
|
||||
connections: Vec::with_capacity(INITIAL_CAP),
|
||||
user_tracker: HashMap::with_capacity(INITIAL_CAP),
|
||||
last_tracked: UserConnTracker(0),
|
||||
user_tracker: BTreeMap::new(),
|
||||
last_tracked: UserConnTracker(Wrapping(0)),
|
||||
ids_used: vec![bitmap_id],
|
||||
};
|
||||
ret.connections.resize_with(INITIAL_CAP, || None);
|
||||
ret
|
||||
}
|
||||
pub fn get_mut(
|
||||
&mut self,
|
||||
tracker: UserConnTracker,
|
||||
) -> Option<&mut Connection> {
|
||||
let idx = if let Some(idx) = self.user_tracker.get(&tracker) {
|
||||
*idx
|
||||
} else {
|
||||
return None;
|
||||
};
|
||||
match &mut self.connections[idx] {
|
||||
None => None,
|
||||
Some(conn) => Some(conn),
|
||||
}
|
||||
}
|
||||
pub fn len(&self) -> usize {
|
||||
let mut total: usize = 0;
|
||||
for bitmap in self.ids_used.iter() {
|
||||
|
@ -293,7 +326,7 @@ impl ConnList {
|
|||
/// NOTE: does NOT check if the connection has been previously reserved!
|
||||
pub(crate) fn track(
|
||||
&mut self,
|
||||
conn: Rc<Connection>,
|
||||
conn: Connection,
|
||||
) -> Result<UserConnTracker, ()> {
|
||||
let conn_id = match conn.id_recv {
|
||||
IDRecv(ID::Handshake) => {
|
||||
|
@ -304,8 +337,14 @@ impl ConnList {
|
|||
let id_in_thread: usize =
|
||||
(conn_id.get() / (self.thread_id.total as u64)) as usize;
|
||||
self.connections[id_in_thread] = Some(conn);
|
||||
let tracked = self.last_tracked.advance();
|
||||
let mut tracked;
|
||||
loop {
|
||||
tracked = self.last_tracked.advance();
|
||||
if self.user_tracker.get(&tracked).is_none() {
|
||||
let _ = self.user_tracker.insert(tracked, id_in_thread);
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(tracked)
|
||||
}
|
||||
pub(crate) fn remove(&mut self, id: IDRecv) {
|
||||
|
|
|
@ -19,7 +19,7 @@ pub enum Kind {
|
|||
}
|
||||
|
||||
/// Id of the stream
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub struct ID(pub u16);
|
||||
|
||||
impl ID {
|
||||
|
|
|
@ -4,6 +4,10 @@
|
|||
|
||||
pub(crate) mod worker;
|
||||
|
||||
use crate::inner::worker::Work;
|
||||
use ::std::{collections::BTreeMap, vec::Vec};
|
||||
use ::tokio::time::Instant;
|
||||
|
||||
/// Track the total number of threads and our index
|
||||
/// 65K cpus should be enough for anybody
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
|
@ -12,3 +16,83 @@ pub(crate) struct ThreadTracker {
|
|||
/// Note: starts from 1
|
||||
pub id: u16,
|
||||
}
|
||||
|
||||
pub(crate) static mut SLEEP_RESOLUTION: ::std::time::Duration =
|
||||
if cfg!(linux) || cfg!(macos) {
|
||||
::std::time::Duration::from_millis(1)
|
||||
} else {
|
||||
// windows
|
||||
::std::time::Duration::from_millis(16)
|
||||
};
|
||||
|
||||
pub(crate) async fn set_minimum_sleep_resolution() {
|
||||
let nanosleep = ::std::time::Duration::from_nanos(1);
|
||||
let mut tests: usize = 3;
|
||||
|
||||
while tests > 0 {
|
||||
let pre_sleep = ::std::time::Instant::now();
|
||||
::tokio::time::sleep(nanosleep).await;
|
||||
let post_sleep = ::std::time::Instant::now();
|
||||
let slept_for = post_sleep - pre_sleep;
|
||||
#[allow(unsafe_code)]
|
||||
unsafe {
|
||||
if slept_for < SLEEP_RESOLUTION {
|
||||
SLEEP_RESOLUTION = slept_for;
|
||||
}
|
||||
}
|
||||
tests = tests - 1;
|
||||
}
|
||||
}
|
||||
|
||||
/// Sleeping has a higher resolution that we would like for packet pacing.
|
||||
/// So we sleep for however log we need, then chunk up all the work here
|
||||
/// we will end up chunking the work in SLEEP_RESOLUTION, then we will busy wait
|
||||
/// for more precise timing
|
||||
pub(crate) struct Timers {
|
||||
times: BTreeMap<Instant, Work>,
|
||||
}
|
||||
|
||||
impl Timers {
|
||||
pub(crate) fn new() -> Self {
|
||||
Self {
|
||||
times: BTreeMap::new(),
|
||||
}
|
||||
}
|
||||
pub(crate) fn get_next(&self) -> ::tokio::time::Sleep {
|
||||
match self.times.keys().next() {
|
||||
Some(entry) => ::tokio::time::sleep_until((*entry).into()),
|
||||
None => {
|
||||
::tokio::time::sleep(::std::time::Duration::from_secs(3600))
|
||||
}
|
||||
}
|
||||
}
|
||||
/// Get all the work from now up until now + SLEEP_RESOLUTION
|
||||
pub(crate) fn get_work(&mut self) -> Vec<Work> {
|
||||
let now: ::tokio::time::Instant = ::std::time::Instant::now().into();
|
||||
let mut ret = Vec::with_capacity(4);
|
||||
let mut count_rm = 0;
|
||||
#[allow(unsafe_code)]
|
||||
let next_instant = unsafe { now + SLEEP_RESOLUTION };
|
||||
let mut iter = self.times.iter_mut().peekable();
|
||||
loop {
|
||||
match iter.peek() {
|
||||
None => break,
|
||||
Some(next) => {
|
||||
if *next.0 > next_instant {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut work = Work::DropHandshake(crate::enc::asym::KeyID(0));
|
||||
let mut entry = iter.next().unwrap();
|
||||
::core::mem::swap(&mut work, &mut entry.1);
|
||||
ret.push(work);
|
||||
count_rm = count_rm + 1;
|
||||
}
|
||||
while count_rm > 0 {
|
||||
self.times.pop_first();
|
||||
count_rm = count_rm - 1;
|
||||
}
|
||||
ret
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,7 +11,8 @@ use crate::{
|
|||
},
|
||||
packet::{self, Packet},
|
||||
socket::{UdpClient, UdpServer},
|
||||
AuthSrvConn, ConnList, Connection, IDSend, ServiceConn,
|
||||
stream, AuthSrvConn, ConnList, Connection, IDSend, ServiceConn,
|
||||
UserConnTracker,
|
||||
},
|
||||
dnssec,
|
||||
enc::{
|
||||
|
@ -51,6 +52,7 @@ pub(crate) enum Work {
|
|||
Connect(ConnectInfo),
|
||||
DropHandshake(KeyID),
|
||||
Recv(RawUdp),
|
||||
UserSend((UserConnTracker, stream::ID, Vec<u8>)),
|
||||
}
|
||||
|
||||
/// Actual worker implementation.
|
||||
|
@ -70,6 +72,7 @@ pub struct Worker {
|
|||
thread_channels: Vec<::async_channel::Sender<Work>>,
|
||||
connections: ConnList,
|
||||
handshakes: handshake::Tracker,
|
||||
work_timers: super::Timers,
|
||||
}
|
||||
|
||||
#[allow(unsafe_code)]
|
||||
|
@ -126,12 +129,15 @@ impl Worker {
|
|||
thread_channels: Vec::new(),
|
||||
connections: ConnList::new(thread_id),
|
||||
handshakes,
|
||||
work_timers: super::Timers::new(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Continuously loop and process work as needed
|
||||
pub async fn work_loop(&mut self) {
|
||||
'mainloop: loop {
|
||||
let next_timer = self.work_timers.get_next();
|
||||
::tokio::pin!(next_timer);
|
||||
let work = ::tokio::select! {
|
||||
tell_stopped = self.stop_working.recv() => {
|
||||
if let Ok(stop_ch) = tell_stopped {
|
||||
|
@ -140,6 +146,13 @@ impl Worker {
|
|||
}
|
||||
break;
|
||||
}
|
||||
() = &mut next_timer => {
|
||||
let work_list = self.work_timers.get_work();
|
||||
for w in work_list.into_iter() {
|
||||
let _ = self.queue_sender.send(w).await;
|
||||
}
|
||||
continue 'mainloop;
|
||||
}
|
||||
maybe_timeout = self.queue.recv() => {
|
||||
match maybe_timeout {
|
||||
Ok(work) => work,
|
||||
|
@ -419,6 +432,13 @@ impl Worker {
|
|||
Work::Recv(pkt) => {
|
||||
self.recv(pkt).await;
|
||||
}
|
||||
Work::UserSend((tracker, stream, data)) => {
|
||||
let conn = match self.connections.get_mut(tracker) {
|
||||
None => return,
|
||||
Some(conn) => conn,
|
||||
};
|
||||
conn.send(stream, data);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -596,10 +616,9 @@ impl Worker {
|
|||
let id_recv = conn.id_recv;
|
||||
let cipher = conn.cipher_recv.kind();
|
||||
// track the connection to the authentication server
|
||||
let track_auth_conn =
|
||||
match self.connections.track(conn.into()) {
|
||||
let track_auth_conn = match self.connections.track(conn) {
|
||||
Ok(track_auth_conn) => track_auth_conn,
|
||||
Err(e) => {
|
||||
Err(_) => {
|
||||
::tracing::error!(
|
||||
"Could not track new auth srv connection"
|
||||
);
|
||||
|
@ -635,12 +654,10 @@ impl Worker {
|
|||
service_connection.id_recv = cci.service_connection_id;
|
||||
service_connection.id_send =
|
||||
IDSend(resp_data.service_connection_id);
|
||||
let track_serv_conn = match self
|
||||
.connections
|
||||
.track(service_connection.into())
|
||||
{
|
||||
let track_serv_conn =
|
||||
match self.connections.track(service_connection) {
|
||||
Ok(track_serv_conn) => track_serv_conn,
|
||||
Err(e) => {
|
||||
Err(_) => {
|
||||
::tracing::error!(
|
||||
"Could not track new service connection"
|
||||
);
|
||||
|
@ -650,7 +667,8 @@ impl Worker {
|
|||
// FIXME: drop auth srv connection if we just
|
||||
// established it
|
||||
let _ = cci.answer.send(Err(
|
||||
handshake::Error::InternalTracking.into(),
|
||||
handshake::Error::InternalTracking
|
||||
.into(),
|
||||
));
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -176,6 +176,7 @@ impl Fenrir {
|
|||
config: &Config,
|
||||
tokio_rt: Arc<::tokio::runtime::Runtime>,
|
||||
) -> Result<Self, Error> {
|
||||
inner::set_minimum_sleep_resolution().await;
|
||||
let (sender, _) = ::tokio::sync::broadcast::channel(1);
|
||||
let dnssec = dnssec::Dnssec::new(&config.resolvers)?;
|
||||
// bind sockets early so we can change "port 0" (aka: random)
|
||||
|
@ -214,6 +215,7 @@ impl Fenrir {
|
|||
pub async fn with_workers(
|
||||
config: &Config,
|
||||
) -> Result<(Self, Vec<Worker>), Error> {
|
||||
inner::set_minimum_sleep_resolution().await;
|
||||
let (stop_working, _) = ::tokio::sync::broadcast::channel(1);
|
||||
let dnssec = dnssec::Dnssec::new(&config.resolvers)?;
|
||||
// bind sockets early so we can change "port 0" (aka: random)
|
||||
|
|
Loading…
Reference in New Issue