tungstenite/handshake/
machine.rs

1//! WebSocket handshake machine.
2
3use bytes::Buf;
4use log::*;
5use std::io::{Cursor, Read, Write};
6
7use crate::{
8    error::{Error, ProtocolError, Result},
9    util::NonBlockingResult,
10    ReadBuffer,
11};
12
13/// A generic handshake state machine.
14#[derive(Debug)]
15pub struct HandshakeMachine<Stream> {
16    stream: Stream,
17    state: HandshakeState,
18}
19
20impl<Stream> HandshakeMachine<Stream> {
21    /// Start reading data from the peer.
22    pub fn start_read(stream: Stream) -> Self {
23        Self { stream, state: HandshakeState::Reading(ReadBuffer::new(), AttackCheck::new()) }
24    }
25    /// Start writing data to the peer.
26    pub fn start_write<D: Into<Vec<u8>>>(stream: Stream, data: D) -> Self {
27        HandshakeMachine { stream, state: HandshakeState::Writing(Cursor::new(data.into())) }
28    }
29    /// Returns a shared reference to the inner stream.
30    pub fn get_ref(&self) -> &Stream {
31        &self.stream
32    }
33    /// Returns a mutable reference to the inner stream.
34    pub fn get_mut(&mut self) -> &mut Stream {
35        &mut self.stream
36    }
37}
38
39impl<Stream: Read + Write> HandshakeMachine<Stream> {
40    /// Perform a single handshake round.
41    pub fn single_round<Obj: TryParse>(mut self) -> Result<RoundResult<Obj, Stream>> {
42        trace!("Doing handshake round.");
43        match self.state {
44            HandshakeState::Reading(mut buf, mut attack_check) => {
45                let read = buf.read_from(&mut self.stream).no_block()?;
46                match read {
47                    Some(0) => Err(Error::Protocol(ProtocolError::HandshakeIncomplete)),
48                    Some(count) => {
49                        attack_check.check_incoming_packet_size(count)?;
50                        // TODO: this is slow for big headers with too many small packets.
51                        // The parser has to be reworked in order to work on streams instead
52                        // of buffers.
53                        Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? {
54                            buf.advance(size);
55                            RoundResult::StageFinished(StageResult::DoneReading {
56                                result: obj,
57                                stream: self.stream,
58                                tail: buf.into_vec(),
59                            })
60                        } else {
61                            RoundResult::Incomplete(HandshakeMachine {
62                                state: HandshakeState::Reading(buf, attack_check),
63                                ..self
64                            })
65                        })
66                    }
67                    None => Ok(RoundResult::WouldBlock(HandshakeMachine {
68                        state: HandshakeState::Reading(buf, attack_check),
69                        ..self
70                    })),
71                }
72            }
73            HandshakeState::Writing(mut buf) => {
74                assert!(buf.has_remaining());
75                if let Some(size) = self.stream.write(Buf::chunk(&buf)).no_block()? {
76                    assert!(size > 0);
77                    buf.advance(size);
78                    Ok(if buf.has_remaining() {
79                        RoundResult::Incomplete(HandshakeMachine {
80                            state: HandshakeState::Writing(buf),
81                            ..self
82                        })
83                    } else {
84                        RoundResult::Incomplete(HandshakeMachine {
85                            state: HandshakeState::Flushing,
86                            ..self
87                        })
88                    })
89                } else {
90                    Ok(RoundResult::WouldBlock(HandshakeMachine {
91                        state: HandshakeState::Writing(buf),
92                        ..self
93                    }))
94                }
95            }
96            HandshakeState::Flushing => Ok(match self.stream.flush().no_block()? {
97                Some(()) => RoundResult::StageFinished(StageResult::DoneWriting(self.stream)),
98                None => RoundResult::WouldBlock(HandshakeMachine {
99                    state: HandshakeState::Flushing,
100                    ..self
101                }),
102            }),
103        }
104    }
105}
106
107/// The result of the round.
108#[derive(Debug)]
109pub enum RoundResult<Obj, Stream> {
110    /// Round not done, I/O would block.
111    WouldBlock(HandshakeMachine<Stream>),
112    /// Round done, state unchanged.
113    Incomplete(HandshakeMachine<Stream>),
114    /// Stage complete.
115    StageFinished(StageResult<Obj, Stream>),
116}
117
118/// The result of the stage.
119#[derive(Debug)]
120pub enum StageResult<Obj, Stream> {
121    /// Reading round finished.
122    #[allow(missing_docs)]
123    DoneReading { result: Obj, stream: Stream, tail: Vec<u8> },
124    /// Writing round finished.
125    DoneWriting(Stream),
126}
127
128/// The parseable object.
129pub trait TryParse: Sized {
130    /// Return Ok(None) if incomplete, Err on syntax error.
131    fn try_parse(data: &[u8]) -> Result<Option<(usize, Self)>>;
132}
133
134/// The handshake state.
135#[derive(Debug)]
136enum HandshakeState {
137    /// Reading data from the peer.
138    Reading(ReadBuffer, AttackCheck),
139    /// Sending data to the peer.
140    Writing(Cursor<Vec<u8>>),
141    /// Flushing data to ensure that all intermediately buffered contents reach their destination.
142    Flushing,
143}
144
145/// Attack mitigation. Contains counters needed to prevent DoS attacks
146/// and reject valid but useless headers.
147#[derive(Debug)]
148pub(crate) struct AttackCheck {
149    /// Number of HTTP header successful reads (TCP packets).
150    number_of_packets: usize,
151    /// Total number of bytes in HTTP header.
152    number_of_bytes: usize,
153}
154
155impl AttackCheck {
156    /// Initialize attack checking for incoming buffer.
157    fn new() -> Self {
158        Self { number_of_packets: 0, number_of_bytes: 0 }
159    }
160
161    /// Check the size of an incoming packet. To be called immediately after `read()`
162    /// passing its returned bytes count as `size`.
163    fn check_incoming_packet_size(&mut self, size: usize) -> Result<()> {
164        self.number_of_packets += 1;
165        self.number_of_bytes += size;
166
167        // TODO: these values are hardcoded. Instead of making them configurable,
168        // rework the way HTTP header is parsed to remove this check at all.
169        const MAX_BYTES: usize = 65536;
170        const MAX_PACKETS: usize = 512;
171        const MIN_PACKET_SIZE: usize = 128;
172        const MIN_PACKET_CHECK_THRESHOLD: usize = 64;
173
174        if self.number_of_bytes > MAX_BYTES {
175            return Err(Error::AttackAttempt);
176        }
177
178        if self.number_of_packets > MAX_PACKETS {
179            return Err(Error::AttackAttempt);
180        }
181
182        if self.number_of_packets > MIN_PACKET_CHECK_THRESHOLD
183            && self.number_of_packets * MIN_PACKET_SIZE > self.number_of_bytes
184        {
185            return Err(Error::AttackAttempt);
186        }
187
188        Ok(())
189    }
190}