User conn tracking, enqueue data, timers
Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
This commit is contained in:
parent
2fe91d5dd3
commit
9c67210e3e
|
@ -5,7 +5,11 @@ pub mod packet;
|
||||||
pub mod socket;
|
pub mod socket;
|
||||||
pub mod stream;
|
pub mod stream;
|
||||||
|
|
||||||
use ::std::{collections::HashMap, rc::Rc, vec::Vec};
|
use ::core::num::Wrapping;
|
||||||
|
use ::std::{
|
||||||
|
collections::{BTreeMap, HashMap},
|
||||||
|
vec::Vec,
|
||||||
|
};
|
||||||
|
|
||||||
pub use crate::connection::{handshake::Handshake, packet::Packet};
|
pub use crate::connection::{handshake::Handshake, packet::Packet};
|
||||||
|
|
||||||
|
@ -125,12 +129,12 @@ impl ProtocolVersion {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, Eq, Hash, PartialEq)]
|
#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
|
||||||
pub(crate) struct UserConnTracker(usize);
|
pub(crate) struct UserConnTracker(Wrapping<usize>);
|
||||||
impl UserConnTracker {
|
impl UserConnTracker {
|
||||||
fn advance(&mut self) -> Self {
|
fn advance(&mut self) -> Self {
|
||||||
let old = self.0;
|
let old = self.0;
|
||||||
self.0 = self.0 + 1;
|
self.0 = self.0 + Wrapping(1);
|
||||||
UserConnTracker(old)
|
UserConnTracker(old)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -151,8 +155,13 @@ pub struct Conn {
|
||||||
|
|
||||||
impl Conn {
|
impl Conn {
|
||||||
/// Queue some data to be sent in this connection
|
/// Queue some data to be sent in this connection
|
||||||
pub fn send(&mut self, stream: stream::ID, _data: Vec<u8>) {
|
// TODO: send_and_wait, that wait for recipient ACK
|
||||||
todo!()
|
pub async fn send(&mut self, stream: stream::ID, data: Vec<u8>) {
|
||||||
|
use crate::inner::worker::Work;
|
||||||
|
let _ = self
|
||||||
|
.queue
|
||||||
|
.send(Work::UserSend((self.conn, stream, data)))
|
||||||
|
.await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -160,15 +169,17 @@ impl Conn {
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) struct Connection {
|
pub(crate) struct Connection {
|
||||||
/// Receiving Conn ID
|
/// Receiving Conn ID
|
||||||
pub id_recv: IDRecv,
|
pub(crate) id_recv: IDRecv,
|
||||||
/// Sending Conn ID
|
/// Sending Conn ID
|
||||||
pub id_send: IDSend,
|
pub(crate) id_send: IDSend,
|
||||||
/// The main hkdf used for all secrets in this connection
|
/// The main hkdf used for all secrets in this connection
|
||||||
pub hkdf: Hkdf,
|
hkdf: Hkdf,
|
||||||
/// Cipher for decrypting data
|
/// Cipher for decrypting data
|
||||||
pub cipher_recv: CipherRecv,
|
pub(crate) cipher_recv: CipherRecv,
|
||||||
/// Cipher for encrypting data
|
/// Cipher for encrypting data
|
||||||
pub cipher_send: CipherSend,
|
pub(crate) cipher_send: CipherSend,
|
||||||
|
/// send queue for each Stream
|
||||||
|
send_queue: BTreeMap<stream::ID, Vec<Vec<u8>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Role: track the connection direction
|
/// Role: track the connection direction
|
||||||
|
@ -210,14 +221,22 @@ impl Connection {
|
||||||
hkdf,
|
hkdf,
|
||||||
cipher_recv,
|
cipher_recv,
|
||||||
cipher_send,
|
cipher_send,
|
||||||
|
send_queue: BTreeMap::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
pub(crate) fn send(&mut self, stream: stream::ID, data: Vec<u8>) {
|
||||||
|
let stream = match self.send_queue.get_mut(&stream) {
|
||||||
|
None => return,
|
||||||
|
Some(stream) => stream,
|
||||||
|
};
|
||||||
|
stream.push(data);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) struct ConnList {
|
pub(crate) struct ConnList {
|
||||||
thread_id: ThreadTracker,
|
thread_id: ThreadTracker,
|
||||||
connections: Vec<Option<Rc<Connection>>>,
|
connections: Vec<Option<Connection>>,
|
||||||
user_tracker: HashMap<UserConnTracker, usize>,
|
user_tracker: BTreeMap<UserConnTracker, usize>,
|
||||||
last_tracked: UserConnTracker,
|
last_tracked: UserConnTracker,
|
||||||
/// Bitmap to track which connection ids are used or free
|
/// Bitmap to track which connection ids are used or free
|
||||||
ids_used: Vec<::bitmaps::Bitmap<1024>>,
|
ids_used: Vec<::bitmaps::Bitmap<1024>>,
|
||||||
|
@ -234,13 +253,27 @@ impl ConnList {
|
||||||
let mut ret = Self {
|
let mut ret = Self {
|
||||||
thread_id,
|
thread_id,
|
||||||
connections: Vec::with_capacity(INITIAL_CAP),
|
connections: Vec::with_capacity(INITIAL_CAP),
|
||||||
user_tracker: HashMap::with_capacity(INITIAL_CAP),
|
user_tracker: BTreeMap::new(),
|
||||||
last_tracked: UserConnTracker(0),
|
last_tracked: UserConnTracker(Wrapping(0)),
|
||||||
ids_used: vec![bitmap_id],
|
ids_used: vec![bitmap_id],
|
||||||
};
|
};
|
||||||
ret.connections.resize_with(INITIAL_CAP, || None);
|
ret.connections.resize_with(INITIAL_CAP, || None);
|
||||||
ret
|
ret
|
||||||
}
|
}
|
||||||
|
pub fn get_mut(
|
||||||
|
&mut self,
|
||||||
|
tracker: UserConnTracker,
|
||||||
|
) -> Option<&mut Connection> {
|
||||||
|
let idx = if let Some(idx) = self.user_tracker.get(&tracker) {
|
||||||
|
*idx
|
||||||
|
} else {
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
match &mut self.connections[idx] {
|
||||||
|
None => None,
|
||||||
|
Some(conn) => Some(conn),
|
||||||
|
}
|
||||||
|
}
|
||||||
pub fn len(&self) -> usize {
|
pub fn len(&self) -> usize {
|
||||||
let mut total: usize = 0;
|
let mut total: usize = 0;
|
||||||
for bitmap in self.ids_used.iter() {
|
for bitmap in self.ids_used.iter() {
|
||||||
|
@ -293,7 +326,7 @@ impl ConnList {
|
||||||
/// NOTE: does NOT check if the connection has been previously reserved!
|
/// NOTE: does NOT check if the connection has been previously reserved!
|
||||||
pub(crate) fn track(
|
pub(crate) fn track(
|
||||||
&mut self,
|
&mut self,
|
||||||
conn: Rc<Connection>,
|
conn: Connection,
|
||||||
) -> Result<UserConnTracker, ()> {
|
) -> Result<UserConnTracker, ()> {
|
||||||
let conn_id = match conn.id_recv {
|
let conn_id = match conn.id_recv {
|
||||||
IDRecv(ID::Handshake) => {
|
IDRecv(ID::Handshake) => {
|
||||||
|
@ -304,8 +337,14 @@ impl ConnList {
|
||||||
let id_in_thread: usize =
|
let id_in_thread: usize =
|
||||||
(conn_id.get() / (self.thread_id.total as u64)) as usize;
|
(conn_id.get() / (self.thread_id.total as u64)) as usize;
|
||||||
self.connections[id_in_thread] = Some(conn);
|
self.connections[id_in_thread] = Some(conn);
|
||||||
let tracked = self.last_tracked.advance();
|
let mut tracked;
|
||||||
let _ = self.user_tracker.insert(tracked, id_in_thread);
|
loop {
|
||||||
|
tracked = self.last_tracked.advance();
|
||||||
|
if self.user_tracker.get(&tracked).is_none() {
|
||||||
|
let _ = self.user_tracker.insert(tracked, id_in_thread);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
Ok(tracked)
|
Ok(tracked)
|
||||||
}
|
}
|
||||||
pub(crate) fn remove(&mut self, id: IDRecv) {
|
pub(crate) fn remove(&mut self, id: IDRecv) {
|
||||||
|
|
|
@ -19,7 +19,7 @@ pub enum Kind {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Id of the stream
|
/// Id of the stream
|
||||||
#[derive(Debug, Copy, Clone)]
|
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||||
pub struct ID(pub u16);
|
pub struct ID(pub u16);
|
||||||
|
|
||||||
impl ID {
|
impl ID {
|
||||||
|
|
|
@ -4,6 +4,10 @@
|
||||||
|
|
||||||
pub(crate) mod worker;
|
pub(crate) mod worker;
|
||||||
|
|
||||||
|
use crate::inner::worker::Work;
|
||||||
|
use ::std::{collections::BTreeMap, vec::Vec};
|
||||||
|
use ::tokio::time::Instant;
|
||||||
|
|
||||||
/// Track the total number of threads and our index
|
/// Track the total number of threads and our index
|
||||||
/// 65K cpus should be enough for anybody
|
/// 65K cpus should be enough for anybody
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
@ -12,3 +16,83 @@ pub(crate) struct ThreadTracker {
|
||||||
/// Note: starts from 1
|
/// Note: starts from 1
|
||||||
pub id: u16,
|
pub id: u16,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) static mut SLEEP_RESOLUTION: ::std::time::Duration =
|
||||||
|
if cfg!(linux) || cfg!(macos) {
|
||||||
|
::std::time::Duration::from_millis(1)
|
||||||
|
} else {
|
||||||
|
// windows
|
||||||
|
::std::time::Duration::from_millis(16)
|
||||||
|
};
|
||||||
|
|
||||||
|
pub(crate) async fn set_minimum_sleep_resolution() {
|
||||||
|
let nanosleep = ::std::time::Duration::from_nanos(1);
|
||||||
|
let mut tests: usize = 3;
|
||||||
|
|
||||||
|
while tests > 0 {
|
||||||
|
let pre_sleep = ::std::time::Instant::now();
|
||||||
|
::tokio::time::sleep(nanosleep).await;
|
||||||
|
let post_sleep = ::std::time::Instant::now();
|
||||||
|
let slept_for = post_sleep - pre_sleep;
|
||||||
|
#[allow(unsafe_code)]
|
||||||
|
unsafe {
|
||||||
|
if slept_for < SLEEP_RESOLUTION {
|
||||||
|
SLEEP_RESOLUTION = slept_for;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tests = tests - 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sleeping has a higher resolution that we would like for packet pacing.
|
||||||
|
/// So we sleep for however log we need, then chunk up all the work here
|
||||||
|
/// we will end up chunking the work in SLEEP_RESOLUTION, then we will busy wait
|
||||||
|
/// for more precise timing
|
||||||
|
pub(crate) struct Timers {
|
||||||
|
times: BTreeMap<Instant, Work>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Timers {
|
||||||
|
pub(crate) fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
times: BTreeMap::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub(crate) fn get_next(&self) -> ::tokio::time::Sleep {
|
||||||
|
match self.times.keys().next() {
|
||||||
|
Some(entry) => ::tokio::time::sleep_until((*entry).into()),
|
||||||
|
None => {
|
||||||
|
::tokio::time::sleep(::std::time::Duration::from_secs(3600))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/// Get all the work from now up until now + SLEEP_RESOLUTION
|
||||||
|
pub(crate) fn get_work(&mut self) -> Vec<Work> {
|
||||||
|
let now: ::tokio::time::Instant = ::std::time::Instant::now().into();
|
||||||
|
let mut ret = Vec::with_capacity(4);
|
||||||
|
let mut count_rm = 0;
|
||||||
|
#[allow(unsafe_code)]
|
||||||
|
let next_instant = unsafe { now + SLEEP_RESOLUTION };
|
||||||
|
let mut iter = self.times.iter_mut().peekable();
|
||||||
|
loop {
|
||||||
|
match iter.peek() {
|
||||||
|
None => break,
|
||||||
|
Some(next) => {
|
||||||
|
if *next.0 > next_instant {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let mut work = Work::DropHandshake(crate::enc::asym::KeyID(0));
|
||||||
|
let mut entry = iter.next().unwrap();
|
||||||
|
::core::mem::swap(&mut work, &mut entry.1);
|
||||||
|
ret.push(work);
|
||||||
|
count_rm = count_rm + 1;
|
||||||
|
}
|
||||||
|
while count_rm > 0 {
|
||||||
|
self.times.pop_first();
|
||||||
|
count_rm = count_rm - 1;
|
||||||
|
}
|
||||||
|
ret
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -11,7 +11,8 @@ use crate::{
|
||||||
},
|
},
|
||||||
packet::{self, Packet},
|
packet::{self, Packet},
|
||||||
socket::{UdpClient, UdpServer},
|
socket::{UdpClient, UdpServer},
|
||||||
AuthSrvConn, ConnList, Connection, IDSend, ServiceConn,
|
stream, AuthSrvConn, ConnList, Connection, IDSend, ServiceConn,
|
||||||
|
UserConnTracker,
|
||||||
},
|
},
|
||||||
dnssec,
|
dnssec,
|
||||||
enc::{
|
enc::{
|
||||||
|
@ -51,6 +52,7 @@ pub(crate) enum Work {
|
||||||
Connect(ConnectInfo),
|
Connect(ConnectInfo),
|
||||||
DropHandshake(KeyID),
|
DropHandshake(KeyID),
|
||||||
Recv(RawUdp),
|
Recv(RawUdp),
|
||||||
|
UserSend((UserConnTracker, stream::ID, Vec<u8>)),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Actual worker implementation.
|
/// Actual worker implementation.
|
||||||
|
@ -70,6 +72,7 @@ pub struct Worker {
|
||||||
thread_channels: Vec<::async_channel::Sender<Work>>,
|
thread_channels: Vec<::async_channel::Sender<Work>>,
|
||||||
connections: ConnList,
|
connections: ConnList,
|
||||||
handshakes: handshake::Tracker,
|
handshakes: handshake::Tracker,
|
||||||
|
work_timers: super::Timers,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(unsafe_code)]
|
#[allow(unsafe_code)]
|
||||||
|
@ -126,12 +129,15 @@ impl Worker {
|
||||||
thread_channels: Vec::new(),
|
thread_channels: Vec::new(),
|
||||||
connections: ConnList::new(thread_id),
|
connections: ConnList::new(thread_id),
|
||||||
handshakes,
|
handshakes,
|
||||||
|
work_timers: super::Timers::new(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Continuously loop and process work as needed
|
/// Continuously loop and process work as needed
|
||||||
pub async fn work_loop(&mut self) {
|
pub async fn work_loop(&mut self) {
|
||||||
'mainloop: loop {
|
'mainloop: loop {
|
||||||
|
let next_timer = self.work_timers.get_next();
|
||||||
|
::tokio::pin!(next_timer);
|
||||||
let work = ::tokio::select! {
|
let work = ::tokio::select! {
|
||||||
tell_stopped = self.stop_working.recv() => {
|
tell_stopped = self.stop_working.recv() => {
|
||||||
if let Ok(stop_ch) = tell_stopped {
|
if let Ok(stop_ch) = tell_stopped {
|
||||||
|
@ -140,6 +146,13 @@ impl Worker {
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
() = &mut next_timer => {
|
||||||
|
let work_list = self.work_timers.get_work();
|
||||||
|
for w in work_list.into_iter() {
|
||||||
|
let _ = self.queue_sender.send(w).await;
|
||||||
|
}
|
||||||
|
continue 'mainloop;
|
||||||
|
}
|
||||||
maybe_timeout = self.queue.recv() => {
|
maybe_timeout = self.queue.recv() => {
|
||||||
match maybe_timeout {
|
match maybe_timeout {
|
||||||
Ok(work) => work,
|
Ok(work) => work,
|
||||||
|
@ -419,6 +432,13 @@ impl Worker {
|
||||||
Work::Recv(pkt) => {
|
Work::Recv(pkt) => {
|
||||||
self.recv(pkt).await;
|
self.recv(pkt).await;
|
||||||
}
|
}
|
||||||
|
Work::UserSend((tracker, stream, data)) => {
|
||||||
|
let conn = match self.connections.get_mut(tracker) {
|
||||||
|
None => return,
|
||||||
|
Some(conn) => conn,
|
||||||
|
};
|
||||||
|
conn.send(stream, data);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -596,21 +616,20 @@ impl Worker {
|
||||||
let id_recv = conn.id_recv;
|
let id_recv = conn.id_recv;
|
||||||
let cipher = conn.cipher_recv.kind();
|
let cipher = conn.cipher_recv.kind();
|
||||||
// track the connection to the authentication server
|
// track the connection to the authentication server
|
||||||
let track_auth_conn =
|
let track_auth_conn = match self.connections.track(conn) {
|
||||||
match self.connections.track(conn.into()) {
|
Ok(track_auth_conn) => track_auth_conn,
|
||||||
Ok(track_auth_conn) => track_auth_conn,
|
Err(_) => {
|
||||||
Err(e) => {
|
::tracing::error!(
|
||||||
::tracing::error!(
|
"Could not track new auth srv connection"
|
||||||
"Could not track new auth srv connection"
|
);
|
||||||
);
|
self.connections.remove(id_recv);
|
||||||
self.connections.remove(id_recv);
|
// FIXME: proper connection closing
|
||||||
// FIXME: proper connection closing
|
let _ = cci.answer.send(Err(
|
||||||
let _ = cci.answer.send(Err(
|
handshake::Error::InternalTracking.into(),
|
||||||
handshake::Error::InternalTracking.into(),
|
));
|
||||||
));
|
return;
|
||||||
return;
|
}
|
||||||
}
|
};
|
||||||
};
|
|
||||||
let authsrv_conn = AuthSrvConn(connection::Conn {
|
let authsrv_conn = AuthSrvConn(connection::Conn {
|
||||||
queue: self.queue_sender.clone(),
|
queue: self.queue_sender.clone(),
|
||||||
conn: track_auth_conn,
|
conn: track_auth_conn,
|
||||||
|
@ -635,26 +654,25 @@ impl Worker {
|
||||||
service_connection.id_recv = cci.service_connection_id;
|
service_connection.id_recv = cci.service_connection_id;
|
||||||
service_connection.id_send =
|
service_connection.id_send =
|
||||||
IDSend(resp_data.service_connection_id);
|
IDSend(resp_data.service_connection_id);
|
||||||
let track_serv_conn = match self
|
let track_serv_conn =
|
||||||
.connections
|
match self.connections.track(service_connection) {
|
||||||
.track(service_connection.into())
|
Ok(track_serv_conn) => track_serv_conn,
|
||||||
{
|
Err(_) => {
|
||||||
Ok(track_serv_conn) => track_serv_conn,
|
::tracing::error!(
|
||||||
Err(e) => {
|
"Could not track new service connection"
|
||||||
::tracing::error!(
|
);
|
||||||
"Could not track new service connection"
|
self.connections
|
||||||
);
|
.remove(cci.service_connection_id);
|
||||||
self.connections
|
// FIXME: proper connection closing
|
||||||
.remove(cci.service_connection_id);
|
// FIXME: drop auth srv connection if we just
|
||||||
// FIXME: proper connection closing
|
// established it
|
||||||
// FIXME: drop auth srv connection if we just
|
let _ = cci.answer.send(Err(
|
||||||
// established it
|
handshake::Error::InternalTracking
|
||||||
let _ = cci.answer.send(Err(
|
.into(),
|
||||||
handshake::Error::InternalTracking.into(),
|
));
|
||||||
));
|
return;
|
||||||
return;
|
}
|
||||||
}
|
};
|
||||||
};
|
|
||||||
service_conn = Some(ServiceConn(connection::Conn {
|
service_conn = Some(ServiceConn(connection::Conn {
|
||||||
queue: self.queue_sender.clone(),
|
queue: self.queue_sender.clone(),
|
||||||
conn: track_serv_conn,
|
conn: track_serv_conn,
|
||||||
|
|
|
@ -176,6 +176,7 @@ impl Fenrir {
|
||||||
config: &Config,
|
config: &Config,
|
||||||
tokio_rt: Arc<::tokio::runtime::Runtime>,
|
tokio_rt: Arc<::tokio::runtime::Runtime>,
|
||||||
) -> Result<Self, Error> {
|
) -> Result<Self, Error> {
|
||||||
|
inner::set_minimum_sleep_resolution().await;
|
||||||
let (sender, _) = ::tokio::sync::broadcast::channel(1);
|
let (sender, _) = ::tokio::sync::broadcast::channel(1);
|
||||||
let dnssec = dnssec::Dnssec::new(&config.resolvers)?;
|
let dnssec = dnssec::Dnssec::new(&config.resolvers)?;
|
||||||
// bind sockets early so we can change "port 0" (aka: random)
|
// bind sockets early so we can change "port 0" (aka: random)
|
||||||
|
@ -214,6 +215,7 @@ impl Fenrir {
|
||||||
pub async fn with_workers(
|
pub async fn with_workers(
|
||||||
config: &Config,
|
config: &Config,
|
||||||
) -> Result<(Self, Vec<Worker>), Error> {
|
) -> Result<(Self, Vec<Worker>), Error> {
|
||||||
|
inner::set_minimum_sleep_resolution().await;
|
||||||
let (stop_working, _) = ::tokio::sync::broadcast::channel(1);
|
let (stop_working, _) = ::tokio::sync::broadcast::channel(1);
|
||||||
let dnssec = dnssec::Dnssec::new(&config.resolvers)?;
|
let dnssec = dnssec::Dnssec::new(&config.resolvers)?;
|
||||||
// bind sockets early so we can change "port 0" (aka: random)
|
// bind sockets early so we can change "port 0" (aka: random)
|
||||||
|
|
Loading…
Reference in New Issue