use crate::bindings::{BindingReadExact, BindingWriteExact};
use crate::error::BootstrapError;
use crate::messages::{
BootstrapClientMessage, BootstrapClientMessageSerializer, BootstrapServerMessage,
BootstrapServerMessageDeserializer,
};
use crate::settings::BootstrapClientConfig;
use massa_hash::Hash;
use massa_models::config::{
MAX_BOOTSTRAP_MESSAGE_SIZE, MAX_BOOTSTRAP_MESSAGE_SIZE_BYTES, SIGNATURE_DESER_SIZE,
};
use massa_models::serialization::{DeserializeMinBEInt, SerializeMinBEInt};
use massa_models::version::{Version, VersionSerializer};
use massa_serialization::{DeserializeError, Deserializer, Serializer};
use massa_signature::{PublicKey, Signature};
use rand::{rngs::StdRng, RngCore, SeedableRng};
use std::time::Instant;
use std::{net::TcpStream, time::Duration};
use stream_limiter::{Limiter, LimiterOptions};
pub struct BootstrapClientBinder {
remote_pubkey: PublicKey,
duplex: Limiter<TcpStream>,
prev_message: Option<Hash>,
version_serializer: VersionSerializer,
cfg: BootstrapClientConfig,
}
const KNOWN_PREFIX_LEN: usize = SIGNATURE_DESER_SIZE + MAX_BOOTSTRAP_MESSAGE_SIZE_BYTES;
struct ServerMessageLeader {
sig: Signature,
msg_len: u32,
}
impl BootstrapClientBinder {
#[allow(clippy::too_many_arguments)]
pub fn new(
duplex: TcpStream,
remote_pubkey: PublicKey,
cfg: BootstrapClientConfig,
limit: Option<u64>,
) -> Self {
let limit_opts =
limit.map(|limit| LimiterOptions::new(limit, Duration::from_millis(1000), limit));
let duplex = Limiter::new(duplex, limit_opts.clone(), limit_opts);
BootstrapClientBinder {
remote_pubkey,
duplex,
prev_message: None,
version_serializer: VersionSerializer::new(),
cfg,
}
}
pub fn handshake(&mut self, version: Version) -> Result<(), BootstrapError> {
let msg_hash = {
let mut version_ser = Vec::new();
self.version_serializer
.serialize(&version, &mut version_ser)?;
let mut version_random_bytes =
vec![0u8; version_ser.len() + self.cfg.randomness_size_bytes];
version_random_bytes[..version_ser.len()].clone_from_slice(&version_ser);
StdRng::from_entropy().fill_bytes(&mut version_random_bytes[version_ser.len()..]);
self.write_all_timeout(&version_random_bytes, None)
.map_err(|(e, _)| e)?;
Hash::compute_from(&version_random_bytes)
};
self.prev_message = Some(msg_hash);
Ok(())
}
pub fn next_timeout(
&mut self,
duration: Option<Duration>,
) -> Result<BootstrapServerMessage, BootstrapError> {
let deadline = duration.map(|d| Instant::now() + d);
let mut known_len_buff = [0u8; KNOWN_PREFIX_LEN];
self.read_exact_timeout(&mut known_len_buff, deadline)
.map_err(|(err, _consumed)| err)?;
let ServerMessageLeader { sig, msg_len } = self.decode_msg_leader(&known_len_buff)?;
let message_deserializer = BootstrapServerMessageDeserializer::new((&self.cfg).into());
let prev_msg = self
.prev_message
.replace(Hash::compute_from(&sig.to_bytes()));
let message = {
if let Some(prev_msg) = prev_msg {
let mut stream_bytes =
vec![0u8; msg_len.try_into().expect("Overflow on msg_len to usize")];
self.read_exact_timeout(&mut stream_bytes[..], deadline)
.map_err(|(e, _consumed)| e)?;
let msg_bytes = &mut stream_bytes[..];
let rehash_seed = &[prev_msg.to_bytes().as_slice(), msg_bytes].concat();
let msg_hash = Hash::compute_from(rehash_seed);
self.remote_pubkey.verify_signature(&msg_hash, &sig)?;
let (_, msg) = message_deserializer
.deserialize::<DeserializeError>(msg_bytes)
.map_err(|err| BootstrapError::DeserializeError(format!("{}", err)))?;
msg
} else {
let mut stream_bytes =
vec![0u8; msg_len.try_into().expect("Overflow on msg_len to usize")];
self.read_exact_timeout(&mut stream_bytes[..], deadline)
.map_err(|(e, _)| e)?;
let sig_msg_bytes = &mut stream_bytes[..];
let msg_hash = Hash::compute_from(sig_msg_bytes);
self.remote_pubkey.verify_signature(&msg_hash, &sig)?;
let (_, msg) = message_deserializer
.deserialize::<DeserializeError>(sig_msg_bytes)
.map_err(|err| BootstrapError::DeserializeError(format!("{}", err)))?;
msg
}
};
Ok(message)
}
pub fn send_timeout(
&mut self,
msg: &BootstrapClientMessage,
duration: Option<Duration>,
) -> Result<(), BootstrapError> {
let deadline = duration.map(|d| Instant::now() + d);
let mut msg_bytes = Vec::new();
let message_serializer = BootstrapClientMessageSerializer::new();
message_serializer.serialize(msg, &mut msg_bytes)?;
let msg_len: u32 = msg_bytes.len().try_into().map_err(|e| {
BootstrapError::GeneralError(format!("bootstrap message too large to encode: {}", e))
})?;
let mut write_buf = Vec::new();
if let Some(prev_message) = self.prev_message {
let prev_message = prev_message.to_bytes();
let mut hash_data =
Vec::with_capacity(prev_message.len().saturating_add(msg_bytes.len()));
hash_data.extend(prev_message);
hash_data.extend(&msg_bytes);
self.prev_message = Some(Hash::compute_from(&hash_data));
write_buf.extend(prev_message);
} else {
self.prev_message = Some(Hash::compute_from(&msg_bytes));
}
let msg_len_bytes = msg_len.to_be_bytes_min(MAX_BOOTSTRAP_MESSAGE_SIZE)?;
write_buf.extend(&msg_len_bytes);
write_buf.extend(&msg_bytes);
self.write_all_timeout(&write_buf, deadline)
.map_err(|(e, _)| e)?;
Ok(())
}
fn decode_msg_leader(
&self,
leader_buff: &[u8; SIGNATURE_DESER_SIZE + MAX_BOOTSTRAP_MESSAGE_SIZE_BYTES],
) -> Result<ServerMessageLeader, BootstrapError> {
let sig = Signature::from_bytes(leader_buff)?;
let msg_len = u32::from_be_bytes_min(
&leader_buff[SIGNATURE_DESER_SIZE..],
MAX_BOOTSTRAP_MESSAGE_SIZE,
)?
.0;
Ok(ServerMessageLeader { sig, msg_len })
}
}
impl crate::bindings::BindingReadExact for BootstrapClientBinder {
fn set_read_timeout(&mut self, duration: Option<Duration>) -> Result<(), std::io::Error> {
if let Some(ref mut opts) = self.duplex.read_opt {
opts.timeout = duration;
}
self.duplex.stream.set_read_timeout(duration)
}
}
impl std::io::Read for BootstrapClientBinder {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
self.duplex.read(buf)
}
}
impl crate::bindings::BindingWriteExact for BootstrapClientBinder {
fn set_write_timeout(&mut self, duration: Option<Duration>) -> Result<(), std::io::Error> {
if let Some(ref mut opts) = self.duplex.write_opt {
opts.timeout = duration;
}
self.duplex.stream.set_write_timeout(duration)
}
}
impl std::io::Write for BootstrapClientBinder {
fn write(&mut self, buf: &[u8]) -> Result<usize, std::io::Error> {
self.duplex.write(buf)
}
fn flush(&mut self) -> Result<(), std::io::Error> {
self.duplex.flush()
}
}