Test and fix shutdowns
we have a Quick but partial shutdown, which lets the async "threads" work in the background and shutdown after a bit more time and the graceful/full shutdown, which waits for everything. Unfortunately `Drop` can't manage async and blocks everything, no way to yeld either, so if we only have a thread we would deadlock if we tried to stop things gracefully Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
This commit is contained in:
parent
aff1c313f5
commit
b682068dca
12
Cargo.toml
12
Cargo.toml
|
@ -50,6 +50,7 @@ tokio = { version = "1", features = ["full"] }
|
|||
# PERF: todo linux-only, behind "iouring" feature
|
||||
#tokio-uring = { version = "0.4" }
|
||||
tracing = { version = "0.1" }
|
||||
tracing-test = { version = "0.2" }
|
||||
trust-dns-resolver = { version = "0.22", features = [ "dnssec-ring" ] }
|
||||
trust-dns-client = { version = "0.22", features = [ "dnssec" ] }
|
||||
trust-dns-proto = { version = "0.22" }
|
||||
|
@ -72,3 +73,14 @@ incremental = true
|
|||
codegen-units = 256
|
||||
rpath = false
|
||||
|
||||
[profile.test]
|
||||
opt-level = 0
|
||||
debug = true
|
||||
debug-assertions = true
|
||||
overflow-checks = true
|
||||
lto = false
|
||||
panic = 'unwind'
|
||||
incremental = true
|
||||
codegen-units = 256
|
||||
rpath = false
|
||||
|
||||
|
|
|
@ -141,7 +141,7 @@ pub(crate) struct ClientConnectInfo {
|
|||
#[derive(Debug)]
|
||||
pub(crate) enum HandshakeAction {
|
||||
/// Parsing finished, all ok, nothing to do
|
||||
Nonthing,
|
||||
Nothing,
|
||||
/// Packet parsed, now go perform authentication
|
||||
AuthNeeded(AuthNeededInfo),
|
||||
/// the client can fully establish a connection with this info
|
||||
|
@ -155,7 +155,7 @@ pub(crate) enum HandshakeAction {
|
|||
/// core = (udp_src_sender_port % total_threads) - 1
|
||||
pub(crate) struct HandshakeTracker {
|
||||
thread_id: ThreadTracker,
|
||||
key_exchanges: Vec<(asym::KeyKind, asym::KeyExchangeKind)>,
|
||||
key_exchanges: Vec<asym::KeyExchangeKind>,
|
||||
ciphers: Vec<CipherKind>,
|
||||
/// ephemeral keys used server side in key exchange
|
||||
keys_srv: Vec<HandshakeServer>,
|
||||
|
@ -164,16 +164,24 @@ pub(crate) struct HandshakeTracker {
|
|||
}
|
||||
|
||||
impl HandshakeTracker {
|
||||
pub(crate) fn new(thread_id: ThreadTracker) -> Self {
|
||||
pub(crate) fn new(
|
||||
thread_id: ThreadTracker,
|
||||
ciphers: Vec<CipherKind>,
|
||||
key_exchanges: Vec<asym::KeyExchangeKind>,
|
||||
) -> Self {
|
||||
Self {
|
||||
thread_id,
|
||||
ciphers: Vec::new(),
|
||||
key_exchanges: Vec::new(),
|
||||
ciphers,
|
||||
key_exchanges,
|
||||
keys_srv: Vec::new(),
|
||||
hshake_cli: HandshakeClientList::new(),
|
||||
}
|
||||
}
|
||||
pub(crate) fn new_client(
|
||||
pub(crate) fn add_server(&mut self, id: KeyID, key: PrivKey) {
|
||||
self.keys_srv.push(HandshakeServer { id, key });
|
||||
self.keys_srv.sort_by(|h_a, h_b| h_a.id.0.cmp(&h_b.id.0));
|
||||
}
|
||||
pub(crate) fn add_client(
|
||||
&mut self,
|
||||
priv_key: PrivKey,
|
||||
pub_key: PubKey,
|
||||
|
@ -208,45 +216,34 @@ impl HandshakeTracker {
|
|||
match handshake.data {
|
||||
HandshakeData::DirSync(ref mut ds) => match ds {
|
||||
DirSync::Req(ref mut req) => {
|
||||
let ephemeral_key = {
|
||||
if let Some(h_k) =
|
||||
self.keys_srv.iter().find(|k| k.id == req.key_id)
|
||||
{
|
||||
if !self.key_exchanges.contains(&req.exchange) {
|
||||
return Err(enc::Error::UnsupportedKeyExchange.into());
|
||||
}
|
||||
if !self.ciphers.contains(&req.cipher) {
|
||||
return Err(enc::Error::UnsupportedCipher.into());
|
||||
}
|
||||
let has_key = self.keys_srv.iter().find(|k| {
|
||||
if k.id == req.key_id {
|
||||
// Directory synchronized can only use keys
|
||||
// for key exchange, not signing keys
|
||||
if let PrivKey::Exchange(k) = &h_k.key {
|
||||
Some(k.clone())
|
||||
} else {
|
||||
None
|
||||
if let PrivKey::Exchange(_) = k.key {
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
if ephemeral_key.is_none() {
|
||||
::tracing::debug!(
|
||||
"No such server key id: {:?}",
|
||||
req.key_id
|
||||
);
|
||||
return Err(handshake::Error::UnknownKeyID.into());
|
||||
}
|
||||
let ephemeral_key = ephemeral_key.unwrap();
|
||||
{
|
||||
if None
|
||||
== self.key_exchanges.iter().find(|&x| {
|
||||
*x == (ephemeral_key.kind(), req.exchange)
|
||||
})
|
||||
{
|
||||
return Err(
|
||||
enc::Error::UnsupportedKeyExchange.into()
|
||||
);
|
||||
false
|
||||
});
|
||||
|
||||
let ephemeral_key;
|
||||
match has_key {
|
||||
Some(s_k) => {
|
||||
if let PrivKey::Exchange(ref k) = &s_k.key {
|
||||
ephemeral_key = k;
|
||||
} else {
|
||||
unreachable!();
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
if None
|
||||
== self.ciphers.iter().find(|&x| *x == req.cipher)
|
||||
{
|
||||
return Err(enc::Error::UnsupportedCipher.into());
|
||||
None => {
|
||||
return Err(handshake::Error::UnknownKeyID.into())
|
||||
}
|
||||
}
|
||||
let shared_key = match ephemeral_key
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
//! Socket related types and functions
|
||||
|
||||
use ::arc_swap::ArcSwap;
|
||||
use ::std::{net::SocketAddr, sync::Arc, vec::Vec};
|
||||
use ::tokio::{net::UdpSocket, task::JoinHandle};
|
||||
|
||||
|
@ -10,82 +9,31 @@ pub type SocketTracker =
|
|||
|
||||
/// async free socket list
|
||||
pub(crate) struct SocketList {
|
||||
pub list: ArcSwap<Vec<SocketTracker>>,
|
||||
pub list: Vec<SocketTracker>,
|
||||
}
|
||||
impl SocketList {
|
||||
pub(crate) fn new() -> Self {
|
||||
Self {
|
||||
list: ArcSwap::new(Arc::new(Vec::new())),
|
||||
}
|
||||
Self { list: Vec::new() }
|
||||
}
|
||||
// TODO: fn rm_socket()
|
||||
pub(crate) fn rm_all(&self) -> Self {
|
||||
let new_list = Arc::new(Vec::new());
|
||||
let old_list = self.list.swap(new_list);
|
||||
Self {
|
||||
list: old_list.into(),
|
||||
}
|
||||
pub(crate) fn rm_all(&mut self) -> Self {
|
||||
let mut old_list = Vec::new();
|
||||
::core::mem::swap(&mut self.list, &mut old_list);
|
||||
Self { list: old_list }
|
||||
}
|
||||
pub(crate) async fn add_socket(
|
||||
&self,
|
||||
&mut self,
|
||||
socket: Arc<UdpSocket>,
|
||||
handle: JoinHandle<::std::io::Result<()>>,
|
||||
) {
|
||||
// we could simplify this into just a `.swap` instead of `.rcu` but
|
||||
// it is not yet guaranteed that only one thread will call this fn
|
||||
// ...we don't need performance here anyway
|
||||
let arc_handle = Arc::new(handle);
|
||||
self.list.rcu(|old_list| {
|
||||
let mut new_list = Arc::new(Vec::with_capacity(old_list.len() + 1));
|
||||
new_list = old_list.to_vec().into();
|
||||
Arc::get_mut(&mut new_list)
|
||||
.unwrap()
|
||||
.push((socket.clone(), arc_handle.clone()));
|
||||
new_list
|
||||
});
|
||||
self.list.push((socket, arc_handle));
|
||||
}
|
||||
/// This method assumes no other `add_sockets` are being run
|
||||
pub(crate) async fn stop_all(self) {
|
||||
let mut arc_list = self.list.into_inner();
|
||||
let list = loop {
|
||||
match Arc::try_unwrap(arc_list) {
|
||||
Ok(list) => break list,
|
||||
Err(arc_retry) => {
|
||||
arc_list = arc_retry;
|
||||
::tokio::time::sleep(::core::time::Duration::from_millis(
|
||||
50,
|
||||
))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
};
|
||||
for (_socket, mut handle) in list.into_iter() {
|
||||
for (_socket, mut handle) in self.list.into_iter() {
|
||||
let _ = Arc::get_mut(&mut handle).unwrap().await;
|
||||
}
|
||||
}
|
||||
pub(crate) fn lock(&self) -> SocketListRef {
|
||||
SocketListRef {
|
||||
list: self.list.load_full(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Reference to a locked SocketList
|
||||
// TODO: impl Drop for SocketList
|
||||
pub(crate) struct SocketListRef {
|
||||
list: Arc<Vec<SocketTracker>>,
|
||||
}
|
||||
impl SocketListRef {
|
||||
pub(crate) fn find(&self, sock: UdpServer) -> Option<Arc<UdpSocket>> {
|
||||
match self
|
||||
.list
|
||||
.iter()
|
||||
.find(|&(s, _)| s.local_addr().unwrap() == sock.0)
|
||||
{
|
||||
Some((sock_srv, _)) => Some(sock_srv.clone()),
|
||||
None => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Strong typedef for a client socket address
|
||||
|
|
|
@ -62,7 +62,7 @@ pub(crate) struct Worker {
|
|||
thread_id: ThreadTracker,
|
||||
// PERF: rand uses syscalls. how to do that async?
|
||||
rand: Random,
|
||||
stop_working: ::tokio::sync::broadcast::Receiver<bool>,
|
||||
stop_working: crate::StopWorkingRecvCh,
|
||||
token_check: Option<Arc<Mutex<TokenChecker>>>,
|
||||
sockets: Vec<UdpSocket>,
|
||||
queue: ::async_channel::Receiver<Work>,
|
||||
|
@ -77,7 +77,7 @@ impl Worker {
|
|||
pub(crate) async fn new_and_loop(
|
||||
cfg: Config,
|
||||
thread_id: ThreadTracker,
|
||||
stop_working: ::tokio::sync::broadcast::Receiver<bool>,
|
||||
stop_working: crate::StopWorkingRecvCh,
|
||||
token_check: Option<Arc<Mutex<TokenChecker>>>,
|
||||
socket_addrs: Vec<::std::net::SocketAddr>,
|
||||
queue: ::async_channel::Receiver<Work>,
|
||||
|
@ -96,9 +96,9 @@ impl Worker {
|
|||
Ok(())
|
||||
}
|
||||
pub(crate) async fn new(
|
||||
cfg: Config,
|
||||
mut cfg: Config,
|
||||
thread_id: ThreadTracker,
|
||||
stop_working: ::tokio::sync::broadcast::Receiver<bool>,
|
||||
stop_working: crate::StopWorkingRecvCh,
|
||||
token_check: Option<Arc<Mutex<TokenChecker>>>,
|
||||
socket_addrs: Vec<::std::net::SocketAddr>,
|
||||
queue: ::async_channel::Receiver<Work>,
|
||||
|
@ -108,36 +108,43 @@ impl Worker {
|
|||
// in the future we will want to have a thread-local listener too,
|
||||
// but before that we need ebpf to pin a connection to a thread
|
||||
// directly from the kernel
|
||||
let socket_binding =
|
||||
socket_addrs.into_iter().map(|s_addr| async move {
|
||||
let socket = ::tokio::spawn(connection::socket::bind_udp(
|
||||
s_addr.clone(),
|
||||
))
|
||||
.await??;
|
||||
let mut sock_set = ::tokio::task::JoinSet::new();
|
||||
socket_addrs.into_iter().for_each(|s_addr| {
|
||||
sock_set.spawn(async move {
|
||||
let socket =
|
||||
connection::socket::bind_udp(s_addr.clone()).await?;
|
||||
Ok(socket)
|
||||
});
|
||||
let sockets_bind_res =
|
||||
::futures::future::join_all(socket_binding).await;
|
||||
let sockets: Result<Vec<UdpSocket>, ::std::io::Error> =
|
||||
sockets_bind_res
|
||||
.into_iter()
|
||||
.map(|s_res| match s_res {
|
||||
Ok(s) => Ok(s),
|
||||
});
|
||||
// make sure we either add all of them, or none
|
||||
let mut sockets = Vec::with_capacity(cfg.listen.len());
|
||||
while let Some(join_res) = sock_set.join_next().await {
|
||||
match join_res {
|
||||
Ok(s_res) => match s_res {
|
||||
Ok(sock) => sockets.push(sock),
|
||||
Err(e) => {
|
||||
::tracing::error!("Worker can't bind on socket: {}", e);
|
||||
Err(e)
|
||||
::tracing::error!("Can't rebind socket");
|
||||
return Err(e);
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let sockets = match sockets {
|
||||
Ok(sockets) => sockets,
|
||||
Err(e) => {
|
||||
return Err(e);
|
||||
},
|
||||
Err(e) => return Err(e.into()),
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
let (queue_timeouts_send, queue_timeouts_recv) =
|
||||
mpsc::unbounded_channel();
|
||||
let mut handshakes = HandshakeTracker::new(
|
||||
thread_id,
|
||||
cfg.ciphers.clone(),
|
||||
cfg.key_exchanges.clone(),
|
||||
);
|
||||
let mut keys = Vec::new();
|
||||
// make sure the keys are no longer in the config
|
||||
::core::mem::swap(&mut keys, &mut cfg.keys);
|
||||
for k in keys.into_iter() {
|
||||
handshakes.add_server(k.0, k.1);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
cfg,
|
||||
thread_id,
|
||||
|
@ -150,13 +157,15 @@ impl Worker {
|
|||
queue_timeouts_send,
|
||||
thread_channels: Vec::new(),
|
||||
connections: ConnList::new(thread_id),
|
||||
handshakes: HandshakeTracker::new(thread_id),
|
||||
handshakes,
|
||||
})
|
||||
}
|
||||
pub(crate) async fn work_loop(&mut self) {
|
||||
'mainloop: loop {
|
||||
let work = ::tokio::select! {
|
||||
_done = self.stop_working.recv() => {
|
||||
tell_stopped = self.stop_working.recv() => {
|
||||
let _ = tell_stopped.unwrap().send(
|
||||
crate::StopWorking::WorkerStopped).await;
|
||||
break;
|
||||
}
|
||||
maybe_timeout = self.queue.recv() => {
|
||||
|
@ -326,7 +335,7 @@ impl Worker {
|
|||
conn.id_recv = auth_recv_id;
|
||||
let (client_key_id, hshake) = match self
|
||||
.handshakes
|
||||
.new_client(
|
||||
.add_client(
|
||||
PrivKey::Exchange(priv_key),
|
||||
PubKey::Exchange(pub_key),
|
||||
conn_info.service_id,
|
||||
|
@ -617,7 +626,7 @@ impl Worker {
|
|||
IDSend(resp_data.service_connection_id);
|
||||
let _ = self.connections.track(service_connection.into());
|
||||
}
|
||||
HandshakeAction::Nonthing => {}
|
||||
HandshakeAction::Nothing => {}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
249
src/lib.rs
249
src/lib.rs
|
@ -69,6 +69,17 @@ pub enum Error {
|
|||
Encrypt(enc::Error),
|
||||
}
|
||||
|
||||
pub(crate) enum StopWorking {
|
||||
WorkerStopped,
|
||||
ListenerStopped,
|
||||
}
|
||||
|
||||
pub(crate) type StopWorkingSendCh =
|
||||
::tokio::sync::broadcast::Sender<::tokio::sync::mpsc::Sender<StopWorking>>;
|
||||
pub(crate) type StopWorkingRecvCh = ::tokio::sync::broadcast::Receiver<
|
||||
::tokio::sync::mpsc::Sender<StopWorking>,
|
||||
>;
|
||||
|
||||
/// Instance of a fenrir endpoint
|
||||
#[allow(missing_copy_implementations, missing_debug_implementations)]
|
||||
pub struct Fenrir {
|
||||
|
@ -79,7 +90,7 @@ pub struct Fenrir {
|
|||
/// DNSSEC resolver, with failovers
|
||||
dnssec: dnssec::Dnssec,
|
||||
/// Broadcast channel to tell workers to stop working
|
||||
stop_working: ::tokio::sync::broadcast::Sender<bool>,
|
||||
stop_working: StopWorkingSendCh,
|
||||
/// where to ask for token check
|
||||
token_check: Option<Arc<::tokio::sync::Mutex<TokenChecker>>>,
|
||||
/// tracks the connections to authentication servers
|
||||
|
@ -89,22 +100,74 @@ pub struct Fenrir {
|
|||
// manner
|
||||
_thread_pool: Vec<::std::thread::JoinHandle<()>>,
|
||||
_thread_work: Arc<Vec<::async_channel::Sender<Work>>>,
|
||||
// This can be different from cfg.listen since using port 0 will result
|
||||
// in a random port assigned by the operative system
|
||||
_listen_addrs: Vec<::std::net::SocketAddr>,
|
||||
}
|
||||
|
||||
// TODO: graceful vs immediate stop
|
||||
|
||||
impl Drop for Fenrir {
|
||||
fn drop(&mut self) {
|
||||
self.stop_sync()
|
||||
::tracing::debug!(
|
||||
"Fenrir fast shutdown.\
|
||||
Some threads might remain a bit longer"
|
||||
);
|
||||
let _ = self.stop_sync();
|
||||
}
|
||||
}
|
||||
|
||||
impl Fenrir {
|
||||
/// Gracefully stop all listeners and workers
|
||||
/// only return when all resources have been deallocated
|
||||
pub async fn graceful_stop(mut self) {
|
||||
::tracing::debug!("Fenrir full shut down");
|
||||
if let Some((ch, listeners, workers)) = self.stop_sync() {
|
||||
self.stop_wait(ch, listeners, workers).await;
|
||||
}
|
||||
}
|
||||
fn stop_sync(
|
||||
&mut self,
|
||||
) -> Option<(::tokio::sync::mpsc::Receiver<StopWorking>, usize, usize)>
|
||||
{
|
||||
let listeners_num = self.sockets.list.len();
|
||||
let workers_num = self._thread_work.len();
|
||||
if self.sockets.list.len() > 0 || self._thread_work.len() > 0 {
|
||||
let (ch_send, ch_recv) = ::tokio::sync::mpsc::channel(4);
|
||||
let _ = self.stop_working.send(ch_send);
|
||||
let _ = self.sockets.rm_all();
|
||||
self._thread_pool.clear();
|
||||
Some((ch_recv, listeners_num, workers_num))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
async fn stop_wait(
|
||||
&mut self,
|
||||
mut ch: ::tokio::sync::mpsc::Receiver<StopWorking>,
|
||||
mut listeners_num: usize,
|
||||
mut workers_num: usize,
|
||||
) {
|
||||
while listeners_num > 0 && workers_num > 0 {
|
||||
match ch.recv().await {
|
||||
Some(stopped) => match stopped {
|
||||
StopWorking::WorkerStopped => workers_num = workers_num - 1,
|
||||
StopWorking::ListenerStopped => {
|
||||
listeners_num = listeners_num - 1
|
||||
}
|
||||
},
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
/// Create a new Fenrir endpoint
|
||||
pub fn new(config: &Config) -> Result<Self, Error> {
|
||||
/// spawn threads pinned to cpus in our own way with tokio's runtime
|
||||
pub async fn with_threads(
|
||||
config: &Config,
|
||||
tokio_rt: Arc<::tokio::runtime::Runtime>,
|
||||
) -> Result<Self, Error> {
|
||||
let (sender, _) = ::tokio::sync::broadcast::channel(1);
|
||||
let dnssec = dnssec::Dnssec::new(&config.resolvers)?;
|
||||
let endpoint = Fenrir {
|
||||
let mut endpoint = Self {
|
||||
cfg: config.clone(),
|
||||
sockets: SocketList::new(),
|
||||
dnssec,
|
||||
|
@ -113,86 +176,120 @@ impl Fenrir {
|
|||
conn_auth_srv: Mutex::new(AuthServerConnections::new()),
|
||||
_thread_pool: Vec::new(),
|
||||
_thread_work: Arc::new(Vec::new()),
|
||||
_listen_addrs: Vec::with_capacity(config.listen.len()),
|
||||
};
|
||||
endpoint.start_work_threads_pinned(tokio_rt).await?;
|
||||
match endpoint.add_sockets().await {
|
||||
Ok(addrs) => endpoint._listen_addrs = addrs,
|
||||
Err(e) => return Err(e.into()),
|
||||
}
|
||||
Ok(endpoint)
|
||||
}
|
||||
|
||||
///FIXME: remove this, move into new()
|
||||
/// Start all workers, listeners
|
||||
pub async fn start(
|
||||
&mut self,
|
||||
tokio_rt: Arc<::tokio::runtime::Runtime>,
|
||||
) -> Result<(), Error> {
|
||||
self.start_work_threads_pinned(tokio_rt).await?;
|
||||
if let Err(e) = self.add_sockets().await {
|
||||
self.stop().await;
|
||||
return Err(e.into());
|
||||
/// Create a new Fenrir endpoint
|
||||
/// Get the workers that you can use in a tokio LocalSet
|
||||
/// You should:
|
||||
/// * move these workers each in its own thread
|
||||
/// * make sure that the threads are pinned on the cpu
|
||||
pub async fn with_workers(
|
||||
config: &Config,
|
||||
) -> Result<
|
||||
(
|
||||
Self,
|
||||
Vec<impl futures::Future<Output = Result<(), std::io::Error>>>,
|
||||
),
|
||||
Error,
|
||||
> {
|
||||
let (stop_working, _) = ::tokio::sync::broadcast::channel(1);
|
||||
let dnssec = dnssec::Dnssec::new(&config.resolvers)?;
|
||||
let cfg = config.clone();
|
||||
let sockets = SocketList::new();
|
||||
let conn_auth_srv = Mutex::new(AuthServerConnections::new());
|
||||
let thread_pool = Vec::new();
|
||||
let thread_work = Arc::new(Vec::new());
|
||||
let listen_addrs = Vec::with_capacity(config.listen.len());
|
||||
let mut endpoint = Self {
|
||||
cfg,
|
||||
sockets,
|
||||
dnssec,
|
||||
stop_working: stop_working.clone(),
|
||||
token_check: None,
|
||||
conn_auth_srv,
|
||||
_thread_pool: thread_pool,
|
||||
_thread_work: thread_work,
|
||||
_listen_addrs: listen_addrs,
|
||||
};
|
||||
let worker_num = config.threads.unwrap().get();
|
||||
let mut workers = Vec::with_capacity(worker_num);
|
||||
for _ in 0..worker_num {
|
||||
workers.push(endpoint.start_single_worker().await?);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
///FIXME: remove this, move into new()
|
||||
pub async fn setup_no_workers(&mut self) -> Result<(), Error> {
|
||||
if let Err(e) = self.add_sockets().await {
|
||||
self.stop().await;
|
||||
return Err(e.into());
|
||||
match endpoint.add_sockets().await {
|
||||
Ok(addrs) => endpoint._listen_addrs = addrs,
|
||||
Err(e) => return Err(e.into()),
|
||||
}
|
||||
Ok(())
|
||||
Ok((endpoint, workers))
|
||||
}
|
||||
/// Returns the list of the actual addresses we are listening on
|
||||
/// Note that this can be different from what was configured:
|
||||
/// if you specified UDP port 0 a random one has been assigned to you
|
||||
/// by the operating system.
|
||||
pub fn addresses(&self) -> Vec<::std::net::SocketAddr> {
|
||||
self._listen_addrs.clone()
|
||||
}
|
||||
|
||||
/// Stop all workers, listeners
|
||||
/// asyncronous version for Drop
|
||||
fn stop_sync(&mut self) {
|
||||
let _ = self.stop_working.send(true);
|
||||
let toempty_sockets = self.sockets.rm_all();
|
||||
let task = ::tokio::task::spawn(toempty_sockets.stop_all());
|
||||
let _ = ::futures::executor::block_on(task);
|
||||
let mut old_thread_pool = Vec::new();
|
||||
::std::mem::swap(&mut self._thread_pool, &mut old_thread_pool);
|
||||
let _ = old_thread_pool.into_iter().map(|th| th.join());
|
||||
}
|
||||
|
||||
/// Stop all workers, listeners
|
||||
pub async fn stop(&mut self) {
|
||||
let _ = self.stop_working.send(true);
|
||||
let toempty_sockets = self.sockets.rm_all();
|
||||
toempty_sockets.stop_all().await;
|
||||
let mut old_thread_pool = Vec::new();
|
||||
::std::mem::swap(&mut self._thread_pool, &mut old_thread_pool);
|
||||
let _ = old_thread_pool.into_iter().map(|th| th.join());
|
||||
}
|
||||
// only call **after** starting all threads
|
||||
/// Add all UDP sockets found in config
|
||||
/// and start listening for packets
|
||||
async fn add_sockets(&self) -> ::std::io::Result<()> {
|
||||
let sockets = self.cfg.listen.iter().map(|s_addr| async {
|
||||
let socket =
|
||||
::tokio::spawn(connection::socket::bind_udp(s_addr.clone()))
|
||||
.await??;
|
||||
Ok(socket)
|
||||
async fn add_sockets(
|
||||
&mut self,
|
||||
) -> ::std::io::Result<Vec<::std::net::SocketAddr>> {
|
||||
// try to bind multiple sockets in parallel
|
||||
let mut sock_set = ::tokio::task::JoinSet::new();
|
||||
self.cfg.listen.iter().for_each(|s_addr| {
|
||||
let socket_address = s_addr.clone();
|
||||
let stop_working = self.stop_working.subscribe();
|
||||
let th_work = self._thread_work.clone();
|
||||
sock_set.spawn(async move {
|
||||
let s = connection::socket::bind_udp(socket_address).await?;
|
||||
let arc_s = Arc::new(s);
|
||||
let join = ::tokio::spawn(Self::listen_udp(
|
||||
stop_working,
|
||||
th_work,
|
||||
arc_s.clone(),
|
||||
));
|
||||
Ok((arc_s, join))
|
||||
});
|
||||
});
|
||||
let sockets = ::futures::future::join_all(sockets).await;
|
||||
for s_res in sockets.into_iter() {
|
||||
match s_res {
|
||||
Ok(s) => {
|
||||
let stop_working = self.stop_working.subscribe();
|
||||
let arc_s = Arc::new(s);
|
||||
let join = ::tokio::spawn(Self::listen_udp(
|
||||
stop_working,
|
||||
self._thread_work.clone(),
|
||||
arc_s.clone(),
|
||||
));
|
||||
self.sockets.add_socket(arc_s, join).await;
|
||||
}
|
||||
|
||||
// make sure we either add all of them, or none
|
||||
let mut all_socks = Vec::with_capacity(self.cfg.listen.len());
|
||||
while let Some(join_res) = sock_set.join_next().await {
|
||||
match join_res {
|
||||
Ok(s_res) => match s_res {
|
||||
Ok(s) => {
|
||||
all_socks.push(s);
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(e);
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
return Err(e);
|
||||
return Err(e.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
|
||||
let mut ret = Vec::with_capacity(self.cfg.listen.len());
|
||||
for (arc_s, join) in all_socks.into_iter() {
|
||||
ret.push(arc_s.local_addr().unwrap());
|
||||
self.sockets.add_socket(arc_s, join).await;
|
||||
}
|
||||
Ok(ret)
|
||||
}
|
||||
|
||||
/// Run a dedicated loop to read packets on the listening socket
|
||||
async fn listen_udp(
|
||||
mut stop_working: ::tokio::sync::broadcast::Receiver<bool>,
|
||||
mut stop_working: StopWorkingRecvCh,
|
||||
work_queues: Arc<Vec<::async_channel::Sender<Work>>>,
|
||||
socket: Arc<UdpSocket>,
|
||||
) -> ::std::io::Result<()> {
|
||||
|
@ -202,8 +299,11 @@ impl Fenrir {
|
|||
let queues_num = work_queues.len() as u64;
|
||||
loop {
|
||||
let (bytes, sock_sender) = ::tokio::select! {
|
||||
_done = stop_working.recv() => {
|
||||
break;
|
||||
tell_stopped = stop_working.recv() => {
|
||||
drop(socket);
|
||||
let _ = tell_stopped.unwrap()
|
||||
.send(StopWorking::ListenerStopped).await;
|
||||
return Ok(());
|
||||
}
|
||||
result = socket.recv_from(&mut buffer) => {
|
||||
result?
|
||||
|
@ -241,7 +341,6 @@ impl Fenrir {
|
|||
}))
|
||||
.await;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
/// Get the raw TXT record of a Fenrir domain
|
||||
pub async fn resolv_txt(&self, domain: &Domain) -> Result<String, Error> {
|
||||
|
@ -373,6 +472,7 @@ impl Fenrir {
|
|||
}
|
||||
}
|
||||
|
||||
// needs to be called before add_sockets
|
||||
async fn start_single_worker(
|
||||
&mut self,
|
||||
) -> ::std::result::Result<
|
||||
|
@ -404,6 +504,10 @@ impl Fenrir {
|
|||
self.cfg.listen.clone(),
|
||||
work_recv,
|
||||
);
|
||||
// don't keep around private keys too much
|
||||
if (thread_idx + 1) == max_threads {
|
||||
self.cfg.keys.clear();
|
||||
}
|
||||
loop {
|
||||
let queues_lock = match Arc::get_mut(&mut self._thread_work) {
|
||||
Some(queues_lock) => queues_lock,
|
||||
|
@ -421,7 +525,8 @@ impl Fenrir {
|
|||
}
|
||||
Ok(worker)
|
||||
}
|
||||
// TODO: start work on a LocalSet provided by the user
|
||||
|
||||
// needs to be called before add_sockets
|
||||
/// Start one working thread for each physical cpu
|
||||
/// threads are pinned to each cpu core.
|
||||
/// Work will be divided and rerouted so that there is no need to lock
|
||||
|
@ -521,6 +626,8 @@ impl Fenrir {
|
|||
}
|
||||
self._thread_pool.push(join_handle);
|
||||
}
|
||||
// don't keep around private keys too much
|
||||
self.cfg.keys.clear();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
119
src/tests.rs
119
src/tests.rs
|
@ -1,8 +1,8 @@
|
|||
use crate::*;
|
||||
|
||||
#[::tracing_test::traced_test]
|
||||
#[::tokio::test]
|
||||
async fn test_connection_dirsync() {
|
||||
return;
|
||||
use enc::asym::{KeyID, PrivKey, PubKey};
|
||||
let rand = enc::Random::new();
|
||||
let (priv_exchange_key, pub_exchange_key) =
|
||||
|
@ -16,22 +16,6 @@ async fn test_connection_dirsync() {
|
|||
return;
|
||||
}
|
||||
};
|
||||
let dnssec_record = Record {
|
||||
public_keys: [(KeyID(42), pub_exchange_key)].to_vec(),
|
||||
addresses: [record::Address {
|
||||
ip: ::std::net::IpAddr::V4(::std::net::Ipv4Addr::new(127, 0, 0, 1)),
|
||||
port: Some(::core::num::NonZeroU16::new(31337).unwrap()),
|
||||
priority: record::AddressPriority::P1,
|
||||
weight: record::AddressWeight::W1,
|
||||
handshake_ids: [HandshakeID::DirectorySynchronized].to_vec(),
|
||||
public_key_idx: [record::PubKeyIdx(0)].to_vec(),
|
||||
}]
|
||||
.to_vec(),
|
||||
key_exchanges: [enc::asym::KeyExchangeKind::X25519DiffieHellman]
|
||||
.to_vec(),
|
||||
hkdfs: [enc::hkdf::HkdfKind::Sha3].to_vec(),
|
||||
ciphers: [enc::sym::CipherKind::XChaCha20Poly1305].to_vec(),
|
||||
};
|
||||
let cfg_client = {
|
||||
let mut cfg = config::Config::default();
|
||||
cfg.threads = Some(::core::num::NonZeroUsize::new(1).unwrap());
|
||||
|
@ -43,21 +27,46 @@ async fn test_connection_dirsync() {
|
|||
cfg
|
||||
};
|
||||
|
||||
let mut server = Fenrir::new(&cfg_server).unwrap();
|
||||
let _ = server.setup_no_workers().await;
|
||||
let srv_worker = server.start_single_worker().await;
|
||||
let (server, mut srv_workers) =
|
||||
Fenrir::with_workers(&cfg_server).await.unwrap();
|
||||
|
||||
::tokio::task::spawn_local(async move { srv_worker });
|
||||
let mut client = Fenrir::new(&cfg_client).unwrap();
|
||||
let _ = client.setup_no_workers().await;
|
||||
let cli_worker = server.start_single_worker().await;
|
||||
::tokio::task::spawn_local(async move { cli_worker });
|
||||
let srv_worker = srv_workers.pop().unwrap();
|
||||
let local_thread = ::tokio::task::LocalSet::new();
|
||||
local_thread.spawn_local(async move { srv_worker.await });
|
||||
|
||||
let (client, mut cli_workers) =
|
||||
Fenrir::with_workers(&cfg_client).await.unwrap();
|
||||
let cli_worker = cli_workers.pop().unwrap();
|
||||
local_thread.spawn_local(async move { cli_worker.await });
|
||||
|
||||
use crate::{
|
||||
connection::handshake::HandshakeID,
|
||||
dnssec::{record, Record},
|
||||
};
|
||||
|
||||
let port: u16 = server.addresses()[0].port();
|
||||
|
||||
let dnssec_record = Record {
|
||||
public_keys: [(KeyID(42), pub_exchange_key)].to_vec(),
|
||||
addresses: [record::Address {
|
||||
ip: ::std::net::IpAddr::V4(::std::net::Ipv4Addr::new(127, 0, 0, 1)),
|
||||
port: Some(::core::num::NonZeroU16::new(port).unwrap()),
|
||||
priority: record::AddressPriority::P1,
|
||||
weight: record::AddressWeight::W1,
|
||||
handshake_ids: [HandshakeID::DirectorySynchronized].to_vec(),
|
||||
public_key_idx: [record::PubKeyIdx(0)].to_vec(),
|
||||
}]
|
||||
.to_vec(),
|
||||
key_exchanges: [enc::asym::KeyExchangeKind::X25519DiffieHellman]
|
||||
.to_vec(),
|
||||
hkdfs: [enc::hkdf::HkdfKind::Sha3].to_vec(),
|
||||
ciphers: [enc::sym::CipherKind::XChaCha20Poly1305].to_vec(),
|
||||
};
|
||||
|
||||
server.graceful_stop().await;
|
||||
client.graceful_stop().await;
|
||||
return;
|
||||
|
||||
let _ = client
|
||||
.connect_resolved(
|
||||
dnssec_record,
|
||||
|
@ -65,62 +74,6 @@ async fn test_connection_dirsync() {
|
|||
auth::SERVICEID_AUTH,
|
||||
)
|
||||
.await;
|
||||
|
||||
/*
|
||||
let thread_id = ThreadTracker { total: 1, id: 0 };
|
||||
|
||||
let (stop_sender, _) = ::tokio::sync::broadcast::channel::<bool>(1);
|
||||
|
||||
use ::std::net;
|
||||
let cli_socket_addr = [net::SocketAddr::new(
|
||||
net::IpAddr::V4(net::Ipv4Addr::new(127, 0, 0, 1)),
|
||||
0,
|
||||
)]
|
||||
.to_vec();
|
||||
let srv_socket_addr = [net::SocketAddr::new(
|
||||
net::IpAddr::V4(net::Ipv4Addr::new(127, 0, 0, 1)),
|
||||
0,
|
||||
)]
|
||||
.to_vec();
|
||||
|
||||
let srv_sock = Arc::new(connection::socket::bind_udp(srv_socket_addr[0])
|
||||
.await
|
||||
.unwrap());
|
||||
let cli_sock = Arc::new(connection::socket::bind_udp(cli_socket_addr[0])
|
||||
.await
|
||||
.unwrap());
|
||||
|
||||
use crate::inner::worker::Work;
|
||||
let (srv_work_send, srv_work_recv) = ::async_channel::unbounded::<Work>();
|
||||
let (cli_work_send, cli_work_recv) = ::async_channel::unbounded::<Work>();
|
||||
|
||||
let srv_queue = Arc::new([srv_work_recv.clone()].to_vec());
|
||||
let cli_queue = Arc::new([cli_work_recv.clone()].to_vec());
|
||||
|
||||
let listen_work_srv =
|
||||
|
||||
|
||||
::tokio::spawn(Fenrir::listen_udp(
|
||||
stop_sender.subscribe(),
|
||||
|
||||
|
||||
let _server = crate::inner::worker::Worker::new(
|
||||
cfg.clone(),
|
||||
thread_id,
|
||||
stop_sender.subscribe(),
|
||||
None,
|
||||
srv_socket_addr,
|
||||
srv_work_recv,
|
||||
);
|
||||
let _client = crate::inner::worker::Worker::new(
|
||||
cfg,
|
||||
thread_id,
|
||||
stop_sender.subscribe(),
|
||||
None,
|
||||
cli_socket_addr,
|
||||
cli_work_recv,
|
||||
);
|
||||
|
||||
todo!()
|
||||
*/
|
||||
server.graceful_stop().await;
|
||||
client.graceful_stop().await;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue