diff --git a/src/connection/handshake/tracker.rs b/src/connection/handshake/tracker.rs index 30a6a8f..c40158e 100644 --- a/src/connection/handshake/tracker.rs +++ b/src/connection/handshake/tracker.rs @@ -3,8 +3,9 @@ use crate::{ auth::{Domain, ServiceID}, connection::{ + self, handshake::{self, Error, Handshake}, - Conn, IDRecv, IDSend, + Connection, IDRecv, IDSend, }, enc::{ self, @@ -18,20 +19,27 @@ use crate::{ use ::tokio::sync::oneshot; pub(crate) struct Server { - pub id: KeyID, - pub key: PrivKey, - pub domains: Vec, + pub(crate) id: KeyID, + pub(crate) key: PrivKey, + pub(crate) domains: Vec, } -pub(crate) type ConnectAnswer = Result<(KeyID, IDSend), crate::Error>; +pub(crate) type ConnectAnswer = Result; +#[derive(Debug)] +pub(crate) struct ConnectOk { + pub(crate) auth_key_id: KeyID, + pub(crate) auth_id_send: IDSend, + pub(crate) authsrv_conn: connection::AuthSrvConn, + pub(crate) service_conn: Option, +} pub(crate) struct Client { - pub service_id: ServiceID, - pub service_conn_id: IDRecv, - pub connection: Conn, - pub timeout: Option<::tokio::task::JoinHandle<()>>, - pub answer: oneshot::Sender, - pub srv_key_id: KeyID, + pub(crate) service_id: ServiceID, + pub(crate) service_conn_id: IDRecv, + pub(crate) connection: Connection, + pub(crate) timeout: Option<::tokio::task::JoinHandle<()>>, + pub(crate) answer: oneshot::Sender, + pub(crate) srv_key_id: KeyID, } /// Tracks the keys used by the client and the handshake @@ -78,7 +86,7 @@ impl ClientList { pub_key: PubKey, service_id: ServiceID, service_conn_id: IDRecv, - connection: Conn, + connection: Connection, answer: oneshot::Sender, srv_key_id: KeyID, ) -> Result<(KeyID, &mut Client), oneshot::Sender> { @@ -128,26 +136,26 @@ impl ClientList { #[derive(Debug, Clone)] pub(crate) struct AuthNeededInfo { /// Parsed handshake packet - pub handshake: Handshake, + pub(crate) handshake: Handshake, /// hkdf generated from the handshake - pub hkdf: Hkdf, + pub(crate) hkdf: Hkdf, } /// Client information needed to fully establish the conenction #[derive(Debug)] pub(crate) struct ClientConnectInfo { /// The service ID that we are connecting to - pub service_id: ServiceID, + pub(crate) service_id: ServiceID, /// The service ID that we are connecting to - pub service_connection_id: IDRecv, + pub(crate) service_connection_id: IDRecv, /// Parsed handshake packet - pub handshake: Handshake, - /// Conn - pub connection: Conn, + pub(crate) handshake: Handshake, + /// Connection + pub(crate) connection: Connection, /// where to wake up the waiting client - pub answer: oneshot::Sender, - /// server public key id that we used on the handshake - pub srv_key_id: KeyID, + pub(crate) answer: oneshot::Sender, + /// server pub(crate)lic key id that we used on the handshake + pub(crate) srv_key_id: KeyID, } /// Intermediate actions to be taken while parsing the handshake #[derive(Debug)] @@ -231,7 +239,7 @@ impl Tracker { pub_key: PubKey, service_id: ServiceID, service_conn_id: IDRecv, - connection: Conn, + connection: Connection, answer: oneshot::Sender, srv_key_id: KeyID, ) -> Result<(KeyID, &mut Client), oneshot::Sender> { diff --git a/src/connection/mod.rs b/src/connection/mod.rs index f6c0a86..58853c7 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -5,7 +5,7 @@ pub mod packet; pub mod socket; pub mod stream; -use ::std::{rc::Rc, vec::Vec}; +use ::std::{collections::HashMap, rc::Rc, vec::Vec}; pub use crate::connection::{handshake::Handshake, packet::Packet}; @@ -17,9 +17,8 @@ use crate::{ sym::{self, CipherRecv, CipherSend}, Random, }, - inner::ThreadTracker, + inner::{worker, ThreadTracker}, }; -use ::std::rc; /// Fenrir Connection ID /// @@ -126,13 +125,40 @@ impl ProtocolVersion { } } +#[derive(Debug, Clone, Copy, Eq, Hash, PartialEq)] +pub(crate) struct UserConnTracker(usize); +impl UserConnTracker { + fn advance(&mut self) -> Self { + let old = self.0; + self.0 = self.0 + 1; + UserConnTracker(old) + } +} + +/// Connection to an Authentication Server +#[derive(Debug)] +pub struct AuthSrvConn(pub(crate) Conn); +/// Connection to a service +#[derive(Debug)] +pub struct ServiceConn(pub(crate) Conn); + /// The connection, as seen from a user of libFenrir #[derive(Debug)] -pub struct Connection(rc::Weak); +pub struct Conn { + pub(crate) queue: ::async_channel::Sender, + pub(crate) conn: UserConnTracker, +} + +impl Conn { + /// Queue some data to be sent in this connection + pub fn send(&mut self, stream: stream::ID, _data: Vec) { + todo!() + } +} /// A single connection and its data #[derive(Debug)] -pub(crate) struct Conn { +pub(crate) struct Connection { /// Receiving Conn ID pub id_recv: IDRecv, /// Sending Conn ID @@ -160,7 +186,7 @@ pub enum Role { Client, } -impl Conn { +impl Connection { pub(crate) fn new( hkdf: Hkdf, cipher: sym::Kind, @@ -190,7 +216,9 @@ impl Conn { pub(crate) struct ConnList { thread_id: ThreadTracker, - connections: Vec>>, + connections: Vec>>, + user_tracker: HashMap, + last_tracked: UserConnTracker, /// Bitmap to track which connection ids are used or free ids_used: Vec<::bitmaps::Bitmap<1024>>, } @@ -206,6 +234,8 @@ impl ConnList { let mut ret = Self { thread_id, connections: Vec::with_capacity(INITIAL_CAP), + user_tracker: HashMap::with_capacity(INITIAL_CAP), + last_tracked: UserConnTracker(0), ids_used: vec![bitmap_id], }; ret.connections.resize_with(INITIAL_CAP, || None); @@ -261,7 +291,10 @@ impl ConnList { new_id } /// NOTE: does NOT check if the connection has been previously reserved! - pub(crate) fn track(&mut self, conn: Rc) -> Result<(), ()> { + pub(crate) fn track( + &mut self, + conn: Rc, + ) -> Result { let conn_id = match conn.id_recv { IDRecv(ID::Handshake) => { return Err(()); @@ -271,7 +304,9 @@ impl ConnList { let id_in_thread: usize = (conn_id.get() / (self.thread_id.total as u64)) as usize; self.connections[id_in_thread] = Some(conn); - Ok(()) + let tracked = self.last_tracked.advance(); + let _ = self.user_tracker.insert(tracked, id_in_thread); + Ok(tracked) } pub(crate) fn remove(&mut self, id: IDRecv) { if let IDRecv(ID::ID(raw_id)) = id { @@ -303,7 +338,6 @@ enum MapEntry { Present(IDSend), Reserved, } -use ::std::collections::HashMap; /// Link the public key of the authentication server to a connection id /// so that we can reuse that connection to ask for more authentications diff --git a/src/inner/worker.rs b/src/inner/worker.rs index 27dd92a..29ad7f7 100644 --- a/src/inner/worker.rs +++ b/src/inner/worker.rs @@ -11,7 +11,7 @@ use crate::{ }, packet::{self, Packet}, socket::{UdpClient, UdpServer}, - Conn, ConnList, IDSend, + AuthSrvConn, ConnList, Connection, IDSend, ServiceConn, }, dnssec, enc::{ @@ -64,6 +64,7 @@ pub struct Worker { token_check: Option>>, sockets: Vec>, queue: ::async_channel::Receiver, + queue_sender: ::async_channel::Sender, queue_timeouts_recv: mpsc::UnboundedReceiver, queue_timeouts_send: mpsc::UnboundedSender, thread_channels: Vec<::async_channel::Sender>, @@ -82,6 +83,7 @@ impl Worker { token_check: Option>>, sockets: Vec>, queue: ::async_channel::Receiver, + queue_sender: ::async_channel::Sender, ) -> ::std::io::Result { let (queue_timeouts_send, queue_timeouts_recv) = mpsc::unbounded_channel(); @@ -118,6 +120,7 @@ impl Worker { token_check, sockets, queue, + queue_sender, queue_timeouts_recv, queue_timeouts_send, thread_channels: Vec::new(), @@ -293,7 +296,7 @@ impl Worker { // are PubKey::Exchange unreachable!() } - let mut conn = Conn::new( + let mut conn = Connection::new( hkdf, cipher_selected, connection::Role::Client, @@ -515,7 +518,7 @@ impl Worker { let head_len = req.cipher.nonce_len(); let tag_len = req.cipher.tag_len(); - let mut auth_conn = Conn::new( + let mut auth_conn = Connection::new( authinfo.hkdf, req.cipher, connection::Role::Server, @@ -587,20 +590,32 @@ impl Worker { ); unreachable!(); } - let auth_srv_conn = IDSend(resp_data.id); + let auth_id_send = IDSend(resp_data.id); let mut conn = cci.connection; - conn.id_send = auth_srv_conn; + conn.id_send = auth_id_send; let id_recv = conn.id_recv; let cipher = conn.cipher_recv.kind(); // track the connection to the authentication server - if self.connections.track(conn.into()).is_err() { - ::tracing::error!("Could not track new connection"); - self.connections.remove(id_recv); - let _ = cci.answer.send(Err( - handshake::Error::InternalTracking.into(), - )); - return; - } + let track_auth_conn = + match self.connections.track(conn.into()) { + Ok(track_auth_conn) => track_auth_conn, + Err(e) => { + ::tracing::error!( + "Could not track new auth srv connection" + ); + self.connections.remove(id_recv); + // FIXME: proper connection closing + let _ = cci.answer.send(Err( + handshake::Error::InternalTracking.into(), + )); + return; + } + }; + let authsrv_conn = AuthSrvConn(connection::Conn { + queue: self.queue_sender.clone(), + conn: track_auth_conn, + }); + let mut service_conn = None; if cci.service_id != auth::SERVICEID_AUTH { // create and track the connection to the service // SECURITY: xor with secrets @@ -611,7 +626,7 @@ impl Worker { cci.service_id.as_bytes(), resp_data.service_key, ); - let mut service_connection = Conn::new( + let mut service_connection = Connection::new( hkdf, cipher, connection::Role::Client, @@ -620,11 +635,38 @@ impl Worker { service_connection.id_recv = cci.service_connection_id; service_connection.id_send = IDSend(resp_data.service_connection_id); - let _ = - self.connections.track(service_connection.into()); + let track_serv_conn = match self + .connections + .track(service_connection.into()) + { + Ok(track_serv_conn) => track_serv_conn, + Err(e) => { + ::tracing::error!( + "Could not track new service connection" + ); + self.connections + .remove(cci.service_connection_id); + // FIXME: proper connection closing + // FIXME: drop auth srv connection if we just + // established it + let _ = cci.answer.send(Err( + handshake::Error::InternalTracking.into(), + )); + return; + } + }; + service_conn = Some(ServiceConn(connection::Conn { + queue: self.queue_sender.clone(), + conn: track_serv_conn, + })); } let _ = - cci.answer.send(Ok((cci.srv_key_id, auth_srv_conn))); + cci.answer.send(Ok(handshake::tracker::ConnectOk { + auth_key_id: cci.srv_key_id, + auth_id_send, + authsrv_conn, + service_conn, + })); } handshake::Action::Nothing => {} }; diff --git a/src/lib.rs b/src/lib.rs index 57da614..a0ee7ae 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -39,7 +39,7 @@ use crate::{ }, }; pub use config::Config; -pub use connection::Connection; +pub use connection::{AuthSrvConn, ServiceConn}; /// Main fenrir library errors #[derive(::thiserror::Error, Debug)] @@ -382,7 +382,7 @@ impl Fenrir { &self, domain: &Domain, service: ServiceID, - ) -> Result<(), Error> { + ) -> Result<(AuthSrvConn, Option), Error> { let resolved = self.resolv(domain).await?; self.connect_resolved(resolved, domain, service).await } @@ -392,7 +392,7 @@ impl Fenrir { resolved: dnssec::Record, domain: &Domain, service: ServiceID, - ) -> Result<(), Error> { + ) -> Result<(AuthSrvConn, Option), Error> { loop { // check if we already have a connection to that auth. srv let is_reserved = { @@ -460,29 +460,28 @@ impl Fenrir { .await; match recv.await { - Ok(res) => { - match res { - Err(e) => { - let mut conn_auth_lock = - self.conn_auth_srv.lock().await; - conn_auth_lock.remove_reserved(&resolved); - Err(e) - } - Ok((key_id, id_send)) => { - let key = resolved - .public_keys - .iter() - .find(|k| k.0 == key_id) - .unwrap(); - let mut conn_auth_lock = - self.conn_auth_srv.lock().await; - conn_auth_lock.add(&key.1, id_send, &resolved); - - //FIXME: user needs to somehow track the connection - Ok(()) - } + Ok(res) => match res { + Err(e) => { + let mut conn_auth_lock = self.conn_auth_srv.lock().await; + conn_auth_lock.remove_reserved(&resolved); + Err(e) } - } + Ok(connections) => { + let key = resolved + .public_keys + .iter() + .find(|k| k.0 == connections.auth_key_id) + .unwrap(); + let mut conn_auth_lock = self.conn_auth_srv.lock().await; + conn_auth_lock.add( + &key.1, + connections.auth_id_send, + &resolved, + ); + + Ok((connections.authsrv_conn, connections.service_conn)) + } + }, Err(e) => { // Thread dropped the sender. no more thread? let mut conn_auth_lock = self.conn_auth_srv.lock().await; @@ -524,6 +523,7 @@ impl Fenrir { self.token_check.clone(), socks, work_recv, + work_send.clone(), ) .await?; // don't keep around private keys too much @@ -547,7 +547,6 @@ impl Fenrir { } Ok(worker) } - // needs to be called before add_sockets /// Start one working thread for each physical cpu /// threads are pinned to each cpu core. @@ -589,6 +588,7 @@ impl Fenrir { let th_tokio_rt = tokio_rt.clone(); let th_config = self.cfg.clone(); let (work_send, work_recv) = ::async_channel::unbounded::(); + let th_work_send = work_send.clone(); let th_stop_working = self.stop_working.subscribe(); let th_token_check = self.token_check.clone(); let th_sockets = sockets.clone(); @@ -629,6 +629,7 @@ impl Fenrir { th_token_check, th_sockets, work_recv, + th_work_send, ) .await { diff --git a/src/tests.rs b/src/tests.rs index 4e2a0b6..ff28ae6 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -88,7 +88,7 @@ async fn test_connection_dirsync() { .connect_resolved(dnssec_record, &test_domain, auth::SERVICEID_AUTH) .await { - Ok(()) => {} + Ok((_, _)) => {} Err(e) => { assert!(false, "Err on client connection: {:?}", e); }