Test and fix shutdowns

we have a Quick but partial shutdown, which lets the async "threads"
work in the background and shutdown after a bit more time

and the graceful/full shutdown, which waits for everything.

Unfortunately `Drop` can't manage async and blocks everything,
no way to yeld either, so if we only have a thread
we would deadlock if we tried to stop things gracefully

Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
This commit is contained in:
Luca Fulchir 2023-06-11 22:45:40 +02:00
parent aff1c313f5
commit b682068dca
Signed by: luca.fulchir
GPG Key ID: 8F6440603D13A78E
6 changed files with 311 additions and 285 deletions

View File

@ -50,6 +50,7 @@ tokio = { version = "1", features = ["full"] }
# PERF: todo linux-only, behind "iouring" feature # PERF: todo linux-only, behind "iouring" feature
#tokio-uring = { version = "0.4" } #tokio-uring = { version = "0.4" }
tracing = { version = "0.1" } tracing = { version = "0.1" }
tracing-test = { version = "0.2" }
trust-dns-resolver = { version = "0.22", features = [ "dnssec-ring" ] } trust-dns-resolver = { version = "0.22", features = [ "dnssec-ring" ] }
trust-dns-client = { version = "0.22", features = [ "dnssec" ] } trust-dns-client = { version = "0.22", features = [ "dnssec" ] }
trust-dns-proto = { version = "0.22" } trust-dns-proto = { version = "0.22" }
@ -72,3 +73,14 @@ incremental = true
codegen-units = 256 codegen-units = 256
rpath = false rpath = false
[profile.test]
opt-level = 0
debug = true
debug-assertions = true
overflow-checks = true
lto = false
panic = 'unwind'
incremental = true
codegen-units = 256
rpath = false

View File

@ -141,7 +141,7 @@ pub(crate) struct ClientConnectInfo {
#[derive(Debug)] #[derive(Debug)]
pub(crate) enum HandshakeAction { pub(crate) enum HandshakeAction {
/// Parsing finished, all ok, nothing to do /// Parsing finished, all ok, nothing to do
Nonthing, Nothing,
/// Packet parsed, now go perform authentication /// Packet parsed, now go perform authentication
AuthNeeded(AuthNeededInfo), AuthNeeded(AuthNeededInfo),
/// the client can fully establish a connection with this info /// the client can fully establish a connection with this info
@ -155,7 +155,7 @@ pub(crate) enum HandshakeAction {
/// core = (udp_src_sender_port % total_threads) - 1 /// core = (udp_src_sender_port % total_threads) - 1
pub(crate) struct HandshakeTracker { pub(crate) struct HandshakeTracker {
thread_id: ThreadTracker, thread_id: ThreadTracker,
key_exchanges: Vec<(asym::KeyKind, asym::KeyExchangeKind)>, key_exchanges: Vec<asym::KeyExchangeKind>,
ciphers: Vec<CipherKind>, ciphers: Vec<CipherKind>,
/// ephemeral keys used server side in key exchange /// ephemeral keys used server side in key exchange
keys_srv: Vec<HandshakeServer>, keys_srv: Vec<HandshakeServer>,
@ -164,16 +164,24 @@ pub(crate) struct HandshakeTracker {
} }
impl HandshakeTracker { impl HandshakeTracker {
pub(crate) fn new(thread_id: ThreadTracker) -> Self { pub(crate) fn new(
thread_id: ThreadTracker,
ciphers: Vec<CipherKind>,
key_exchanges: Vec<asym::KeyExchangeKind>,
) -> Self {
Self { Self {
thread_id, thread_id,
ciphers: Vec::new(), ciphers,
key_exchanges: Vec::new(), key_exchanges,
keys_srv: Vec::new(), keys_srv: Vec::new(),
hshake_cli: HandshakeClientList::new(), hshake_cli: HandshakeClientList::new(),
} }
} }
pub(crate) fn new_client( pub(crate) fn add_server(&mut self, id: KeyID, key: PrivKey) {
self.keys_srv.push(HandshakeServer { id, key });
self.keys_srv.sort_by(|h_a, h_b| h_a.id.0.cmp(&h_b.id.0));
}
pub(crate) fn add_client(
&mut self, &mut self,
priv_key: PrivKey, priv_key: PrivKey,
pub_key: PubKey, pub_key: PubKey,
@ -208,45 +216,34 @@ impl HandshakeTracker {
match handshake.data { match handshake.data {
HandshakeData::DirSync(ref mut ds) => match ds { HandshakeData::DirSync(ref mut ds) => match ds {
DirSync::Req(ref mut req) => { DirSync::Req(ref mut req) => {
let ephemeral_key = { if !self.key_exchanges.contains(&req.exchange) {
if let Some(h_k) = return Err(enc::Error::UnsupportedKeyExchange.into());
self.keys_srv.iter().find(|k| k.id == req.key_id) }
{ if !self.ciphers.contains(&req.cipher) {
return Err(enc::Error::UnsupportedCipher.into());
}
let has_key = self.keys_srv.iter().find(|k| {
if k.id == req.key_id {
// Directory synchronized can only use keys // Directory synchronized can only use keys
// for key exchange, not signing keys // for key exchange, not signing keys
if let PrivKey::Exchange(k) = &h_k.key { if let PrivKey::Exchange(_) = k.key {
Some(k.clone()) return true;
}
}
false
});
let ephemeral_key;
match has_key {
Some(s_k) => {
if let PrivKey::Exchange(ref k) = &s_k.key {
ephemeral_key = k;
} else { } else {
None unreachable!();
}
} else {
None
}
};
if ephemeral_key.is_none() {
::tracing::debug!(
"No such server key id: {:?}",
req.key_id
);
return Err(handshake::Error::UnknownKeyID.into());
}
let ephemeral_key = ephemeral_key.unwrap();
{
if None
== self.key_exchanges.iter().find(|&x| {
*x == (ephemeral_key.kind(), req.exchange)
})
{
return Err(
enc::Error::UnsupportedKeyExchange.into()
);
} }
} }
{ None => {
if None return Err(handshake::Error::UnknownKeyID.into())
== self.ciphers.iter().find(|&x| *x == req.cipher)
{
return Err(enc::Error::UnsupportedCipher.into());
} }
} }
let shared_key = match ephemeral_key let shared_key = match ephemeral_key

View File

@ -1,6 +1,5 @@
//! Socket related types and functions //! Socket related types and functions
use ::arc_swap::ArcSwap;
use ::std::{net::SocketAddr, sync::Arc, vec::Vec}; use ::std::{net::SocketAddr, sync::Arc, vec::Vec};
use ::tokio::{net::UdpSocket, task::JoinHandle}; use ::tokio::{net::UdpSocket, task::JoinHandle};
@ -10,82 +9,31 @@ pub type SocketTracker =
/// async free socket list /// async free socket list
pub(crate) struct SocketList { pub(crate) struct SocketList {
pub list: ArcSwap<Vec<SocketTracker>>, pub list: Vec<SocketTracker>,
} }
impl SocketList { impl SocketList {
pub(crate) fn new() -> Self { pub(crate) fn new() -> Self {
Self { Self { list: Vec::new() }
list: ArcSwap::new(Arc::new(Vec::new())),
}
}
// TODO: fn rm_socket()
pub(crate) fn rm_all(&self) -> Self {
let new_list = Arc::new(Vec::new());
let old_list = self.list.swap(new_list);
Self {
list: old_list.into(),
} }
pub(crate) fn rm_all(&mut self) -> Self {
let mut old_list = Vec::new();
::core::mem::swap(&mut self.list, &mut old_list);
Self { list: old_list }
} }
pub(crate) async fn add_socket( pub(crate) async fn add_socket(
&self, &mut self,
socket: Arc<UdpSocket>, socket: Arc<UdpSocket>,
handle: JoinHandle<::std::io::Result<()>>, handle: JoinHandle<::std::io::Result<()>>,
) { ) {
// we could simplify this into just a `.swap` instead of `.rcu` but
// it is not yet guaranteed that only one thread will call this fn
// ...we don't need performance here anyway
let arc_handle = Arc::new(handle); let arc_handle = Arc::new(handle);
self.list.rcu(|old_list| { self.list.push((socket, arc_handle));
let mut new_list = Arc::new(Vec::with_capacity(old_list.len() + 1));
new_list = old_list.to_vec().into();
Arc::get_mut(&mut new_list)
.unwrap()
.push((socket.clone(), arc_handle.clone()));
new_list
});
} }
/// This method assumes no other `add_sockets` are being run /// This method assumes no other `add_sockets` are being run
pub(crate) async fn stop_all(self) { pub(crate) async fn stop_all(self) {
let mut arc_list = self.list.into_inner(); for (_socket, mut handle) in self.list.into_iter() {
let list = loop {
match Arc::try_unwrap(arc_list) {
Ok(list) => break list,
Err(arc_retry) => {
arc_list = arc_retry;
::tokio::time::sleep(::core::time::Duration::from_millis(
50,
))
.await;
}
}
};
for (_socket, mut handle) in list.into_iter() {
let _ = Arc::get_mut(&mut handle).unwrap().await; let _ = Arc::get_mut(&mut handle).unwrap().await;
} }
} }
pub(crate) fn lock(&self) -> SocketListRef {
SocketListRef {
list: self.list.load_full(),
}
}
}
/// Reference to a locked SocketList
// TODO: impl Drop for SocketList
pub(crate) struct SocketListRef {
list: Arc<Vec<SocketTracker>>,
}
impl SocketListRef {
pub(crate) fn find(&self, sock: UdpServer) -> Option<Arc<UdpSocket>> {
match self
.list
.iter()
.find(|&(s, _)| s.local_addr().unwrap() == sock.0)
{
Some((sock_srv, _)) => Some(sock_srv.clone()),
None => None,
}
}
} }
/// Strong typedef for a client socket address /// Strong typedef for a client socket address

View File

@ -62,7 +62,7 @@ pub(crate) struct Worker {
thread_id: ThreadTracker, thread_id: ThreadTracker,
// PERF: rand uses syscalls. how to do that async? // PERF: rand uses syscalls. how to do that async?
rand: Random, rand: Random,
stop_working: ::tokio::sync::broadcast::Receiver<bool>, stop_working: crate::StopWorkingRecvCh,
token_check: Option<Arc<Mutex<TokenChecker>>>, token_check: Option<Arc<Mutex<TokenChecker>>>,
sockets: Vec<UdpSocket>, sockets: Vec<UdpSocket>,
queue: ::async_channel::Receiver<Work>, queue: ::async_channel::Receiver<Work>,
@ -77,7 +77,7 @@ impl Worker {
pub(crate) async fn new_and_loop( pub(crate) async fn new_and_loop(
cfg: Config, cfg: Config,
thread_id: ThreadTracker, thread_id: ThreadTracker,
stop_working: ::tokio::sync::broadcast::Receiver<bool>, stop_working: crate::StopWorkingRecvCh,
token_check: Option<Arc<Mutex<TokenChecker>>>, token_check: Option<Arc<Mutex<TokenChecker>>>,
socket_addrs: Vec<::std::net::SocketAddr>, socket_addrs: Vec<::std::net::SocketAddr>,
queue: ::async_channel::Receiver<Work>, queue: ::async_channel::Receiver<Work>,
@ -96,9 +96,9 @@ impl Worker {
Ok(()) Ok(())
} }
pub(crate) async fn new( pub(crate) async fn new(
cfg: Config, mut cfg: Config,
thread_id: ThreadTracker, thread_id: ThreadTracker,
stop_working: ::tokio::sync::broadcast::Receiver<bool>, stop_working: crate::StopWorkingRecvCh,
token_check: Option<Arc<Mutex<TokenChecker>>>, token_check: Option<Arc<Mutex<TokenChecker>>>,
socket_addrs: Vec<::std::net::SocketAddr>, socket_addrs: Vec<::std::net::SocketAddr>,
queue: ::async_channel::Receiver<Work>, queue: ::async_channel::Receiver<Work>,
@ -108,36 +108,43 @@ impl Worker {
// in the future we will want to have a thread-local listener too, // in the future we will want to have a thread-local listener too,
// but before that we need ebpf to pin a connection to a thread // but before that we need ebpf to pin a connection to a thread
// directly from the kernel // directly from the kernel
let socket_binding = let mut sock_set = ::tokio::task::JoinSet::new();
socket_addrs.into_iter().map(|s_addr| async move { socket_addrs.into_iter().for_each(|s_addr| {
let socket = ::tokio::spawn(connection::socket::bind_udp( sock_set.spawn(async move {
s_addr.clone(), let socket =
)) connection::socket::bind_udp(s_addr.clone()).await?;
.await??;
Ok(socket) Ok(socket)
}); });
let sockets_bind_res = });
::futures::future::join_all(socket_binding).await; // make sure we either add all of them, or none
let sockets: Result<Vec<UdpSocket>, ::std::io::Error> = let mut sockets = Vec::with_capacity(cfg.listen.len());
sockets_bind_res while let Some(join_res) = sock_set.join_next().await {
.into_iter() match join_res {
.map(|s_res| match s_res { Ok(s_res) => match s_res {
Ok(s) => Ok(s), Ok(sock) => sockets.push(sock),
Err(e) => {
::tracing::error!("Worker can't bind on socket: {}", e);
Err(e)
}
})
.collect();
let sockets = match sockets {
Ok(sockets) => sockets,
Err(e) => { Err(e) => {
::tracing::error!("Can't rebind socket");
return Err(e); return Err(e);
} }
}; },
Err(e) => return Err(e.into()),
}
}
let (queue_timeouts_send, queue_timeouts_recv) = let (queue_timeouts_send, queue_timeouts_recv) =
mpsc::unbounded_channel(); mpsc::unbounded_channel();
let mut handshakes = HandshakeTracker::new(
thread_id,
cfg.ciphers.clone(),
cfg.key_exchanges.clone(),
);
let mut keys = Vec::new();
// make sure the keys are no longer in the config
::core::mem::swap(&mut keys, &mut cfg.keys);
for k in keys.into_iter() {
handshakes.add_server(k.0, k.1);
}
Ok(Self { Ok(Self {
cfg, cfg,
thread_id, thread_id,
@ -150,13 +157,15 @@ impl Worker {
queue_timeouts_send, queue_timeouts_send,
thread_channels: Vec::new(), thread_channels: Vec::new(),
connections: ConnList::new(thread_id), connections: ConnList::new(thread_id),
handshakes: HandshakeTracker::new(thread_id), handshakes,
}) })
} }
pub(crate) async fn work_loop(&mut self) { pub(crate) async fn work_loop(&mut self) {
'mainloop: loop { 'mainloop: loop {
let work = ::tokio::select! { let work = ::tokio::select! {
_done = self.stop_working.recv() => { tell_stopped = self.stop_working.recv() => {
let _ = tell_stopped.unwrap().send(
crate::StopWorking::WorkerStopped).await;
break; break;
} }
maybe_timeout = self.queue.recv() => { maybe_timeout = self.queue.recv() => {
@ -326,7 +335,7 @@ impl Worker {
conn.id_recv = auth_recv_id; conn.id_recv = auth_recv_id;
let (client_key_id, hshake) = match self let (client_key_id, hshake) = match self
.handshakes .handshakes
.new_client( .add_client(
PrivKey::Exchange(priv_key), PrivKey::Exchange(priv_key),
PubKey::Exchange(pub_key), PubKey::Exchange(pub_key),
conn_info.service_id, conn_info.service_id,
@ -617,7 +626,7 @@ impl Worker {
IDSend(resp_data.service_connection_id); IDSend(resp_data.service_connection_id);
let _ = self.connections.track(service_connection.into()); let _ = self.connections.track(service_connection.into());
} }
HandshakeAction::Nonthing => {} HandshakeAction::Nothing => {}
}; };
} }
} }

View File

@ -69,6 +69,17 @@ pub enum Error {
Encrypt(enc::Error), Encrypt(enc::Error),
} }
pub(crate) enum StopWorking {
WorkerStopped,
ListenerStopped,
}
pub(crate) type StopWorkingSendCh =
::tokio::sync::broadcast::Sender<::tokio::sync::mpsc::Sender<StopWorking>>;
pub(crate) type StopWorkingRecvCh = ::tokio::sync::broadcast::Receiver<
::tokio::sync::mpsc::Sender<StopWorking>,
>;
/// Instance of a fenrir endpoint /// Instance of a fenrir endpoint
#[allow(missing_copy_implementations, missing_debug_implementations)] #[allow(missing_copy_implementations, missing_debug_implementations)]
pub struct Fenrir { pub struct Fenrir {
@ -79,7 +90,7 @@ pub struct Fenrir {
/// DNSSEC resolver, with failovers /// DNSSEC resolver, with failovers
dnssec: dnssec::Dnssec, dnssec: dnssec::Dnssec,
/// Broadcast channel to tell workers to stop working /// Broadcast channel to tell workers to stop working
stop_working: ::tokio::sync::broadcast::Sender<bool>, stop_working: StopWorkingSendCh,
/// where to ask for token check /// where to ask for token check
token_check: Option<Arc<::tokio::sync::Mutex<TokenChecker>>>, token_check: Option<Arc<::tokio::sync::Mutex<TokenChecker>>>,
/// tracks the connections to authentication servers /// tracks the connections to authentication servers
@ -89,22 +100,74 @@ pub struct Fenrir {
// manner // manner
_thread_pool: Vec<::std::thread::JoinHandle<()>>, _thread_pool: Vec<::std::thread::JoinHandle<()>>,
_thread_work: Arc<Vec<::async_channel::Sender<Work>>>, _thread_work: Arc<Vec<::async_channel::Sender<Work>>>,
// This can be different from cfg.listen since using port 0 will result
// in a random port assigned by the operative system
_listen_addrs: Vec<::std::net::SocketAddr>,
} }
// TODO: graceful vs immediate stop // TODO: graceful vs immediate stop
impl Drop for Fenrir { impl Drop for Fenrir {
fn drop(&mut self) { fn drop(&mut self) {
self.stop_sync() ::tracing::debug!(
"Fenrir fast shutdown.\
Some threads might remain a bit longer"
);
let _ = self.stop_sync();
} }
} }
impl Fenrir { impl Fenrir {
/// Gracefully stop all listeners and workers
/// only return when all resources have been deallocated
pub async fn graceful_stop(mut self) {
::tracing::debug!("Fenrir full shut down");
if let Some((ch, listeners, workers)) = self.stop_sync() {
self.stop_wait(ch, listeners, workers).await;
}
}
fn stop_sync(
&mut self,
) -> Option<(::tokio::sync::mpsc::Receiver<StopWorking>, usize, usize)>
{
let listeners_num = self.sockets.list.len();
let workers_num = self._thread_work.len();
if self.sockets.list.len() > 0 || self._thread_work.len() > 0 {
let (ch_send, ch_recv) = ::tokio::sync::mpsc::channel(4);
let _ = self.stop_working.send(ch_send);
let _ = self.sockets.rm_all();
self._thread_pool.clear();
Some((ch_recv, listeners_num, workers_num))
} else {
None
}
}
async fn stop_wait(
&mut self,
mut ch: ::tokio::sync::mpsc::Receiver<StopWorking>,
mut listeners_num: usize,
mut workers_num: usize,
) {
while listeners_num > 0 && workers_num > 0 {
match ch.recv().await {
Some(stopped) => match stopped {
StopWorking::WorkerStopped => workers_num = workers_num - 1,
StopWorking::ListenerStopped => {
listeners_num = listeners_num - 1
}
},
_ => break,
}
}
}
/// Create a new Fenrir endpoint /// Create a new Fenrir endpoint
pub fn new(config: &Config) -> Result<Self, Error> { /// spawn threads pinned to cpus in our own way with tokio's runtime
pub async fn with_threads(
config: &Config,
tokio_rt: Arc<::tokio::runtime::Runtime>,
) -> Result<Self, Error> {
let (sender, _) = ::tokio::sync::broadcast::channel(1); let (sender, _) = ::tokio::sync::broadcast::channel(1);
let dnssec = dnssec::Dnssec::new(&config.resolvers)?; let dnssec = dnssec::Dnssec::new(&config.resolvers)?;
let endpoint = Fenrir { let mut endpoint = Self {
cfg: config.clone(), cfg: config.clone(),
sockets: SocketList::new(), sockets: SocketList::new(),
dnssec, dnssec,
@ -113,86 +176,120 @@ impl Fenrir {
conn_auth_srv: Mutex::new(AuthServerConnections::new()), conn_auth_srv: Mutex::new(AuthServerConnections::new()),
_thread_pool: Vec::new(), _thread_pool: Vec::new(),
_thread_work: Arc::new(Vec::new()), _thread_work: Arc::new(Vec::new()),
_listen_addrs: Vec::with_capacity(config.listen.len()),
}; };
endpoint.start_work_threads_pinned(tokio_rt).await?;
match endpoint.add_sockets().await {
Ok(addrs) => endpoint._listen_addrs = addrs,
Err(e) => return Err(e.into()),
}
Ok(endpoint) Ok(endpoint)
} }
/// Create a new Fenrir endpoint
///FIXME: remove this, move into new() /// Get the workers that you can use in a tokio LocalSet
/// Start all workers, listeners /// You should:
pub async fn start( /// * move these workers each in its own thread
&mut self, /// * make sure that the threads are pinned on the cpu
tokio_rt: Arc<::tokio::runtime::Runtime>, pub async fn with_workers(
) -> Result<(), Error> { config: &Config,
self.start_work_threads_pinned(tokio_rt).await?; ) -> Result<
if let Err(e) = self.add_sockets().await { (
self.stop().await; Self,
return Err(e.into()); Vec<impl futures::Future<Output = Result<(), std::io::Error>>>,
),
Error,
> {
let (stop_working, _) = ::tokio::sync::broadcast::channel(1);
let dnssec = dnssec::Dnssec::new(&config.resolvers)?;
let cfg = config.clone();
let sockets = SocketList::new();
let conn_auth_srv = Mutex::new(AuthServerConnections::new());
let thread_pool = Vec::new();
let thread_work = Arc::new(Vec::new());
let listen_addrs = Vec::with_capacity(config.listen.len());
let mut endpoint = Self {
cfg,
sockets,
dnssec,
stop_working: stop_working.clone(),
token_check: None,
conn_auth_srv,
_thread_pool: thread_pool,
_thread_work: thread_work,
_listen_addrs: listen_addrs,
};
let worker_num = config.threads.unwrap().get();
let mut workers = Vec::with_capacity(worker_num);
for _ in 0..worker_num {
workers.push(endpoint.start_single_worker().await?);
} }
Ok(()) match endpoint.add_sockets().await {
Ok(addrs) => endpoint._listen_addrs = addrs,
Err(e) => return Err(e.into()),
} }
///FIXME: remove this, move into new() Ok((endpoint, workers))
pub async fn setup_no_workers(&mut self) -> Result<(), Error> {
if let Err(e) = self.add_sockets().await {
self.stop().await;
return Err(e.into());
} }
Ok(()) /// Returns the list of the actual addresses we are listening on
/// Note that this can be different from what was configured:
/// if you specified UDP port 0 a random one has been assigned to you
/// by the operating system.
pub fn addresses(&self) -> Vec<::std::net::SocketAddr> {
self._listen_addrs.clone()
} }
/// Stop all workers, listeners // only call **after** starting all threads
/// asyncronous version for Drop
fn stop_sync(&mut self) {
let _ = self.stop_working.send(true);
let toempty_sockets = self.sockets.rm_all();
let task = ::tokio::task::spawn(toempty_sockets.stop_all());
let _ = ::futures::executor::block_on(task);
let mut old_thread_pool = Vec::new();
::std::mem::swap(&mut self._thread_pool, &mut old_thread_pool);
let _ = old_thread_pool.into_iter().map(|th| th.join());
}
/// Stop all workers, listeners
pub async fn stop(&mut self) {
let _ = self.stop_working.send(true);
let toempty_sockets = self.sockets.rm_all();
toempty_sockets.stop_all().await;
let mut old_thread_pool = Vec::new();
::std::mem::swap(&mut self._thread_pool, &mut old_thread_pool);
let _ = old_thread_pool.into_iter().map(|th| th.join());
}
/// Add all UDP sockets found in config /// Add all UDP sockets found in config
/// and start listening for packets /// and start listening for packets
async fn add_sockets(&self) -> ::std::io::Result<()> { async fn add_sockets(
let sockets = self.cfg.listen.iter().map(|s_addr| async { &mut self,
let socket = ) -> ::std::io::Result<Vec<::std::net::SocketAddr>> {
::tokio::spawn(connection::socket::bind_udp(s_addr.clone())) // try to bind multiple sockets in parallel
.await??; let mut sock_set = ::tokio::task::JoinSet::new();
Ok(socket) self.cfg.listen.iter().for_each(|s_addr| {
}); let socket_address = s_addr.clone();
let sockets = ::futures::future::join_all(sockets).await;
for s_res in sockets.into_iter() {
match s_res {
Ok(s) => {
let stop_working = self.stop_working.subscribe(); let stop_working = self.stop_working.subscribe();
let th_work = self._thread_work.clone();
sock_set.spawn(async move {
let s = connection::socket::bind_udp(socket_address).await?;
let arc_s = Arc::new(s); let arc_s = Arc::new(s);
let join = ::tokio::spawn(Self::listen_udp( let join = ::tokio::spawn(Self::listen_udp(
stop_working, stop_working,
self._thread_work.clone(), th_work,
arc_s.clone(), arc_s.clone(),
)); ));
self.sockets.add_socket(arc_s, join).await; Ok((arc_s, join))
});
});
// make sure we either add all of them, or none
let mut all_socks = Vec::with_capacity(self.cfg.listen.len());
while let Some(join_res) = sock_set.join_next().await {
match join_res {
Ok(s_res) => match s_res {
Ok(s) => {
all_socks.push(s);
} }
Err(e) => { Err(e) => {
return Err(e); return Err(e);
} }
},
Err(e) => {
return Err(e.into());
} }
} }
Ok(()) }
let mut ret = Vec::with_capacity(self.cfg.listen.len());
for (arc_s, join) in all_socks.into_iter() {
ret.push(arc_s.local_addr().unwrap());
self.sockets.add_socket(arc_s, join).await;
}
Ok(ret)
} }
/// Run a dedicated loop to read packets on the listening socket /// Run a dedicated loop to read packets on the listening socket
async fn listen_udp( async fn listen_udp(
mut stop_working: ::tokio::sync::broadcast::Receiver<bool>, mut stop_working: StopWorkingRecvCh,
work_queues: Arc<Vec<::async_channel::Sender<Work>>>, work_queues: Arc<Vec<::async_channel::Sender<Work>>>,
socket: Arc<UdpSocket>, socket: Arc<UdpSocket>,
) -> ::std::io::Result<()> { ) -> ::std::io::Result<()> {
@ -202,8 +299,11 @@ impl Fenrir {
let queues_num = work_queues.len() as u64; let queues_num = work_queues.len() as u64;
loop { loop {
let (bytes, sock_sender) = ::tokio::select! { let (bytes, sock_sender) = ::tokio::select! {
_done = stop_working.recv() => { tell_stopped = stop_working.recv() => {
break; drop(socket);
let _ = tell_stopped.unwrap()
.send(StopWorking::ListenerStopped).await;
return Ok(());
} }
result = socket.recv_from(&mut buffer) => { result = socket.recv_from(&mut buffer) => {
result? result?
@ -241,7 +341,6 @@ impl Fenrir {
})) }))
.await; .await;
} }
Ok(())
} }
/// Get the raw TXT record of a Fenrir domain /// Get the raw TXT record of a Fenrir domain
pub async fn resolv_txt(&self, domain: &Domain) -> Result<String, Error> { pub async fn resolv_txt(&self, domain: &Domain) -> Result<String, Error> {
@ -373,6 +472,7 @@ impl Fenrir {
} }
} }
// needs to be called before add_sockets
async fn start_single_worker( async fn start_single_worker(
&mut self, &mut self,
) -> ::std::result::Result< ) -> ::std::result::Result<
@ -404,6 +504,10 @@ impl Fenrir {
self.cfg.listen.clone(), self.cfg.listen.clone(),
work_recv, work_recv,
); );
// don't keep around private keys too much
if (thread_idx + 1) == max_threads {
self.cfg.keys.clear();
}
loop { loop {
let queues_lock = match Arc::get_mut(&mut self._thread_work) { let queues_lock = match Arc::get_mut(&mut self._thread_work) {
Some(queues_lock) => queues_lock, Some(queues_lock) => queues_lock,
@ -421,7 +525,8 @@ impl Fenrir {
} }
Ok(worker) Ok(worker)
} }
// TODO: start work on a LocalSet provided by the user
// needs to be called before add_sockets
/// Start one working thread for each physical cpu /// Start one working thread for each physical cpu
/// threads are pinned to each cpu core. /// threads are pinned to each cpu core.
/// Work will be divided and rerouted so that there is no need to lock /// Work will be divided and rerouted so that there is no need to lock
@ -521,6 +626,8 @@ impl Fenrir {
} }
self._thread_pool.push(join_handle); self._thread_pool.push(join_handle);
} }
// don't keep around private keys too much
self.cfg.keys.clear();
Ok(()) Ok(())
} }
} }

View File

@ -1,8 +1,8 @@
use crate::*; use crate::*;
#[::tracing_test::traced_test]
#[::tokio::test] #[::tokio::test]
async fn test_connection_dirsync() { async fn test_connection_dirsync() {
return;
use enc::asym::{KeyID, PrivKey, PubKey}; use enc::asym::{KeyID, PrivKey, PubKey};
let rand = enc::Random::new(); let rand = enc::Random::new();
let (priv_exchange_key, pub_exchange_key) = let (priv_exchange_key, pub_exchange_key) =
@ -16,22 +16,6 @@ async fn test_connection_dirsync() {
return; return;
} }
}; };
let dnssec_record = Record {
public_keys: [(KeyID(42), pub_exchange_key)].to_vec(),
addresses: [record::Address {
ip: ::std::net::IpAddr::V4(::std::net::Ipv4Addr::new(127, 0, 0, 1)),
port: Some(::core::num::NonZeroU16::new(31337).unwrap()),
priority: record::AddressPriority::P1,
weight: record::AddressWeight::W1,
handshake_ids: [HandshakeID::DirectorySynchronized].to_vec(),
public_key_idx: [record::PubKeyIdx(0)].to_vec(),
}]
.to_vec(),
key_exchanges: [enc::asym::KeyExchangeKind::X25519DiffieHellman]
.to_vec(),
hkdfs: [enc::hkdf::HkdfKind::Sha3].to_vec(),
ciphers: [enc::sym::CipherKind::XChaCha20Poly1305].to_vec(),
};
let cfg_client = { let cfg_client = {
let mut cfg = config::Config::default(); let mut cfg = config::Config::default();
cfg.threads = Some(::core::num::NonZeroUsize::new(1).unwrap()); cfg.threads = Some(::core::num::NonZeroUsize::new(1).unwrap());
@ -43,21 +27,46 @@ async fn test_connection_dirsync() {
cfg cfg
}; };
let mut server = Fenrir::new(&cfg_server).unwrap(); let (server, mut srv_workers) =
let _ = server.setup_no_workers().await; Fenrir::with_workers(&cfg_server).await.unwrap();
let srv_worker = server.start_single_worker().await;
::tokio::task::spawn_local(async move { srv_worker }); let srv_worker = srv_workers.pop().unwrap();
let mut client = Fenrir::new(&cfg_client).unwrap(); let local_thread = ::tokio::task::LocalSet::new();
let _ = client.setup_no_workers().await; local_thread.spawn_local(async move { srv_worker.await });
let cli_worker = server.start_single_worker().await;
::tokio::task::spawn_local(async move { cli_worker }); let (client, mut cli_workers) =
Fenrir::with_workers(&cfg_client).await.unwrap();
let cli_worker = cli_workers.pop().unwrap();
local_thread.spawn_local(async move { cli_worker.await });
use crate::{ use crate::{
connection::handshake::HandshakeID, connection::handshake::HandshakeID,
dnssec::{record, Record}, dnssec::{record, Record},
}; };
let port: u16 = server.addresses()[0].port();
let dnssec_record = Record {
public_keys: [(KeyID(42), pub_exchange_key)].to_vec(),
addresses: [record::Address {
ip: ::std::net::IpAddr::V4(::std::net::Ipv4Addr::new(127, 0, 0, 1)),
port: Some(::core::num::NonZeroU16::new(port).unwrap()),
priority: record::AddressPriority::P1,
weight: record::AddressWeight::W1,
handshake_ids: [HandshakeID::DirectorySynchronized].to_vec(),
public_key_idx: [record::PubKeyIdx(0)].to_vec(),
}]
.to_vec(),
key_exchanges: [enc::asym::KeyExchangeKind::X25519DiffieHellman]
.to_vec(),
hkdfs: [enc::hkdf::HkdfKind::Sha3].to_vec(),
ciphers: [enc::sym::CipherKind::XChaCha20Poly1305].to_vec(),
};
server.graceful_stop().await;
client.graceful_stop().await;
return;
let _ = client let _ = client
.connect_resolved( .connect_resolved(
dnssec_record, dnssec_record,
@ -65,62 +74,6 @@ async fn test_connection_dirsync() {
auth::SERVICEID_AUTH, auth::SERVICEID_AUTH,
) )
.await; .await;
server.graceful_stop().await;
/* client.graceful_stop().await;
let thread_id = ThreadTracker { total: 1, id: 0 };
let (stop_sender, _) = ::tokio::sync::broadcast::channel::<bool>(1);
use ::std::net;
let cli_socket_addr = [net::SocketAddr::new(
net::IpAddr::V4(net::Ipv4Addr::new(127, 0, 0, 1)),
0,
)]
.to_vec();
let srv_socket_addr = [net::SocketAddr::new(
net::IpAddr::V4(net::Ipv4Addr::new(127, 0, 0, 1)),
0,
)]
.to_vec();
let srv_sock = Arc::new(connection::socket::bind_udp(srv_socket_addr[0])
.await
.unwrap());
let cli_sock = Arc::new(connection::socket::bind_udp(cli_socket_addr[0])
.await
.unwrap());
use crate::inner::worker::Work;
let (srv_work_send, srv_work_recv) = ::async_channel::unbounded::<Work>();
let (cli_work_send, cli_work_recv) = ::async_channel::unbounded::<Work>();
let srv_queue = Arc::new([srv_work_recv.clone()].to_vec());
let cli_queue = Arc::new([cli_work_recv.clone()].to_vec());
let listen_work_srv =
::tokio::spawn(Fenrir::listen_udp(
stop_sender.subscribe(),
let _server = crate::inner::worker::Worker::new(
cfg.clone(),
thread_id,
stop_sender.subscribe(),
None,
srv_socket_addr,
srv_work_recv,
);
let _client = crate::inner::worker::Worker::new(
cfg,
thread_id,
stop_sender.subscribe(),
None,
cli_socket_addr,
cli_work_recv,
);
todo!()
*/
} }