tungstenite/handshake/machine.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
//! WebSocket handshake machine.
use bytes::Buf;
use log::*;
use std::io::{Cursor, Read, Write};
use crate::{
error::{Error, ProtocolError, Result},
util::NonBlockingResult,
ReadBuffer,
};
/// A generic handshake state machine.
#[derive(Debug)]
pub struct HandshakeMachine<Stream> {
stream: Stream,
state: HandshakeState,
}
impl<Stream> HandshakeMachine<Stream> {
/// Start reading data from the peer.
pub fn start_read(stream: Stream) -> Self {
Self { stream, state: HandshakeState::Reading(ReadBuffer::new(), AttackCheck::new()) }
}
/// Start writing data to the peer.
pub fn start_write<D: Into<Vec<u8>>>(stream: Stream, data: D) -> Self {
HandshakeMachine { stream, state: HandshakeState::Writing(Cursor::new(data.into())) }
}
/// Returns a shared reference to the inner stream.
pub fn get_ref(&self) -> &Stream {
&self.stream
}
/// Returns a mutable reference to the inner stream.
pub fn get_mut(&mut self) -> &mut Stream {
&mut self.stream
}
}
impl<Stream: Read + Write> HandshakeMachine<Stream> {
/// Perform a single handshake round.
pub fn single_round<Obj: TryParse>(mut self) -> Result<RoundResult<Obj, Stream>> {
trace!("Doing handshake round.");
match self.state {
HandshakeState::Reading(mut buf, mut attack_check) => {
let read = buf.read_from(&mut self.stream).no_block()?;
match read {
Some(0) => Err(Error::Protocol(ProtocolError::HandshakeIncomplete)),
Some(count) => {
attack_check.check_incoming_packet_size(count)?;
// TODO: this is slow for big headers with too many small packets.
// The parser has to be reworked in order to work on streams instead
// of buffers.
Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? {
buf.advance(size);
RoundResult::StageFinished(StageResult::DoneReading {
result: obj,
stream: self.stream,
tail: buf.into_vec(),
})
} else {
RoundResult::Incomplete(HandshakeMachine {
state: HandshakeState::Reading(buf, attack_check),
..self
})
})
}
None => Ok(RoundResult::WouldBlock(HandshakeMachine {
state: HandshakeState::Reading(buf, attack_check),
..self
})),
}
}
HandshakeState::Writing(mut buf) => {
assert!(buf.has_remaining());
if let Some(size) = self.stream.write(Buf::chunk(&buf)).no_block()? {
assert!(size > 0);
buf.advance(size);
Ok(if buf.has_remaining() {
RoundResult::Incomplete(HandshakeMachine {
state: HandshakeState::Writing(buf),
..self
})
} else {
RoundResult::StageFinished(StageResult::DoneWriting(self.stream))
})
} else {
Ok(RoundResult::WouldBlock(HandshakeMachine {
state: HandshakeState::Writing(buf),
..self
}))
}
}
}
}
}
/// The result of the round.
#[derive(Debug)]
pub enum RoundResult<Obj, Stream> {
/// Round not done, I/O would block.
WouldBlock(HandshakeMachine<Stream>),
/// Round done, state unchanged.
Incomplete(HandshakeMachine<Stream>),
/// Stage complete.
StageFinished(StageResult<Obj, Stream>),
}
/// The result of the stage.
#[derive(Debug)]
pub enum StageResult<Obj, Stream> {
/// Reading round finished.
#[allow(missing_docs)]
DoneReading { result: Obj, stream: Stream, tail: Vec<u8> },
/// Writing round finished.
DoneWriting(Stream),
}
/// The parseable object.
pub trait TryParse: Sized {
/// Return Ok(None) if incomplete, Err on syntax error.
fn try_parse(data: &[u8]) -> Result<Option<(usize, Self)>>;
}
/// The handshake state.
#[derive(Debug)]
enum HandshakeState {
/// Reading data from the peer.
Reading(ReadBuffer, AttackCheck),
/// Sending data to the peer.
Writing(Cursor<Vec<u8>>),
}
/// Attack mitigation. Contains counters needed to prevent DoS attacks
/// and reject valid but useless headers.
#[derive(Debug)]
pub(crate) struct AttackCheck {
/// Number of HTTP header successful reads (TCP packets).
number_of_packets: usize,
/// Total number of bytes in HTTP header.
number_of_bytes: usize,
}
impl AttackCheck {
/// Initialize attack checking for incoming buffer.
fn new() -> Self {
Self { number_of_packets: 0, number_of_bytes: 0 }
}
/// Check the size of an incoming packet. To be called immediately after `read()`
/// passing its returned bytes count as `size`.
fn check_incoming_packet_size(&mut self, size: usize) -> Result<()> {
self.number_of_packets += 1;
self.number_of_bytes += size;
// TODO: these values are hardcoded. Instead of making them configurable,
// rework the way HTTP header is parsed to remove this check at all.
const MAX_BYTES: usize = 65536;
const MAX_PACKETS: usize = 512;
const MIN_PACKET_SIZE: usize = 128;
const MIN_PACKET_CHECK_THRESHOLD: usize = 64;
if self.number_of_bytes > MAX_BYTES {
return Err(Error::AttackAttempt);
}
if self.number_of_packets > MAX_PACKETS {
return Err(Error::AttackAttempt);
}
if self.number_of_packets > MIN_PACKET_CHECK_THRESHOLD
&& self.number_of_packets * MIN_PACKET_SIZE > self.number_of_bytes
{
return Err(Error::AttackAttempt);
}
Ok(())
}
}