tungstenite/protocol/frame/
frame.rs

1use log::*;
2use std::{
3    default::Default,
4    fmt,
5    io::{Cursor, ErrorKind, Read, Write},
6    mem,
7    result::Result as StdResult,
8    str::Utf8Error,
9    string::String,
10};
11
12use super::{
13    coding::{CloseCode, Control, Data, OpCode},
14    mask::{apply_mask, generate_mask},
15};
16use crate::{
17    error::{Error, ProtocolError, Result},
18    protocol::frame::Utf8Bytes,
19};
20use bytes::{Bytes, BytesMut};
21
22/// A struct representing the close command.
23#[derive(Debug, Clone, Eq, PartialEq)]
24pub struct CloseFrame {
25    /// The reason as a code.
26    pub code: CloseCode,
27    /// The reason as text string.
28    pub reason: Utf8Bytes,
29}
30
31impl fmt::Display for CloseFrame {
32    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
33        write!(f, "{} ({})", self.reason, self.code)
34    }
35}
36
37/// A struct representing a WebSocket frame header.
38#[allow(missing_copy_implementations)]
39#[derive(Debug, Clone, Eq, PartialEq)]
40pub struct FrameHeader {
41    /// Indicates that the frame is the last one of a possibly fragmented message.
42    pub is_final: bool,
43    /// Reserved for protocol extensions.
44    pub rsv1: bool,
45    /// Reserved for protocol extensions.
46    pub rsv2: bool,
47    /// Reserved for protocol extensions.
48    pub rsv3: bool,
49    /// WebSocket protocol opcode.
50    pub opcode: OpCode,
51    /// A frame mask, if any.
52    pub mask: Option<[u8; 4]>,
53}
54
55impl Default for FrameHeader {
56    fn default() -> Self {
57        FrameHeader {
58            is_final: true,
59            rsv1: false,
60            rsv2: false,
61            rsv3: false,
62            opcode: OpCode::Control(Control::Close),
63            mask: None,
64        }
65    }
66}
67
68impl FrameHeader {
69    /// Parse a header from an input stream.
70    /// Returns `None` if insufficient data and does not consume anything in this case.
71    /// Payload size is returned along with the header.
72    pub fn parse(cursor: &mut Cursor<impl AsRef<[u8]>>) -> Result<Option<(Self, u64)>> {
73        let initial = cursor.position();
74        match Self::parse_internal(cursor) {
75            ret @ Ok(None) => {
76                cursor.set_position(initial);
77                ret
78            }
79            ret => ret,
80        }
81    }
82
83    /// Get the size of the header formatted with given payload length.
84    #[allow(clippy::len_without_is_empty)]
85    pub fn len(&self, length: u64) -> usize {
86        2 + LengthFormat::for_length(length).extra_bytes() + if self.mask.is_some() { 4 } else { 0 }
87    }
88
89    /// Format a header for given payload size.
90    pub fn format(&self, length: u64, output: &mut impl Write) -> Result<()> {
91        let code: u8 = self.opcode.into();
92
93        let one = {
94            code | if self.is_final { 0x80 } else { 0 }
95                | if self.rsv1 { 0x40 } else { 0 }
96                | if self.rsv2 { 0x20 } else { 0 }
97                | if self.rsv3 { 0x10 } else { 0 }
98        };
99
100        let lenfmt = LengthFormat::for_length(length);
101
102        let two = { lenfmt.length_byte() | if self.mask.is_some() { 0x80 } else { 0 } };
103
104        output.write_all(&[one, two])?;
105        match lenfmt {
106            LengthFormat::U8(_) => (),
107            LengthFormat::U16 => {
108                output.write_all(&(length as u16).to_be_bytes())?;
109            }
110            LengthFormat::U64 => {
111                output.write_all(&length.to_be_bytes())?;
112            }
113        }
114
115        if let Some(ref mask) = self.mask {
116            output.write_all(mask)?;
117        }
118
119        Ok(())
120    }
121
122    /// Generate a random frame mask and store this in the header.
123    ///
124    /// Of course this does not change frame contents. It just generates a mask.
125    pub(crate) fn set_random_mask(&mut self) {
126        self.mask = Some(generate_mask());
127    }
128}
129
130impl FrameHeader {
131    /// Internal parse engine.
132    /// Returns `None` if insufficient data.
133    /// Payload size is returned along with the header.
134    fn parse_internal(cursor: &mut impl Read) -> Result<Option<(Self, u64)>> {
135        let (first, second) = {
136            let mut head = [0u8; 2];
137            if cursor.read(&mut head)? != 2 {
138                return Ok(None);
139            }
140            trace!("Parsed headers {:?}", head);
141            (head[0], head[1])
142        };
143
144        trace!("First: {:b}", first);
145        trace!("Second: {:b}", second);
146
147        let is_final = first & 0x80 != 0;
148
149        let rsv1 = first & 0x40 != 0;
150        let rsv2 = first & 0x20 != 0;
151        let rsv3 = first & 0x10 != 0;
152
153        let opcode = OpCode::from(first & 0x0F);
154        trace!("Opcode: {:?}", opcode);
155
156        let masked = second & 0x80 != 0;
157        trace!("Masked: {:?}", masked);
158
159        let length = {
160            let length_byte = second & 0x7F;
161            let length_length = LengthFormat::for_byte(length_byte).extra_bytes();
162            if length_length > 0 {
163                const SIZE: usize = mem::size_of::<u64>();
164                assert!(length_length <= SIZE, "length exceeded size of u64");
165                let start = SIZE - length_length;
166                let mut buffer = [0; SIZE];
167                match cursor.read_exact(&mut buffer[start..]) {
168                    Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => return Ok(None),
169                    Err(err) => return Err(err.into()),
170                    Ok(()) => u64::from_be_bytes(buffer),
171                }
172            } else {
173                u64::from(length_byte)
174            }
175        };
176
177        let mask = if masked {
178            let mut mask_bytes = [0u8; 4];
179            if cursor.read(&mut mask_bytes)? != 4 {
180                return Ok(None);
181            } else {
182                Some(mask_bytes)
183            }
184        } else {
185            None
186        };
187
188        // Disallow bad opcode
189        match opcode {
190            OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => {
191                return Err(Error::Protocol(ProtocolError::InvalidOpcode(first & 0x0F)))
192            }
193            _ => (),
194        }
195
196        let hdr = FrameHeader { is_final, rsv1, rsv2, rsv3, opcode, mask };
197
198        Ok(Some((hdr, length)))
199    }
200}
201
202/// A struct representing a WebSocket frame.
203#[derive(Debug, Clone, Eq, PartialEq)]
204pub struct Frame {
205    header: FrameHeader,
206    payload: Bytes,
207}
208
209impl Frame {
210    /// Get the length of the frame.
211    /// This is the length of the header + the length of the payload.
212    #[inline]
213    pub fn len(&self) -> usize {
214        let length = self.payload.len();
215        self.header.len(length as u64) + length
216    }
217
218    /// Check if the frame is empty.
219    #[inline]
220    pub fn is_empty(&self) -> bool {
221        self.len() == 0
222    }
223
224    /// Get a reference to the frame's header.
225    #[inline]
226    pub fn header(&self) -> &FrameHeader {
227        &self.header
228    }
229
230    /// Get a mutable reference to the frame's header.
231    #[inline]
232    pub fn header_mut(&mut self) -> &mut FrameHeader {
233        &mut self.header
234    }
235
236    /// Get a reference to the frame's payload.
237    #[inline]
238    pub fn payload(&self) -> &[u8] {
239        &self.payload
240    }
241
242    /// Test whether the frame is masked.
243    #[inline]
244    pub(crate) fn is_masked(&self) -> bool {
245        self.header.mask.is_some()
246    }
247
248    /// Generate a random mask for the frame.
249    ///
250    /// This just generates a mask, payload is not changed. The actual masking is performed
251    /// either on `format()` or on `apply_mask()` call.
252    #[inline]
253    pub(crate) fn set_random_mask(&mut self) {
254        self.header.set_random_mask();
255    }
256
257    /// Consume the frame into its payload as string.
258    #[inline]
259    pub fn into_text(self) -> StdResult<Utf8Bytes, Utf8Error> {
260        self.payload.try_into()
261    }
262
263    /// Consume the frame into its payload.
264    #[inline]
265    pub fn into_payload(self) -> Bytes {
266        self.payload
267    }
268
269    /// Get frame payload as `&str`.
270    #[inline]
271    pub fn to_text(&self) -> Result<&str, Utf8Error> {
272        std::str::from_utf8(&self.payload)
273    }
274
275    /// Consume the frame into a closing frame.
276    #[inline]
277    pub(crate) fn into_close(self) -> Result<Option<CloseFrame>> {
278        match self.payload.len() {
279            0 => Ok(None),
280            1 => Err(Error::Protocol(ProtocolError::InvalidCloseSequence)),
281            _ => {
282                let code = u16::from_be_bytes([self.payload[0], self.payload[1]]).into();
283                let reason = Utf8Bytes::try_from(self.payload.slice(2..))?;
284                Ok(Some(CloseFrame { code, reason }))
285            }
286        }
287    }
288
289    /// Create a new data frame.
290    #[inline]
291    pub fn message(data: impl Into<Bytes>, opcode: OpCode, is_final: bool) -> Frame {
292        debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame.");
293        Frame {
294            header: FrameHeader { is_final, opcode, ..FrameHeader::default() },
295            payload: data.into(),
296        }
297    }
298
299    /// Create a new Pong control frame.
300    #[inline]
301    pub fn pong(data: impl Into<Bytes>) -> Frame {
302        Frame {
303            header: FrameHeader {
304                opcode: OpCode::Control(Control::Pong),
305                ..FrameHeader::default()
306            },
307            payload: data.into(),
308        }
309    }
310
311    /// Create a new Ping control frame.
312    #[inline]
313    pub fn ping(data: impl Into<Bytes>) -> Frame {
314        Frame {
315            header: FrameHeader {
316                opcode: OpCode::Control(Control::Ping),
317                ..FrameHeader::default()
318            },
319            payload: data.into(),
320        }
321    }
322
323    /// Create a new Close control frame.
324    #[inline]
325    pub fn close(msg: Option<CloseFrame>) -> Frame {
326        let payload = if let Some(CloseFrame { code, reason }) = msg {
327            let mut p = BytesMut::with_capacity(reason.len() + 2);
328            p.extend(u16::from(code).to_be_bytes());
329            p.extend_from_slice(reason.as_bytes());
330            p
331        } else {
332            <_>::default()
333        };
334
335        Frame { header: FrameHeader::default(), payload: payload.into() }
336    }
337
338    /// Create a frame from given header and data.
339    pub fn from_payload(header: FrameHeader, payload: Bytes) -> Self {
340        Frame { header, payload }
341    }
342
343    /// Write a frame out to a buffer
344    pub fn format(mut self, output: &mut impl Write) -> Result<()> {
345        self.header.format(self.payload.len() as u64, output)?;
346
347        if let Some(mask) = self.header.mask.take() {
348            let mut data = Vec::from(mem::take(&mut self.payload));
349            apply_mask(&mut data, mask);
350            output.write_all(&data)?;
351        } else {
352            output.write_all(&self.payload)?;
353        }
354
355        Ok(())
356    }
357
358    pub(crate) fn format_into_buf(mut self, buf: &mut Vec<u8>) -> Result<()> {
359        self.header.format(self.payload.len() as u64, buf)?;
360
361        let len = buf.len();
362        buf.extend_from_slice(&self.payload);
363
364        if let Some(mask) = self.header.mask.take() {
365            apply_mask(&mut buf[len..], mask);
366        }
367
368        Ok(())
369    }
370}
371
372impl fmt::Display for Frame {
373    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
374        use std::fmt::Write;
375
376        write!(
377            f,
378            "
379<FRAME>
380final: {}
381reserved: {} {} {}
382opcode: {}
383length: {}
384payload length: {}
385payload: 0x{}
386            ",
387            self.header.is_final,
388            self.header.rsv1,
389            self.header.rsv2,
390            self.header.rsv3,
391            self.header.opcode,
392            // self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()),
393            self.len(),
394            self.payload.len(),
395            self.payload.iter().fold(String::new(), |mut output, byte| {
396                _ = write!(output, "{byte:02x}");
397                output
398            })
399        )
400    }
401}
402
403/// Handling of the length format.
404enum LengthFormat {
405    U8(u8),
406    U16,
407    U64,
408}
409
410impl LengthFormat {
411    /// Get the length format for a given data size.
412    #[inline]
413    fn for_length(length: u64) -> Self {
414        if length < 126 {
415            LengthFormat::U8(length as u8)
416        } else if length < 65536 {
417            LengthFormat::U16
418        } else {
419            LengthFormat::U64
420        }
421    }
422
423    /// Get the size of the length encoding.
424    #[inline]
425    fn extra_bytes(&self) -> usize {
426        match *self {
427            LengthFormat::U8(_) => 0,
428            LengthFormat::U16 => 2,
429            LengthFormat::U64 => 8,
430        }
431    }
432
433    /// Encode the given length.
434    #[inline]
435    fn length_byte(&self) -> u8 {
436        match *self {
437            LengthFormat::U8(b) => b,
438            LengthFormat::U16 => 126,
439            LengthFormat::U64 => 127,
440        }
441    }
442
443    /// Get the length format for a given length byte.
444    #[inline]
445    fn for_byte(byte: u8) -> Self {
446        match byte & 0x7F {
447            126 => LengthFormat::U16,
448            127 => LengthFormat::U64,
449            b => LengthFormat::U8(b),
450        }
451    }
452}
453
454#[cfg(test)]
455mod tests {
456    use super::*;
457
458    use super::super::coding::{Data, OpCode};
459    use std::io::Cursor;
460
461    #[test]
462    fn parse() {
463        let mut raw: Cursor<Vec<u8>> =
464            Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
465        let (header, length) = FrameHeader::parse(&mut raw).unwrap().unwrap();
466        assert_eq!(length, 7);
467        let mut payload = Vec::new();
468        raw.read_to_end(&mut payload).unwrap();
469        let frame = Frame::from_payload(header, payload.into());
470        assert_eq!(frame.into_payload(), &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07][..]);
471    }
472
473    #[test]
474    fn format() {
475        let frame = Frame::ping(vec![0x01, 0x02]);
476        let mut buf = Vec::with_capacity(frame.len());
477        frame.format(&mut buf).unwrap();
478        assert_eq!(buf, vec![0x89, 0x02, 0x01, 0x02]);
479    }
480
481    #[test]
482    fn format_into_buf() {
483        let frame = Frame::ping(vec![0x01, 0x02]);
484        let mut buf = Vec::with_capacity(frame.len());
485        frame.format_into_buf(&mut buf).unwrap();
486        assert_eq!(buf, vec![0x89, 0x02, 0x01, 0x02]);
487    }
488
489    #[test]
490    fn display() {
491        let f = Frame::message(Bytes::from_static(b"hi there"), OpCode::Data(Data::Text), true);
492        let view = format!("{f}");
493        assert!(view.contains("payload:"));
494    }
495}