SocketList: arcswap the list of SocketList
faster socket add/remove, so that we can search this list to find with wich socket we should send Signed-off-by: Luca Fulchir <luca.fulchir@runesauth.com>
This commit is contained in:
parent
c3aff3e8df
commit
0d33033c0b
|
@ -25,4 +25,7 @@ pub enum Error {
|
||||||
/// Can not decrypt. Either corrupted or malicious data
|
/// Can not decrypt. Either corrupted or malicious data
|
||||||
#[error("decrypt: corrupted data")]
|
#[error("decrypt: corrupted data")]
|
||||||
Decrypt,
|
Decrypt,
|
||||||
|
/// Can not encrypt. library failure
|
||||||
|
#[error("can't encrypt")]
|
||||||
|
Encrypt,
|
||||||
}
|
}
|
||||||
|
|
|
@ -151,28 +151,44 @@ impl Cipher {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
fn overhead(&self) -> usize {
|
||||||
|
match self {
|
||||||
|
Cipher::XChaCha20Poly1305(cipher) => {
|
||||||
|
let cipher = CipherKind::XChaCha20Poly1305;
|
||||||
|
cipher.nonce_len() + cipher.tag_len()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
fn encrypt(
|
fn encrypt(
|
||||||
&self,
|
&self,
|
||||||
|
nonce: &Nonce,
|
||||||
aad: AAD,
|
aad: AAD,
|
||||||
nonce: Nonce,
|
data: &mut Data,
|
||||||
data: &mut [u8],
|
|
||||||
) -> Result<(), Error> {
|
) -> Result<(), Error> {
|
||||||
|
// No need to check for minimum buffer size since `Data` assures we
|
||||||
|
// already went through that
|
||||||
match self {
|
match self {
|
||||||
Cipher::XChaCha20Poly1305(cipher) => {
|
Cipher::XChaCha20Poly1305(cipher) => {
|
||||||
use ::chacha20poly1305::{
|
use ::chacha20poly1305::{
|
||||||
aead::generic_array::GenericArray, AeadInPlace,
|
aead::generic_array::GenericArray, AeadInPlace,
|
||||||
};
|
};
|
||||||
let min_len: usize = CipherKind::XChaCha20Poly1305.nonce_len()
|
// write nonce
|
||||||
+ CipherKind::XChaCha20Poly1305.tag_len()
|
data.get_slice_full()[..Nonce::len()]
|
||||||
+ 1;
|
.copy_from_slice(nonce.as_bytes());
|
||||||
if data.len() < min_len {
|
|
||||||
return Err(Error::InsufficientBuffer);
|
|
||||||
}
|
|
||||||
// write Nonce, then advance it
|
|
||||||
|
|
||||||
// encrypt data
|
// encrypt data
|
||||||
|
match cipher.cipher.encrypt_in_place_detached(
|
||||||
// add tag
|
nonce.as_bytes().into(),
|
||||||
|
aad.0,
|
||||||
|
data.get_slice(),
|
||||||
|
) {
|
||||||
|
Ok(tag) => {
|
||||||
|
// add tag
|
||||||
|
data.get_tag_slice().copy_from_slice(tag.as_slice());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Err(_) => Err(Error::Encrypt),
|
||||||
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
todo!()
|
todo!()
|
||||||
|
@ -203,6 +219,35 @@ impl CipherRecv {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Allocate some data, with additional indexes to track
|
||||||
|
/// where nonce and tags are
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Data {
|
||||||
|
data: Vec<u8>,
|
||||||
|
skip_start: usize,
|
||||||
|
skip_end: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Data {
|
||||||
|
/// Get the slice where you will write the actual data
|
||||||
|
/// this will skip the actual nonce and AEAD tag and give you
|
||||||
|
/// only the space for the data
|
||||||
|
pub fn get_slice(&mut self) -> &mut [u8] {
|
||||||
|
&mut self.data[self.skip_start..self.skip_end]
|
||||||
|
}
|
||||||
|
fn get_tag_slice(&mut self) -> &mut [u8] {
|
||||||
|
let start = self.data.len() - self.skip_end;
|
||||||
|
&mut self.data[start..]
|
||||||
|
}
|
||||||
|
fn get_slice_full(&mut self) -> &mut [u8] {
|
||||||
|
&mut self.data
|
||||||
|
}
|
||||||
|
/// Consume the data and return the whole raw vector
|
||||||
|
pub fn get_raw(self) -> Vec<u8> {
|
||||||
|
self.data
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Send only cipher
|
/// Send only cipher
|
||||||
#[allow(missing_debug_implementations)]
|
#[allow(missing_debug_implementations)]
|
||||||
pub struct CipherSend {
|
pub struct CipherSend {
|
||||||
|
@ -218,9 +263,19 @@ impl CipherSend {
|
||||||
cipher: Cipher::new(kind, secret),
|
cipher: Cipher::new(kind, secret),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/// Get the current nonce as &[u8]
|
/// Allocate the memory for the data that will be encrypted
|
||||||
pub fn nonce_as_bytes(&self) -> &[u8] {
|
pub fn make_data(&self, length: usize) -> Data {
|
||||||
self.nonce.as_bytes()
|
Data {
|
||||||
|
data: Vec::with_capacity(length + self.cipher.overhead()),
|
||||||
|
skip_start: self.cipher.nonce_len(),
|
||||||
|
skip_end: self.cipher.tag_len(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/// Encrypt the given data
|
||||||
|
pub fn encrypt(&mut self, aad: AAD, data: &mut Data) -> Result<(), Error> {
|
||||||
|
self.cipher.encrypt(&self.nonce, aad, data)?;
|
||||||
|
self.nonce.advance();
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
188
src/lib.rs
188
src/lib.rs
|
@ -179,13 +179,85 @@ type TokenChecker =
|
||||||
domain: auth::Domain,
|
domain: auth::Domain,
|
||||||
) -> ::futures::future::BoxFuture<'static, Result<bool, ()>>;
|
) -> ::futures::future::BoxFuture<'static, Result<bool, ()>>;
|
||||||
|
|
||||||
|
// PERF: move to arcswap:
|
||||||
|
// We use multiple UDP sockets.
|
||||||
|
// We need a list of them all and to scan this list.
|
||||||
|
// But we want to be able to update this list without locking everyone
|
||||||
|
// So we should wrap into ArcSwap, and use unsafe `as_raf_fd()` and
|
||||||
|
// `from_raw_fd` to update the list.
|
||||||
|
// This means that we will have multiple `UdpSocket` per actual socket
|
||||||
|
// so we have to handle `drop()` manually, and garbage-collect the ones we
|
||||||
|
// are no longer using in the background. sigh.
|
||||||
|
// Just go with a ArcSwapAny<Arc<Vec<Arc< ...
|
||||||
|
struct SocketList {
|
||||||
|
sockets:
|
||||||
|
ArcSwap<Vec<(Arc<UdpSocket>, Arc<JoinHandle<::std::io::Result<()>>>)>>,
|
||||||
|
}
|
||||||
|
impl SocketList {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
sockets: ArcSwap::new(Arc::new(Vec::new())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// TODO: fn rm_socket()
|
||||||
|
fn rm_all(&self) -> Self {
|
||||||
|
let new_list = Arc::new(Vec::new());
|
||||||
|
let old_list = self.sockets.swap(new_list);
|
||||||
|
Self {
|
||||||
|
sockets: old_list.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
async fn add_socket(
|
||||||
|
&self,
|
||||||
|
socket: Arc<UdpSocket>,
|
||||||
|
handle: JoinHandle<::std::io::Result<()>>,
|
||||||
|
) {
|
||||||
|
let mut new_list;
|
||||||
|
{
|
||||||
|
let old_list = self.sockets.load();
|
||||||
|
new_list = Arc::new(Vec::with_capacity(old_list.len() + 1));
|
||||||
|
new_list = old_list.to_vec().into();
|
||||||
|
}
|
||||||
|
Arc::get_mut(&mut new_list)
|
||||||
|
.unwrap()
|
||||||
|
.push((socket, Arc::new(handle)));
|
||||||
|
self.sockets.swap(new_list);
|
||||||
|
}
|
||||||
|
async fn find(&self, sock: SocketAddr) -> Option<Arc<UdpSocket>> {
|
||||||
|
let list = self.sockets.load();
|
||||||
|
match list.iter().find(|&(s, _)| s.local_addr().unwrap() == sock) {
|
||||||
|
Some((sock, _)) => Some(sock.clone()),
|
||||||
|
None => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
async fn stop_all(mut self) {
|
||||||
|
let mut arc_list = self.sockets.into_inner();
|
||||||
|
let list = loop {
|
||||||
|
match Arc::try_unwrap(arc_list) {
|
||||||
|
Ok(list) => break list,
|
||||||
|
Err(arc_retry) => {
|
||||||
|
arc_list = arc_retry;
|
||||||
|
::tokio::time::sleep(::core::time::Duration::from_millis(
|
||||||
|
50,
|
||||||
|
))
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
for (_socket, mut handle) in list.into_iter() {
|
||||||
|
Arc::get_mut(&mut handle).unwrap().await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Instance of a fenrir endpoint
|
/// Instance of a fenrir endpoint
|
||||||
#[allow(missing_copy_implementations, missing_debug_implementations)]
|
#[allow(missing_copy_implementations, missing_debug_implementations)]
|
||||||
pub struct Fenrir {
|
pub struct Fenrir {
|
||||||
/// library Configuration
|
/// library Configuration
|
||||||
cfg: Config,
|
cfg: Config,
|
||||||
/// listening udp sockets
|
/// listening udp sockets
|
||||||
sockets: Vec<(Arc<UdpSocket>, JoinHandle<::std::io::Result<()>>)>,
|
//sockets: Vec<(Arc<UdpSocket>, JoinHandle<::std::io::Result<()>>)>,
|
||||||
|
sockets: SocketList,
|
||||||
/// DNSSEC resolver, with failovers
|
/// DNSSEC resolver, with failovers
|
||||||
dnssec: Option<dnssec::Dnssec>,
|
dnssec: Option<dnssec::Dnssec>,
|
||||||
/// Broadcast channel to tell workers to stop working
|
/// Broadcast channel to tell workers to stop working
|
||||||
|
@ -211,7 +283,7 @@ impl Fenrir {
|
||||||
let (sender, _) = ::tokio::sync::broadcast::channel(1);
|
let (sender, _) = ::tokio::sync::broadcast::channel(1);
|
||||||
let endpoint = Fenrir {
|
let endpoint = Fenrir {
|
||||||
cfg: config.clone(),
|
cfg: config.clone(),
|
||||||
sockets: Vec::with_capacity(listen_num),
|
sockets: SocketList::new(),
|
||||||
dnssec: None,
|
dnssec: None,
|
||||||
stop_working: sender,
|
stop_working: sender,
|
||||||
_inner: Arc::new(FenrirInner {
|
_inner: Arc::new(FenrirInner {
|
||||||
|
@ -239,9 +311,11 @@ impl Fenrir {
|
||||||
/// asyncronous version for Drop
|
/// asyncronous version for Drop
|
||||||
fn stop_sync(&mut self) {
|
fn stop_sync(&mut self) {
|
||||||
let _ = self.stop_working.send(true);
|
let _ = self.stop_working.send(true);
|
||||||
let mut toempty_socket = Vec::new();
|
let mut toempty_sockets = self.sockets.rm_all();
|
||||||
::std::mem::swap(&mut self.sockets, &mut toempty_socket);
|
let task = ::tokio::task::spawn(Self::stop_sockets(toempty_sockets));
|
||||||
let task = ::tokio::task::spawn(Self::stop_sockets(toempty_socket));
|
//let mut toempty_socket = Vec::new();
|
||||||
|
//::std::mem::swap(&mut self.sockets, &mut toempty_socket);
|
||||||
|
//let task = ::tokio::task::spawn(Self::stop_sockets(toempty_socket));
|
||||||
let _ = ::futures::executor::block_on(task);
|
let _ = ::futures::executor::block_on(task);
|
||||||
self.dnssec = None;
|
self.dnssec = None;
|
||||||
}
|
}
|
||||||
|
@ -249,19 +323,19 @@ impl Fenrir {
|
||||||
/// Stop all workers, listeners
|
/// Stop all workers, listeners
|
||||||
pub async fn stop(&mut self) {
|
pub async fn stop(&mut self) {
|
||||||
let _ = self.stop_working.send(true);
|
let _ = self.stop_working.send(true);
|
||||||
let mut toempty_socket = Vec::new();
|
let mut toempty_sockets = self.sockets.rm_all();
|
||||||
::std::mem::swap(&mut self.sockets, &mut toempty_socket);
|
Self::stop_sockets(toempty_sockets).await;
|
||||||
Self::stop_sockets(toempty_socket).await;
|
|
||||||
self.dnssec = None;
|
self.dnssec = None;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// actually do the work of stopping resolvers and listeners
|
/// actually do the work of stopping resolvers and listeners
|
||||||
async fn stop_sockets(
|
async fn stop_sockets(sockets: SocketList) {
|
||||||
sockets: Vec<(Arc<UdpSocket>, JoinHandle<::std::io::Result<()>>)>,
|
sockets.stop_all().await;
|
||||||
) {
|
/*
|
||||||
for s in sockets.into_iter() {
|
for s in sockets.into_iter() {
|
||||||
let _ = s.1.await;
|
let _ = s.1.await;
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Enable some common socket options. This is just the unsafe part
|
/// Enable some common socket options. This is just the unsafe part
|
||||||
|
@ -285,6 +359,37 @@ impl Fenrir {
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Add all UDP sockets found in config
|
||||||
|
/// and start listening for packets
|
||||||
|
async fn add_sockets(&mut self) -> ::std::io::Result<()> {
|
||||||
|
let sockets = self.cfg.listen.iter().map(|s_addr| async {
|
||||||
|
let socket =
|
||||||
|
::tokio::spawn(Self::bind_udp(s_addr.clone())).await??;
|
||||||
|
Ok(socket)
|
||||||
|
});
|
||||||
|
let sockets = ::futures::future::join_all(sockets).await;
|
||||||
|
for s_res in sockets.into_iter() {
|
||||||
|
match s_res {
|
||||||
|
Ok(s) => {
|
||||||
|
let stop_working = self.stop_working.subscribe();
|
||||||
|
let arc_s = Arc::new(s);
|
||||||
|
let join = ::tokio::spawn(Self::listen_udp(
|
||||||
|
stop_working,
|
||||||
|
self._inner.clone(),
|
||||||
|
self.token_check.clone(),
|
||||||
|
arc_s.clone(),
|
||||||
|
));
|
||||||
|
self.sockets.add_socket(arc_s, join);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
return Err(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Add an async udp listener
|
/// Add an async udp listener
|
||||||
async fn bind_udp(sock: SocketAddr) -> ::std::io::Result<UdpSocket> {
|
async fn bind_udp(sock: SocketAddr) -> ::std::io::Result<UdpSocket> {
|
||||||
let socket = UdpSocket::bind(sock).await?;
|
let socket = UdpSocket::bind(sock).await?;
|
||||||
|
@ -319,9 +424,10 @@ impl Fenrir {
|
||||||
socket: Arc<UdpSocket>,
|
socket: Arc<UdpSocket>,
|
||||||
) -> ::std::io::Result<()> {
|
) -> ::std::io::Result<()> {
|
||||||
// jumbo frames are 9K max
|
// jumbo frames are 9K max
|
||||||
|
let sock_receiver = socket.local_addr()?;
|
||||||
let mut buffer: [u8; 9000] = [0; 9000];
|
let mut buffer: [u8; 9000] = [0; 9000];
|
||||||
loop {
|
loop {
|
||||||
let (bytes, sock_from) = ::tokio::select! {
|
let (bytes, sock_sender) = ::tokio::select! {
|
||||||
_done = stop_working.recv() => {
|
_done = stop_working.recv() => {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -333,42 +439,14 @@ impl Fenrir {
|
||||||
fenrir.clone(),
|
fenrir.clone(),
|
||||||
token_check.clone(),
|
token_check.clone(),
|
||||||
&buffer[0..bytes],
|
&buffer[0..bytes],
|
||||||
sock_from,
|
sock_receiver,
|
||||||
|
sock_sender,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add all UDP sockets found in config
|
|
||||||
/// and start listening for packets
|
|
||||||
async fn add_sockets(&mut self) -> ::std::io::Result<()> {
|
|
||||||
let sockets = self.cfg.listen.iter().map(|s_addr| async {
|
|
||||||
let socket =
|
|
||||||
::tokio::spawn(Self::bind_udp(s_addr.clone())).await??;
|
|
||||||
Ok(Arc::new(socket))
|
|
||||||
});
|
|
||||||
let sockets = ::futures::future::join_all(sockets).await;
|
|
||||||
for s_res in sockets.into_iter() {
|
|
||||||
match s_res {
|
|
||||||
Ok(s) => {
|
|
||||||
let stop_working = self.stop_working.subscribe();
|
|
||||||
let join = ::tokio::spawn(Self::listen_udp(
|
|
||||||
stop_working,
|
|
||||||
self._inner.clone(),
|
|
||||||
self.token_check.clone(),
|
|
||||||
s.clone(),
|
|
||||||
));
|
|
||||||
self.sockets.push((s, join));
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
return Err(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the raw TXT record of a Fenrir domain
|
/// Get the raw TXT record of a Fenrir domain
|
||||||
pub async fn resolv_str(&self, domain: &str) -> Result<String, Error> {
|
pub async fn resolv_str(&self, domain: &str) -> Result<String, Error> {
|
||||||
match &self.dnssec {
|
match &self.dnssec {
|
||||||
|
@ -389,7 +467,8 @@ impl Fenrir {
|
||||||
fenrir: Arc<FenrirInner>,
|
fenrir: Arc<FenrirInner>,
|
||||||
token_check: Arc<ArcSwapOption<TokenChecker>>,
|
token_check: Arc<ArcSwapOption<TokenChecker>>,
|
||||||
buffer: &[u8],
|
buffer: &[u8],
|
||||||
_sock_from: SocketAddr,
|
_sock_receiver: SocketAddr,
|
||||||
|
_sock_sender: SocketAddr,
|
||||||
) {
|
) {
|
||||||
if buffer.len() < Self::MIN_PACKET_BYTES {
|
if buffer.len() < Self::MIN_PACKET_BYTES {
|
||||||
return;
|
return;
|
||||||
|
@ -484,20 +563,25 @@ impl Fenrir {
|
||||||
// build response
|
// build response
|
||||||
let secret_send =
|
let secret_send =
|
||||||
authinfo.hkdf.get_secret(b"to_client");
|
authinfo.hkdf.get_secret(b"to_client");
|
||||||
let cipher_send = CipherRecv::new(
|
let mut cipher_send = CipherSend::new(
|
||||||
authinfo.cipher,
|
authinfo.cipher,
|
||||||
secret_send,
|
secret_send,
|
||||||
);
|
);
|
||||||
use crate::enc::sym::AAD;
|
use crate::enc::sym::AAD;
|
||||||
let aad = AAD(&mut []); // no aad for now
|
let aad = AAD(&mut []); // no aad for now
|
||||||
/*
|
let mut data = cipher_send
|
||||||
match cipher_send.encrypt(aad, &mut req.data.ciphertext()) {
|
.make_data(dirsync::RespData::len());
|
||||||
Ok(()) => req.data.mark_as_cleartext(),
|
|
||||||
Err(e) => {
|
if let Err(e) =
|
||||||
return Err(handshake::Error::Key(e).into());
|
cipher_send.encrypt(aad, &mut data)
|
||||||
}
|
{
|
||||||
}
|
::tracing::error!("can't encrypt: {:?}", e);
|
||||||
*/
|
return;
|
||||||
|
}
|
||||||
|
let resp = dirsync::Resp {
|
||||||
|
client_key_id: req_data.client_key_id,
|
||||||
|
enc: data.get_raw(),
|
||||||
|
};
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
|
|
Loading…
Reference in New Issue