tungstenite/protocol/
message.rs

1use super::frame::{CloseFrame, Frame};
2use crate::{
3    error::{CapacityError, Error, Result},
4    protocol::frame::Utf8Bytes,
5};
6use std::{fmt, result::Result as StdResult, str};
7
8mod string_collect {
9    use utf8::DecodeError;
10
11    use crate::error::{Error, Result};
12
13    #[derive(Debug)]
14    pub struct StringCollector {
15        data: String,
16        incomplete: Option<utf8::Incomplete>,
17    }
18
19    impl StringCollector {
20        pub fn new() -> Self {
21            StringCollector { data: String::new(), incomplete: None }
22        }
23
24        pub fn len(&self) -> usize {
25            self.data
26                .len()
27                .saturating_add(self.incomplete.map(|i| i.buffer_len as usize).unwrap_or(0))
28        }
29
30        pub fn extend<T: AsRef<[u8]>>(&mut self, tail: T) -> Result<()> {
31            let mut input: &[u8] = tail.as_ref();
32
33            if let Some(mut incomplete) = self.incomplete.take() {
34                if let Some((result, rest)) = incomplete.try_complete(input) {
35                    input = rest;
36                    if let Ok(text) = result {
37                        self.data.push_str(text);
38                    } else {
39                        return Err(Error::Utf8);
40                    }
41                } else {
42                    input = &[];
43                    self.incomplete = Some(incomplete);
44                }
45            }
46
47            if !input.is_empty() {
48                match utf8::decode(input) {
49                    Ok(text) => {
50                        self.data.push_str(text);
51                        Ok(())
52                    }
53                    Err(DecodeError::Incomplete { valid_prefix, incomplete_suffix }) => {
54                        self.data.push_str(valid_prefix);
55                        self.incomplete = Some(incomplete_suffix);
56                        Ok(())
57                    }
58                    Err(DecodeError::Invalid { valid_prefix, .. }) => {
59                        self.data.push_str(valid_prefix);
60                        Err(Error::Utf8)
61                    }
62                }
63            } else {
64                Ok(())
65            }
66        }
67
68        pub fn into_string(self) -> Result<String> {
69            if self.incomplete.is_some() {
70                Err(Error::Utf8)
71            } else {
72                Ok(self.data)
73            }
74        }
75    }
76}
77
78use self::string_collect::StringCollector;
79use bytes::Bytes;
80
81/// A struct representing the incomplete message.
82#[derive(Debug)]
83pub struct IncompleteMessage {
84    collector: IncompleteMessageCollector,
85}
86
87#[derive(Debug)]
88enum IncompleteMessageCollector {
89    Text(StringCollector),
90    Binary(Vec<u8>),
91}
92
93impl IncompleteMessage {
94    /// Create new.
95    pub fn new(message_type: IncompleteMessageType) -> Self {
96        IncompleteMessage {
97            collector: match message_type {
98                IncompleteMessageType::Binary => IncompleteMessageCollector::Binary(Vec::new()),
99                IncompleteMessageType::Text => {
100                    IncompleteMessageCollector::Text(StringCollector::new())
101                }
102            },
103        }
104    }
105
106    /// Get the current filled size of the buffer.
107    pub fn len(&self) -> usize {
108        match self.collector {
109            IncompleteMessageCollector::Text(ref t) => t.len(),
110            IncompleteMessageCollector::Binary(ref b) => b.len(),
111        }
112    }
113
114    /// Add more data to an existing message.
115    pub fn extend<T: AsRef<[u8]>>(&mut self, tail: T, size_limit: Option<usize>) -> Result<()> {
116        // Always have a max size. This ensures an error in case of concatenating two buffers
117        // of more than `usize::max_value()` bytes in total.
118        let max_size = size_limit.unwrap_or_else(usize::max_value);
119        let my_size = self.len();
120        let portion_size = tail.as_ref().len();
121        // Be careful about integer overflows here.
122        if my_size > max_size || portion_size > max_size - my_size {
123            return Err(Error::Capacity(CapacityError::MessageTooLong {
124                size: my_size + portion_size,
125                max_size,
126            }));
127        }
128
129        match self.collector {
130            IncompleteMessageCollector::Binary(ref mut v) => {
131                v.extend(tail.as_ref());
132                Ok(())
133            }
134            IncompleteMessageCollector::Text(ref mut t) => t.extend(tail),
135        }
136    }
137
138    /// Convert an incomplete message into a complete one.
139    pub fn complete(self) -> Result<Message> {
140        match self.collector {
141            IncompleteMessageCollector::Binary(v) => Ok(Message::Binary(v.into())),
142            IncompleteMessageCollector::Text(t) => {
143                let text = t.into_string()?;
144                Ok(Message::text(text))
145            }
146        }
147    }
148}
149
150/// The type of incomplete message.
151pub enum IncompleteMessageType {
152    Text,
153    Binary,
154}
155
156/// An enum representing the various forms of a WebSocket message.
157#[derive(Debug, Eq, PartialEq, Clone)]
158pub enum Message {
159    /// A text WebSocket message
160    Text(Utf8Bytes),
161    /// A binary WebSocket message
162    Binary(Bytes),
163    /// A ping message with the specified payload
164    ///
165    /// The payload here must have a length less than 125 bytes
166    Ping(Bytes),
167    /// A pong message with the specified payload
168    ///
169    /// The payload here must have a length less than 125 bytes
170    Pong(Bytes),
171    /// A close message with the optional close frame.
172    Close(Option<CloseFrame>),
173    /// Raw frame. Note, that you're not going to get this value while reading the message.
174    Frame(Frame),
175}
176
177impl Message {
178    /// Create a new text WebSocket message from a stringable.
179    pub fn text<S>(string: S) -> Message
180    where
181        S: Into<Utf8Bytes>,
182    {
183        Message::Text(string.into())
184    }
185
186    /// Create a new binary WebSocket message by converting to `Bytes`.
187    pub fn binary<B>(bin: B) -> Message
188    where
189        B: Into<Bytes>,
190    {
191        Message::Binary(bin.into())
192    }
193
194    /// Indicates whether a message is a text message.
195    pub fn is_text(&self) -> bool {
196        matches!(*self, Message::Text(_))
197    }
198
199    /// Indicates whether a message is a binary message.
200    pub fn is_binary(&self) -> bool {
201        matches!(*self, Message::Binary(_))
202    }
203
204    /// Indicates whether a message is a ping message.
205    pub fn is_ping(&self) -> bool {
206        matches!(*self, Message::Ping(_))
207    }
208
209    /// Indicates whether a message is a pong message.
210    pub fn is_pong(&self) -> bool {
211        matches!(*self, Message::Pong(_))
212    }
213
214    /// Indicates whether a message is a close message.
215    pub fn is_close(&self) -> bool {
216        matches!(*self, Message::Close(_))
217    }
218
219    /// Get the length of the WebSocket message.
220    pub fn len(&self) -> usize {
221        match *self {
222            Message::Text(ref string) => string.len(),
223            Message::Binary(ref data) | Message::Ping(ref data) | Message::Pong(ref data) => {
224                data.len()
225            }
226            Message::Close(ref data) => data.as_ref().map(|d| d.reason.len()).unwrap_or(0),
227            Message::Frame(ref frame) => frame.len(),
228        }
229    }
230
231    /// Returns true if the WebSocket message has no content.
232    /// For example, if the other side of the connection sent an empty string.
233    pub fn is_empty(&self) -> bool {
234        self.len() == 0
235    }
236
237    /// Consume the WebSocket and return it as binary data.
238    pub fn into_data(self) -> Bytes {
239        match self {
240            Message::Text(utf8) => utf8.into(),
241            Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => data,
242            Message::Close(None) => <_>::default(),
243            Message::Close(Some(frame)) => frame.reason.into(),
244            Message::Frame(frame) => frame.into_payload(),
245        }
246    }
247
248    /// Attempt to consume the WebSocket message and convert it to a String.
249    pub fn into_text(self) -> Result<Utf8Bytes> {
250        match self {
251            Message::Text(txt) => Ok(txt),
252            Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => {
253                Ok(data.try_into()?)
254            }
255            Message::Close(None) => Ok(<_>::default()),
256            Message::Close(Some(frame)) => Ok(frame.reason),
257            Message::Frame(frame) => Ok(frame.into_text()?),
258        }
259    }
260
261    /// Attempt to get a &str from the WebSocket message,
262    /// this will try to convert binary data to utf8.
263    pub fn to_text(&self) -> Result<&str> {
264        match *self {
265            Message::Text(ref string) => Ok(string.as_str()),
266            Message::Binary(ref data) | Message::Ping(ref data) | Message::Pong(ref data) => {
267                Ok(str::from_utf8(data)?)
268            }
269            Message::Close(None) => Ok(""),
270            Message::Close(Some(ref frame)) => Ok(&frame.reason),
271            Message::Frame(ref frame) => Ok(frame.to_text()?),
272        }
273    }
274}
275
276impl From<String> for Message {
277    #[inline]
278    fn from(string: String) -> Self {
279        Message::text(string)
280    }
281}
282
283impl<'s> From<&'s str> for Message {
284    #[inline]
285    fn from(string: &'s str) -> Self {
286        Message::text(string)
287    }
288}
289
290impl<'b> From<&'b [u8]> for Message {
291    #[inline]
292    fn from(data: &'b [u8]) -> Self {
293        Message::binary(Bytes::copy_from_slice(data))
294    }
295}
296
297impl From<Vec<u8>> for Message {
298    #[inline]
299    fn from(data: Vec<u8>) -> Self {
300        Message::binary(data)
301    }
302}
303
304impl From<Message> for Bytes {
305    #[inline]
306    fn from(message: Message) -> Self {
307        message.into_data()
308    }
309}
310
311impl fmt::Display for Message {
312    fn fmt(&self, f: &mut fmt::Formatter) -> StdResult<(), fmt::Error> {
313        if let Ok(string) = self.to_text() {
314            write!(f, "{string}")
315        } else {
316            write!(f, "Binary Data<length={}>", self.len())
317        }
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    #[test]
326    fn display() {
327        let t = Message::text("test".to_owned());
328        assert_eq!(t.to_string(), "test".to_owned());
329
330        let bin = Message::binary(vec![0, 1, 3, 4, 241]);
331        assert_eq!(bin.to_string(), "Binary Data<length=5>".to_owned());
332    }
333
334    #[test]
335    fn binary_convert() {
336        let bin = [6u8, 7, 8, 9, 10, 241];
337        let msg = Message::from(&bin[..]);
338        assert!(msg.is_binary());
339        assert!(msg.into_text().is_err());
340    }
341
342    #[test]
343    fn binary_convert_vec() {
344        let bin = vec![6u8, 7, 8, 9, 10, 241];
345        let msg = Message::from(bin);
346        assert!(msg.is_binary());
347        assert!(msg.into_text().is_err());
348    }
349
350    #[test]
351    fn binary_convert_into_bytes() {
352        let bin = vec![6u8, 7, 8, 9, 10, 241];
353        let bin_copy = bin.clone();
354        let msg = Message::from(bin);
355        let serialized: Bytes = msg.into();
356        assert_eq!(bin_copy, serialized);
357    }
358
359    #[test]
360    fn text_convert() {
361        let s = "kiwotsukete";
362        let msg = Message::from(s);
363        assert!(msg.is_text());
364    }
365}