tungstenite/handshake/
machine.rs
1use bytes::Buf;
4use log::*;
5use std::io::{Cursor, Read, Write};
6
7use crate::{
8 error::{Error, ProtocolError, Result},
9 util::NonBlockingResult,
10 ReadBuffer,
11};
12
13#[derive(Debug)]
15pub struct HandshakeMachine<Stream> {
16 stream: Stream,
17 state: HandshakeState,
18}
19
20impl<Stream> HandshakeMachine<Stream> {
21 pub fn start_read(stream: Stream) -> Self {
23 Self { stream, state: HandshakeState::Reading(ReadBuffer::new(), AttackCheck::new()) }
24 }
25 pub fn start_write<D: Into<Vec<u8>>>(stream: Stream, data: D) -> Self {
27 HandshakeMachine { stream, state: HandshakeState::Writing(Cursor::new(data.into())) }
28 }
29 pub fn get_ref(&self) -> &Stream {
31 &self.stream
32 }
33 pub fn get_mut(&mut self) -> &mut Stream {
35 &mut self.stream
36 }
37}
38
39impl<Stream: Read + Write> HandshakeMachine<Stream> {
40 pub fn single_round<Obj: TryParse>(mut self) -> Result<RoundResult<Obj, Stream>> {
42 trace!("Doing handshake round.");
43 match self.state {
44 HandshakeState::Reading(mut buf, mut attack_check) => {
45 let read = buf.read_from(&mut self.stream).no_block()?;
46 match read {
47 Some(0) => Err(Error::Protocol(ProtocolError::HandshakeIncomplete)),
48 Some(count) => {
49 attack_check.check_incoming_packet_size(count)?;
50 Ok(if let Some((size, obj)) = Obj::try_parse(Buf::chunk(&buf))? {
54 buf.advance(size);
55 RoundResult::StageFinished(StageResult::DoneReading {
56 result: obj,
57 stream: self.stream,
58 tail: buf.into_vec(),
59 })
60 } else {
61 RoundResult::Incomplete(HandshakeMachine {
62 state: HandshakeState::Reading(buf, attack_check),
63 ..self
64 })
65 })
66 }
67 None => Ok(RoundResult::WouldBlock(HandshakeMachine {
68 state: HandshakeState::Reading(buf, attack_check),
69 ..self
70 })),
71 }
72 }
73 HandshakeState::Writing(mut buf) => {
74 assert!(buf.has_remaining());
75 if let Some(size) = self.stream.write(Buf::chunk(&buf)).no_block()? {
76 assert!(size > 0);
77 buf.advance(size);
78 Ok(if buf.has_remaining() {
79 RoundResult::Incomplete(HandshakeMachine {
80 state: HandshakeState::Writing(buf),
81 ..self
82 })
83 } else {
84 RoundResult::Incomplete(HandshakeMachine {
85 state: HandshakeState::Flushing,
86 ..self
87 })
88 })
89 } else {
90 Ok(RoundResult::WouldBlock(HandshakeMachine {
91 state: HandshakeState::Writing(buf),
92 ..self
93 }))
94 }
95 }
96 HandshakeState::Flushing => Ok(match self.stream.flush().no_block()? {
97 Some(()) => RoundResult::StageFinished(StageResult::DoneWriting(self.stream)),
98 None => RoundResult::WouldBlock(HandshakeMachine {
99 state: HandshakeState::Flushing,
100 ..self
101 }),
102 }),
103 }
104 }
105}
106
107#[derive(Debug)]
109pub enum RoundResult<Obj, Stream> {
110 WouldBlock(HandshakeMachine<Stream>),
112 Incomplete(HandshakeMachine<Stream>),
114 StageFinished(StageResult<Obj, Stream>),
116}
117
118#[derive(Debug)]
120pub enum StageResult<Obj, Stream> {
121 #[allow(missing_docs)]
123 DoneReading { result: Obj, stream: Stream, tail: Vec<u8> },
124 DoneWriting(Stream),
126}
127
128pub trait TryParse: Sized {
130 fn try_parse(data: &[u8]) -> Result<Option<(usize, Self)>>;
132}
133
134#[derive(Debug)]
136enum HandshakeState {
137 Reading(ReadBuffer, AttackCheck),
139 Writing(Cursor<Vec<u8>>),
141 Flushing,
143}
144
145#[derive(Debug)]
148pub(crate) struct AttackCheck {
149 number_of_packets: usize,
151 number_of_bytes: usize,
153}
154
155impl AttackCheck {
156 fn new() -> Self {
158 Self { number_of_packets: 0, number_of_bytes: 0 }
159 }
160
161 fn check_incoming_packet_size(&mut self, size: usize) -> Result<()> {
164 self.number_of_packets += 1;
165 self.number_of_bytes += size;
166
167 const MAX_BYTES: usize = 65536;
170 const MAX_PACKETS: usize = 512;
171 const MIN_PACKET_SIZE: usize = 128;
172 const MIN_PACKET_CHECK_THRESHOLD: usize = 64;
173
174 if self.number_of_bytes > MAX_BYTES {
175 return Err(Error::AttackAttempt);
176 }
177
178 if self.number_of_packets > MAX_PACKETS {
179 return Err(Error::AttackAttempt);
180 }
181
182 if self.number_of_packets > MIN_PACKET_CHECK_THRESHOLD
183 && self.number_of_packets * MIN_PACKET_SIZE > self.number_of_bytes
184 {
185 return Err(Error::AttackAttempt);
186 }
187
188 Ok(())
189 }
190}