Give the user a tracker for conn interactions

Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
This commit is contained in:
Luca Fulchir 2023-06-20 18:22:34 +02:00
parent 11d6b4e467
commit 2fe91d5dd3
Signed by: luca.fulchir
GPG Key ID: 8F6440603D13A78E
5 changed files with 162 additions and 77 deletions

View File

@ -3,8 +3,9 @@
use crate::{ use crate::{
auth::{Domain, ServiceID}, auth::{Domain, ServiceID},
connection::{ connection::{
self,
handshake::{self, Error, Handshake}, handshake::{self, Error, Handshake},
Conn, IDRecv, IDSend, Connection, IDRecv, IDSend,
}, },
enc::{ enc::{
self, self,
@ -18,20 +19,27 @@ use crate::{
use ::tokio::sync::oneshot; use ::tokio::sync::oneshot;
pub(crate) struct Server { pub(crate) struct Server {
pub id: KeyID, pub(crate) id: KeyID,
pub key: PrivKey, pub(crate) key: PrivKey,
pub domains: Vec<Domain>, pub(crate) domains: Vec<Domain>,
} }
pub(crate) type ConnectAnswer = Result<(KeyID, IDSend), crate::Error>; pub(crate) type ConnectAnswer = Result<ConnectOk, crate::Error>;
#[derive(Debug)]
pub(crate) struct ConnectOk {
pub(crate) auth_key_id: KeyID,
pub(crate) auth_id_send: IDSend,
pub(crate) authsrv_conn: connection::AuthSrvConn,
pub(crate) service_conn: Option<connection::ServiceConn>,
}
pub(crate) struct Client { pub(crate) struct Client {
pub service_id: ServiceID, pub(crate) service_id: ServiceID,
pub service_conn_id: IDRecv, pub(crate) service_conn_id: IDRecv,
pub connection: Conn, pub(crate) connection: Connection,
pub timeout: Option<::tokio::task::JoinHandle<()>>, pub(crate) timeout: Option<::tokio::task::JoinHandle<()>>,
pub answer: oneshot::Sender<ConnectAnswer>, pub(crate) answer: oneshot::Sender<ConnectAnswer>,
pub srv_key_id: KeyID, pub(crate) srv_key_id: KeyID,
} }
/// Tracks the keys used by the client and the handshake /// Tracks the keys used by the client and the handshake
@ -78,7 +86,7 @@ impl ClientList {
pub_key: PubKey, pub_key: PubKey,
service_id: ServiceID, service_id: ServiceID,
service_conn_id: IDRecv, service_conn_id: IDRecv,
connection: Conn, connection: Connection,
answer: oneshot::Sender<ConnectAnswer>, answer: oneshot::Sender<ConnectAnswer>,
srv_key_id: KeyID, srv_key_id: KeyID,
) -> Result<(KeyID, &mut Client), oneshot::Sender<ConnectAnswer>> { ) -> Result<(KeyID, &mut Client), oneshot::Sender<ConnectAnswer>> {
@ -128,26 +136,26 @@ impl ClientList {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) struct AuthNeededInfo { pub(crate) struct AuthNeededInfo {
/// Parsed handshake packet /// Parsed handshake packet
pub handshake: Handshake, pub(crate) handshake: Handshake,
/// hkdf generated from the handshake /// hkdf generated from the handshake
pub hkdf: Hkdf, pub(crate) hkdf: Hkdf,
} }
/// Client information needed to fully establish the conenction /// Client information needed to fully establish the conenction
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct ClientConnectInfo { pub(crate) struct ClientConnectInfo {
/// The service ID that we are connecting to /// The service ID that we are connecting to
pub service_id: ServiceID, pub(crate) service_id: ServiceID,
/// The service ID that we are connecting to /// The service ID that we are connecting to
pub service_connection_id: IDRecv, pub(crate) service_connection_id: IDRecv,
/// Parsed handshake packet /// Parsed handshake packet
pub handshake: Handshake, pub(crate) handshake: Handshake,
/// Conn /// Connection
pub connection: Conn, pub(crate) connection: Connection,
/// where to wake up the waiting client /// where to wake up the waiting client
pub answer: oneshot::Sender<ConnectAnswer>, pub(crate) answer: oneshot::Sender<ConnectAnswer>,
/// server public key id that we used on the handshake /// server pub(crate)lic key id that we used on the handshake
pub srv_key_id: KeyID, pub(crate) srv_key_id: KeyID,
} }
/// Intermediate actions to be taken while parsing the handshake /// Intermediate actions to be taken while parsing the handshake
#[derive(Debug)] #[derive(Debug)]
@ -231,7 +239,7 @@ impl Tracker {
pub_key: PubKey, pub_key: PubKey,
service_id: ServiceID, service_id: ServiceID,
service_conn_id: IDRecv, service_conn_id: IDRecv,
connection: Conn, connection: Connection,
answer: oneshot::Sender<ConnectAnswer>, answer: oneshot::Sender<ConnectAnswer>,
srv_key_id: KeyID, srv_key_id: KeyID,
) -> Result<(KeyID, &mut Client), oneshot::Sender<ConnectAnswer>> { ) -> Result<(KeyID, &mut Client), oneshot::Sender<ConnectAnswer>> {

View File

@ -5,7 +5,7 @@ pub mod packet;
pub mod socket; pub mod socket;
pub mod stream; pub mod stream;
use ::std::{rc::Rc, vec::Vec}; use ::std::{collections::HashMap, rc::Rc, vec::Vec};
pub use crate::connection::{handshake::Handshake, packet::Packet}; pub use crate::connection::{handshake::Handshake, packet::Packet};
@ -17,9 +17,8 @@ use crate::{
sym::{self, CipherRecv, CipherSend}, sym::{self, CipherRecv, CipherSend},
Random, Random,
}, },
inner::ThreadTracker, inner::{worker, ThreadTracker},
}; };
use ::std::rc;
/// Fenrir Connection ID /// Fenrir Connection ID
/// ///
@ -126,13 +125,40 @@ impl ProtocolVersion {
} }
} }
#[derive(Debug, Clone, Copy, Eq, Hash, PartialEq)]
pub(crate) struct UserConnTracker(usize);
impl UserConnTracker {
fn advance(&mut self) -> Self {
let old = self.0;
self.0 = self.0 + 1;
UserConnTracker(old)
}
}
/// Connection to an Authentication Server
#[derive(Debug)]
pub struct AuthSrvConn(pub(crate) Conn);
/// Connection to a service
#[derive(Debug)]
pub struct ServiceConn(pub(crate) Conn);
/// The connection, as seen from a user of libFenrir /// The connection, as seen from a user of libFenrir
#[derive(Debug)] #[derive(Debug)]
pub struct Connection(rc::Weak<Conn>); pub struct Conn {
pub(crate) queue: ::async_channel::Sender<worker::Work>,
pub(crate) conn: UserConnTracker,
}
impl Conn {
/// Queue some data to be sent in this connection
pub fn send(&mut self, stream: stream::ID, _data: Vec<u8>) {
todo!()
}
}
/// A single connection and its data /// A single connection and its data
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct Conn { pub(crate) struct Connection {
/// Receiving Conn ID /// Receiving Conn ID
pub id_recv: IDRecv, pub id_recv: IDRecv,
/// Sending Conn ID /// Sending Conn ID
@ -160,7 +186,7 @@ pub enum Role {
Client, Client,
} }
impl Conn { impl Connection {
pub(crate) fn new( pub(crate) fn new(
hkdf: Hkdf, hkdf: Hkdf,
cipher: sym::Kind, cipher: sym::Kind,
@ -190,7 +216,9 @@ impl Conn {
pub(crate) struct ConnList { pub(crate) struct ConnList {
thread_id: ThreadTracker, thread_id: ThreadTracker,
connections: Vec<Option<Rc<Conn>>>, connections: Vec<Option<Rc<Connection>>>,
user_tracker: HashMap<UserConnTracker, usize>,
last_tracked: UserConnTracker,
/// Bitmap to track which connection ids are used or free /// Bitmap to track which connection ids are used or free
ids_used: Vec<::bitmaps::Bitmap<1024>>, ids_used: Vec<::bitmaps::Bitmap<1024>>,
} }
@ -206,6 +234,8 @@ impl ConnList {
let mut ret = Self { let mut ret = Self {
thread_id, thread_id,
connections: Vec::with_capacity(INITIAL_CAP), connections: Vec::with_capacity(INITIAL_CAP),
user_tracker: HashMap::with_capacity(INITIAL_CAP),
last_tracked: UserConnTracker(0),
ids_used: vec![bitmap_id], ids_used: vec![bitmap_id],
}; };
ret.connections.resize_with(INITIAL_CAP, || None); ret.connections.resize_with(INITIAL_CAP, || None);
@ -261,7 +291,10 @@ impl ConnList {
new_id new_id
} }
/// NOTE: does NOT check if the connection has been previously reserved! /// NOTE: does NOT check if the connection has been previously reserved!
pub(crate) fn track(&mut self, conn: Rc<Conn>) -> Result<(), ()> { pub(crate) fn track(
&mut self,
conn: Rc<Connection>,
) -> Result<UserConnTracker, ()> {
let conn_id = match conn.id_recv { let conn_id = match conn.id_recv {
IDRecv(ID::Handshake) => { IDRecv(ID::Handshake) => {
return Err(()); return Err(());
@ -271,7 +304,9 @@ impl ConnList {
let id_in_thread: usize = let id_in_thread: usize =
(conn_id.get() / (self.thread_id.total as u64)) as usize; (conn_id.get() / (self.thread_id.total as u64)) as usize;
self.connections[id_in_thread] = Some(conn); self.connections[id_in_thread] = Some(conn);
Ok(()) let tracked = self.last_tracked.advance();
let _ = self.user_tracker.insert(tracked, id_in_thread);
Ok(tracked)
} }
pub(crate) fn remove(&mut self, id: IDRecv) { pub(crate) fn remove(&mut self, id: IDRecv) {
if let IDRecv(ID::ID(raw_id)) = id { if let IDRecv(ID::ID(raw_id)) = id {
@ -303,7 +338,6 @@ enum MapEntry {
Present(IDSend), Present(IDSend),
Reserved, Reserved,
} }
use ::std::collections::HashMap;
/// Link the public key of the authentication server to a connection id /// Link the public key of the authentication server to a connection id
/// so that we can reuse that connection to ask for more authentications /// so that we can reuse that connection to ask for more authentications

View File

@ -11,7 +11,7 @@ use crate::{
}, },
packet::{self, Packet}, packet::{self, Packet},
socket::{UdpClient, UdpServer}, socket::{UdpClient, UdpServer},
Conn, ConnList, IDSend, AuthSrvConn, ConnList, Connection, IDSend, ServiceConn,
}, },
dnssec, dnssec,
enc::{ enc::{
@ -64,6 +64,7 @@ pub struct Worker {
token_check: Option<Arc<Mutex<TokenChecker>>>, token_check: Option<Arc<Mutex<TokenChecker>>>,
sockets: Vec<Arc<UdpSocket>>, sockets: Vec<Arc<UdpSocket>>,
queue: ::async_channel::Receiver<Work>, queue: ::async_channel::Receiver<Work>,
queue_sender: ::async_channel::Sender<Work>,
queue_timeouts_recv: mpsc::UnboundedReceiver<Work>, queue_timeouts_recv: mpsc::UnboundedReceiver<Work>,
queue_timeouts_send: mpsc::UnboundedSender<Work>, queue_timeouts_send: mpsc::UnboundedSender<Work>,
thread_channels: Vec<::async_channel::Sender<Work>>, thread_channels: Vec<::async_channel::Sender<Work>>,
@ -82,6 +83,7 @@ impl Worker {
token_check: Option<Arc<Mutex<TokenChecker>>>, token_check: Option<Arc<Mutex<TokenChecker>>>,
sockets: Vec<Arc<UdpSocket>>, sockets: Vec<Arc<UdpSocket>>,
queue: ::async_channel::Receiver<Work>, queue: ::async_channel::Receiver<Work>,
queue_sender: ::async_channel::Sender<Work>,
) -> ::std::io::Result<Self> { ) -> ::std::io::Result<Self> {
let (queue_timeouts_send, queue_timeouts_recv) = let (queue_timeouts_send, queue_timeouts_recv) =
mpsc::unbounded_channel(); mpsc::unbounded_channel();
@ -118,6 +120,7 @@ impl Worker {
token_check, token_check,
sockets, sockets,
queue, queue,
queue_sender,
queue_timeouts_recv, queue_timeouts_recv,
queue_timeouts_send, queue_timeouts_send,
thread_channels: Vec::new(), thread_channels: Vec::new(),
@ -293,7 +296,7 @@ impl Worker {
// are PubKey::Exchange // are PubKey::Exchange
unreachable!() unreachable!()
} }
let mut conn = Conn::new( let mut conn = Connection::new(
hkdf, hkdf,
cipher_selected, cipher_selected,
connection::Role::Client, connection::Role::Client,
@ -515,7 +518,7 @@ impl Worker {
let head_len = req.cipher.nonce_len(); let head_len = req.cipher.nonce_len();
let tag_len = req.cipher.tag_len(); let tag_len = req.cipher.tag_len();
let mut auth_conn = Conn::new( let mut auth_conn = Connection::new(
authinfo.hkdf, authinfo.hkdf,
req.cipher, req.cipher,
connection::Role::Server, connection::Role::Server,
@ -587,20 +590,32 @@ impl Worker {
); );
unreachable!(); unreachable!();
} }
let auth_srv_conn = IDSend(resp_data.id); let auth_id_send = IDSend(resp_data.id);
let mut conn = cci.connection; let mut conn = cci.connection;
conn.id_send = auth_srv_conn; conn.id_send = auth_id_send;
let id_recv = conn.id_recv; let id_recv = conn.id_recv;
let cipher = conn.cipher_recv.kind(); let cipher = conn.cipher_recv.kind();
// track the connection to the authentication server // track the connection to the authentication server
if self.connections.track(conn.into()).is_err() { let track_auth_conn =
::tracing::error!("Could not track new connection"); match self.connections.track(conn.into()) {
Ok(track_auth_conn) => track_auth_conn,
Err(e) => {
::tracing::error!(
"Could not track new auth srv connection"
);
self.connections.remove(id_recv); self.connections.remove(id_recv);
// FIXME: proper connection closing
let _ = cci.answer.send(Err( let _ = cci.answer.send(Err(
handshake::Error::InternalTracking.into(), handshake::Error::InternalTracking.into(),
)); ));
return; return;
} }
};
let authsrv_conn = AuthSrvConn(connection::Conn {
queue: self.queue_sender.clone(),
conn: track_auth_conn,
});
let mut service_conn = None;
if cci.service_id != auth::SERVICEID_AUTH { if cci.service_id != auth::SERVICEID_AUTH {
// create and track the connection to the service // create and track the connection to the service
// SECURITY: xor with secrets // SECURITY: xor with secrets
@ -611,7 +626,7 @@ impl Worker {
cci.service_id.as_bytes(), cci.service_id.as_bytes(),
resp_data.service_key, resp_data.service_key,
); );
let mut service_connection = Conn::new( let mut service_connection = Connection::new(
hkdf, hkdf,
cipher, cipher,
connection::Role::Client, connection::Role::Client,
@ -620,11 +635,38 @@ impl Worker {
service_connection.id_recv = cci.service_connection_id; service_connection.id_recv = cci.service_connection_id;
service_connection.id_send = service_connection.id_send =
IDSend(resp_data.service_connection_id); IDSend(resp_data.service_connection_id);
let _ = let track_serv_conn = match self
self.connections.track(service_connection.into()); .connections
.track(service_connection.into())
{
Ok(track_serv_conn) => track_serv_conn,
Err(e) => {
::tracing::error!(
"Could not track new service connection"
);
self.connections
.remove(cci.service_connection_id);
// FIXME: proper connection closing
// FIXME: drop auth srv connection if we just
// established it
let _ = cci.answer.send(Err(
handshake::Error::InternalTracking.into(),
));
return;
}
};
service_conn = Some(ServiceConn(connection::Conn {
queue: self.queue_sender.clone(),
conn: track_serv_conn,
}));
} }
let _ = let _ =
cci.answer.send(Ok((cci.srv_key_id, auth_srv_conn))); cci.answer.send(Ok(handshake::tracker::ConnectOk {
auth_key_id: cci.srv_key_id,
auth_id_send,
authsrv_conn,
service_conn,
}));
} }
handshake::Action::Nothing => {} handshake::Action::Nothing => {}
}; };

View File

@ -39,7 +39,7 @@ use crate::{
}, },
}; };
pub use config::Config; pub use config::Config;
pub use connection::Connection; pub use connection::{AuthSrvConn, ServiceConn};
/// Main fenrir library errors /// Main fenrir library errors
#[derive(::thiserror::Error, Debug)] #[derive(::thiserror::Error, Debug)]
@ -382,7 +382,7 @@ impl Fenrir {
&self, &self,
domain: &Domain, domain: &Domain,
service: ServiceID, service: ServiceID,
) -> Result<(), Error> { ) -> Result<(AuthSrvConn, Option<ServiceConn>), Error> {
let resolved = self.resolv(domain).await?; let resolved = self.resolv(domain).await?;
self.connect_resolved(resolved, domain, service).await self.connect_resolved(resolved, domain, service).await
} }
@ -392,7 +392,7 @@ impl Fenrir {
resolved: dnssec::Record, resolved: dnssec::Record,
domain: &Domain, domain: &Domain,
service: ServiceID, service: ServiceID,
) -> Result<(), Error> { ) -> Result<(AuthSrvConn, Option<ServiceConn>), Error> {
loop { loop {
// check if we already have a connection to that auth. srv // check if we already have a connection to that auth. srv
let is_reserved = { let is_reserved = {
@ -460,29 +460,28 @@ impl Fenrir {
.await; .await;
match recv.await { match recv.await {
Ok(res) => { Ok(res) => match res {
match res {
Err(e) => { Err(e) => {
let mut conn_auth_lock = let mut conn_auth_lock = self.conn_auth_srv.lock().await;
self.conn_auth_srv.lock().await;
conn_auth_lock.remove_reserved(&resolved); conn_auth_lock.remove_reserved(&resolved);
Err(e) Err(e)
} }
Ok((key_id, id_send)) => { Ok(connections) => {
let key = resolved let key = resolved
.public_keys .public_keys
.iter() .iter()
.find(|k| k.0 == key_id) .find(|k| k.0 == connections.auth_key_id)
.unwrap(); .unwrap();
let mut conn_auth_lock = let mut conn_auth_lock = self.conn_auth_srv.lock().await;
self.conn_auth_srv.lock().await; conn_auth_lock.add(
conn_auth_lock.add(&key.1, id_send, &resolved); &key.1,
connections.auth_id_send,
&resolved,
);
//FIXME: user needs to somehow track the connection Ok((connections.authsrv_conn, connections.service_conn))
Ok(())
}
}
} }
},
Err(e) => { Err(e) => {
// Thread dropped the sender. no more thread? // Thread dropped the sender. no more thread?
let mut conn_auth_lock = self.conn_auth_srv.lock().await; let mut conn_auth_lock = self.conn_auth_srv.lock().await;
@ -524,6 +523,7 @@ impl Fenrir {
self.token_check.clone(), self.token_check.clone(),
socks, socks,
work_recv, work_recv,
work_send.clone(),
) )
.await?; .await?;
// don't keep around private keys too much // don't keep around private keys too much
@ -547,7 +547,6 @@ impl Fenrir {
} }
Ok(worker) Ok(worker)
} }
// needs to be called before add_sockets // needs to be called before add_sockets
/// Start one working thread for each physical cpu /// Start one working thread for each physical cpu
/// threads are pinned to each cpu core. /// threads are pinned to each cpu core.
@ -589,6 +588,7 @@ impl Fenrir {
let th_tokio_rt = tokio_rt.clone(); let th_tokio_rt = tokio_rt.clone();
let th_config = self.cfg.clone(); let th_config = self.cfg.clone();
let (work_send, work_recv) = ::async_channel::unbounded::<Work>(); let (work_send, work_recv) = ::async_channel::unbounded::<Work>();
let th_work_send = work_send.clone();
let th_stop_working = self.stop_working.subscribe(); let th_stop_working = self.stop_working.subscribe();
let th_token_check = self.token_check.clone(); let th_token_check = self.token_check.clone();
let th_sockets = sockets.clone(); let th_sockets = sockets.clone();
@ -629,6 +629,7 @@ impl Fenrir {
th_token_check, th_token_check,
th_sockets, th_sockets,
work_recv, work_recv,
th_work_send,
) )
.await .await
{ {

View File

@ -88,7 +88,7 @@ async fn test_connection_dirsync() {
.connect_resolved(dnssec_record, &test_domain, auth::SERVICEID_AUTH) .connect_resolved(dnssec_record, &test_domain, auth::SERVICEID_AUTH)
.await .await
{ {
Ok(()) => {} Ok((_, _)) => {}
Err(e) => { Err(e) => {
assert!(false, "Err on client connection: {:?}", e); assert!(false, "Err on client connection: {:?}", e);
} }