1use log::*;
2use std::{
3 default::Default,
4 fmt,
5 io::{Cursor, ErrorKind, Read, Write},
6 mem,
7 result::Result as StdResult,
8 str::Utf8Error,
9 string::String,
10};
11
12use super::{
13 coding::{CloseCode, Control, Data, OpCode},
14 mask::{apply_mask, generate_mask},
15};
16use crate::{
17 error::{Error, ProtocolError, Result},
18 protocol::frame::Utf8Bytes,
19};
20use bytes::{Bytes, BytesMut};
21
22#[derive(Debug, Clone, Eq, PartialEq)]
24pub struct CloseFrame {
25 pub code: CloseCode,
27 pub reason: Utf8Bytes,
29}
30
31impl fmt::Display for CloseFrame {
32 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
33 write!(f, "{} ({})", self.reason, self.code)
34 }
35}
36
37#[allow(missing_copy_implementations)]
39#[derive(Debug, Clone, Eq, PartialEq)]
40pub struct FrameHeader {
41 pub is_final: bool,
43 pub rsv1: bool,
45 pub rsv2: bool,
47 pub rsv3: bool,
49 pub opcode: OpCode,
51 pub mask: Option<[u8; 4]>,
53}
54
55impl Default for FrameHeader {
56 fn default() -> Self {
57 FrameHeader {
58 is_final: true,
59 rsv1: false,
60 rsv2: false,
61 rsv3: false,
62 opcode: OpCode::Control(Control::Close),
63 mask: None,
64 }
65 }
66}
67
68impl FrameHeader {
69 pub fn parse(cursor: &mut Cursor<impl AsRef<[u8]>>) -> Result<Option<(Self, u64)>> {
73 let initial = cursor.position();
74 match Self::parse_internal(cursor) {
75 ret @ Ok(None) => {
76 cursor.set_position(initial);
77 ret
78 }
79 ret => ret,
80 }
81 }
82
83 #[allow(clippy::len_without_is_empty)]
85 pub fn len(&self, length: u64) -> usize {
86 2 + LengthFormat::for_length(length).extra_bytes() + if self.mask.is_some() { 4 } else { 0 }
87 }
88
89 pub fn format(&self, length: u64, output: &mut impl Write) -> Result<()> {
91 let code: u8 = self.opcode.into();
92
93 let one = {
94 code | if self.is_final { 0x80 } else { 0 }
95 | if self.rsv1 { 0x40 } else { 0 }
96 | if self.rsv2 { 0x20 } else { 0 }
97 | if self.rsv3 { 0x10 } else { 0 }
98 };
99
100 let lenfmt = LengthFormat::for_length(length);
101
102 let two = { lenfmt.length_byte() | if self.mask.is_some() { 0x80 } else { 0 } };
103
104 output.write_all(&[one, two])?;
105 match lenfmt {
106 LengthFormat::U8(_) => (),
107 LengthFormat::U16 => {
108 output.write_all(&(length as u16).to_be_bytes())?;
109 }
110 LengthFormat::U64 => {
111 output.write_all(&length.to_be_bytes())?;
112 }
113 }
114
115 if let Some(ref mask) = self.mask {
116 output.write_all(mask)?;
117 }
118
119 Ok(())
120 }
121
122 pub(crate) fn set_random_mask(&mut self) {
126 self.mask = Some(generate_mask());
127 }
128}
129
130impl FrameHeader {
131 fn parse_internal(cursor: &mut impl Read) -> Result<Option<(Self, u64)>> {
135 let (first, second) = {
136 let mut head = [0u8; 2];
137 if cursor.read(&mut head)? != 2 {
138 return Ok(None);
139 }
140 trace!("Parsed headers {:?}", head);
141 (head[0], head[1])
142 };
143
144 trace!("First: {:b}", first);
145 trace!("Second: {:b}", second);
146
147 let is_final = first & 0x80 != 0;
148
149 let rsv1 = first & 0x40 != 0;
150 let rsv2 = first & 0x20 != 0;
151 let rsv3 = first & 0x10 != 0;
152
153 let opcode = OpCode::from(first & 0x0F);
154 trace!("Opcode: {:?}", opcode);
155
156 let masked = second & 0x80 != 0;
157 trace!("Masked: {:?}", masked);
158
159 let length = {
160 let length_byte = second & 0x7F;
161 let length_length = LengthFormat::for_byte(length_byte).extra_bytes();
162 if length_length > 0 {
163 const SIZE: usize = mem::size_of::<u64>();
164 assert!(length_length <= SIZE, "length exceeded size of u64");
165 let start = SIZE - length_length;
166 let mut buffer = [0; SIZE];
167 match cursor.read_exact(&mut buffer[start..]) {
168 Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => return Ok(None),
169 Err(err) => return Err(err.into()),
170 Ok(()) => u64::from_be_bytes(buffer),
171 }
172 } else {
173 u64::from(length_byte)
174 }
175 };
176
177 let mask = if masked {
178 let mut mask_bytes = [0u8; 4];
179 if cursor.read(&mut mask_bytes)? != 4 {
180 return Ok(None);
181 } else {
182 Some(mask_bytes)
183 }
184 } else {
185 None
186 };
187
188 match opcode {
190 OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => {
191 return Err(Error::Protocol(ProtocolError::InvalidOpcode(first & 0x0F)))
192 }
193 _ => (),
194 }
195
196 let hdr = FrameHeader { is_final, rsv1, rsv2, rsv3, opcode, mask };
197
198 Ok(Some((hdr, length)))
199 }
200}
201
202#[derive(Debug, Clone, Eq, PartialEq)]
204pub struct Frame {
205 header: FrameHeader,
206 payload: Bytes,
207}
208
209impl Frame {
210 #[inline]
213 pub fn len(&self) -> usize {
214 let length = self.payload.len();
215 self.header.len(length as u64) + length
216 }
217
218 #[inline]
220 pub fn is_empty(&self) -> bool {
221 self.len() == 0
222 }
223
224 #[inline]
226 pub fn header(&self) -> &FrameHeader {
227 &self.header
228 }
229
230 #[inline]
232 pub fn header_mut(&mut self) -> &mut FrameHeader {
233 &mut self.header
234 }
235
236 #[inline]
238 pub fn payload(&self) -> &[u8] {
239 &self.payload
240 }
241
242 #[inline]
244 pub(crate) fn is_masked(&self) -> bool {
245 self.header.mask.is_some()
246 }
247
248 #[inline]
253 pub(crate) fn set_random_mask(&mut self) {
254 self.header.set_random_mask();
255 }
256
257 #[inline]
259 pub fn into_text(self) -> StdResult<Utf8Bytes, Utf8Error> {
260 self.payload.try_into()
261 }
262
263 #[inline]
265 pub fn into_payload(self) -> Bytes {
266 self.payload
267 }
268
269 #[inline]
271 pub fn to_text(&self) -> Result<&str, Utf8Error> {
272 std::str::from_utf8(&self.payload)
273 }
274
275 #[inline]
277 pub(crate) fn into_close(self) -> Result<Option<CloseFrame>> {
278 match self.payload.len() {
279 0 => Ok(None),
280 1 => Err(Error::Protocol(ProtocolError::InvalidCloseSequence)),
281 _ => {
282 let code = u16::from_be_bytes([self.payload[0], self.payload[1]]).into();
283 let reason = Utf8Bytes::try_from(self.payload.slice(2..))?;
284 Ok(Some(CloseFrame { code, reason }))
285 }
286 }
287 }
288
289 #[inline]
291 pub fn message(data: impl Into<Bytes>, opcode: OpCode, is_final: bool) -> Frame {
292 debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame.");
293 Frame {
294 header: FrameHeader { is_final, opcode, ..FrameHeader::default() },
295 payload: data.into(),
296 }
297 }
298
299 #[inline]
301 pub fn pong(data: impl Into<Bytes>) -> Frame {
302 Frame {
303 header: FrameHeader {
304 opcode: OpCode::Control(Control::Pong),
305 ..FrameHeader::default()
306 },
307 payload: data.into(),
308 }
309 }
310
311 #[inline]
313 pub fn ping(data: impl Into<Bytes>) -> Frame {
314 Frame {
315 header: FrameHeader {
316 opcode: OpCode::Control(Control::Ping),
317 ..FrameHeader::default()
318 },
319 payload: data.into(),
320 }
321 }
322
323 #[inline]
325 pub fn close(msg: Option<CloseFrame>) -> Frame {
326 let payload = if let Some(CloseFrame { code, reason }) = msg {
327 let mut p = BytesMut::with_capacity(reason.len() + 2);
328 p.extend(u16::from(code).to_be_bytes());
329 p.extend_from_slice(reason.as_bytes());
330 p
331 } else {
332 <_>::default()
333 };
334
335 Frame { header: FrameHeader::default(), payload: payload.into() }
336 }
337
338 pub fn from_payload(header: FrameHeader, payload: Bytes) -> Self {
340 Frame { header, payload }
341 }
342
343 pub fn format(mut self, output: &mut impl Write) -> Result<()> {
345 self.header.format(self.payload.len() as u64, output)?;
346
347 if let Some(mask) = self.header.mask.take() {
348 let mut data = Vec::from(mem::take(&mut self.payload));
349 apply_mask(&mut data, mask);
350 output.write_all(&data)?;
351 } else {
352 output.write_all(&self.payload)?;
353 }
354
355 Ok(())
356 }
357
358 pub(crate) fn format_into_buf(mut self, buf: &mut Vec<u8>) -> Result<()> {
359 self.header.format(self.payload.len() as u64, buf)?;
360
361 let len = buf.len();
362 buf.extend_from_slice(&self.payload);
363
364 if let Some(mask) = self.header.mask.take() {
365 apply_mask(&mut buf[len..], mask);
366 }
367
368 Ok(())
369 }
370}
371
372impl fmt::Display for Frame {
373 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
374 use std::fmt::Write;
375
376 write!(
377 f,
378 "
379<FRAME>
380final: {}
381reserved: {} {} {}
382opcode: {}
383length: {}
384payload length: {}
385payload: 0x{}
386 ",
387 self.header.is_final,
388 self.header.rsv1,
389 self.header.rsv2,
390 self.header.rsv3,
391 self.header.opcode,
392 self.len(),
394 self.payload.len(),
395 self.payload.iter().fold(String::new(), |mut output, byte| {
396 _ = write!(output, "{byte:02x}");
397 output
398 })
399 )
400 }
401}
402
403enum LengthFormat {
405 U8(u8),
406 U16,
407 U64,
408}
409
410impl LengthFormat {
411 #[inline]
413 fn for_length(length: u64) -> Self {
414 if length < 126 {
415 LengthFormat::U8(length as u8)
416 } else if length < 65536 {
417 LengthFormat::U16
418 } else {
419 LengthFormat::U64
420 }
421 }
422
423 #[inline]
425 fn extra_bytes(&self) -> usize {
426 match *self {
427 LengthFormat::U8(_) => 0,
428 LengthFormat::U16 => 2,
429 LengthFormat::U64 => 8,
430 }
431 }
432
433 #[inline]
435 fn length_byte(&self) -> u8 {
436 match *self {
437 LengthFormat::U8(b) => b,
438 LengthFormat::U16 => 126,
439 LengthFormat::U64 => 127,
440 }
441 }
442
443 #[inline]
445 fn for_byte(byte: u8) -> Self {
446 match byte & 0x7F {
447 126 => LengthFormat::U16,
448 127 => LengthFormat::U64,
449 b => LengthFormat::U8(b),
450 }
451 }
452}
453
454#[cfg(test)]
455mod tests {
456 use super::*;
457
458 use super::super::coding::{Data, OpCode};
459 use std::io::Cursor;
460
461 #[test]
462 fn parse() {
463 let mut raw: Cursor<Vec<u8>> =
464 Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
465 let (header, length) = FrameHeader::parse(&mut raw).unwrap().unwrap();
466 assert_eq!(length, 7);
467 let mut payload = Vec::new();
468 raw.read_to_end(&mut payload).unwrap();
469 let frame = Frame::from_payload(header, payload.into());
470 assert_eq!(frame.into_payload(), &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07][..]);
471 }
472
473 #[test]
474 fn format() {
475 let frame = Frame::ping(vec![0x01, 0x02]);
476 let mut buf = Vec::with_capacity(frame.len());
477 frame.format(&mut buf).unwrap();
478 assert_eq!(buf, vec![0x89, 0x02, 0x01, 0x02]);
479 }
480
481 #[test]
482 fn format_into_buf() {
483 let frame = Frame::ping(vec![0x01, 0x02]);
484 let mut buf = Vec::with_capacity(frame.len());
485 frame.format_into_buf(&mut buf).unwrap();
486 assert_eq!(buf, vec![0x89, 0x02, 0x01, 0x02]);
487 }
488
489 #[test]
490 fn display() {
491 let f = Frame::message(Bytes::from_static(b"hi there"), OpCode::Data(Data::Text), true);
492 let view = format!("{f}");
493 assert!(view.contains("payload:"));
494 }
495}