csv_async/async_writers/
mwtr_serde.rs

1use std::io;
2use std::io::Write;
3
4use csv_core::{
5    self, WriteResult, Writer as CoreWriter
6};
7use serde::Serialize;
8
9use crate::error::{Error, ErrorKind, Result};
10use crate::serializer::{serialize, serialize_header};
11use crate::AsyncWriterBuilder;
12
13/// A helper struct to synchronously perform serialization of structures to bytes stored in memory
14/// according to interface provided by serde::Serialize.
15/// Those bytes are being then asynchronously sent to writer.
16/// 
17//  TODO: The `buf` here is present to ease using csv_core interface, 
18//  but is redundant, degrade performance and should be eliminated.
19#[derive(Debug)]
20pub struct MemWriter {
21    core: CoreWriter,
22    wtr: io::Cursor<Vec<u8>>,
23    buf: Buffer,
24    state: WriterState,
25}
26
27#[derive(Debug)]
28struct WriterState {
29    /// Whether the Serde serializer should attempt to write a header row.
30    header: HeaderState,
31    /// Whether inconsistent record lengths are allowed.
32    flexible: bool,
33    /// The number of fields writtein in the first record. This is compared
34    /// with `fields_written` on all subsequent records to check for
35    /// inconsistent record lengths.
36    first_field_count: Option<u64>,
37    /// The number of fields written in this record. This is used to report
38    /// errors for inconsistent record lengths if `flexible` is disabled.
39    fields_written: u64,
40    /// This is set immediately before flushing the buffer and then unset
41    /// immediately after flushing the buffer. This avoids flushing the buffer
42    /// twice if the inner writer panics.
43    panicked: bool,
44}
45
46/// HeaderState encodes a small state machine for handling header writes.
47#[derive(Debug)]
48enum HeaderState {
49    /// Indicates that we should attempt to write a header.
50    Write,
51    /// Indicates that writing a header was attempt, and a header was written.
52    DidWrite,
53    /// Indicates that writing a header was attempted, but no headers were
54    /// written or the attempt failed.
55    DidNotWrite,
56    /// This state is used when headers are disabled. It cannot transition
57    /// to any other state.
58    None,
59}
60
61/// A simple internal buffer for buffering writes.
62///
63/// We need this because the `csv_core` APIs want to write into a `&mut [u8]`,
64/// which is not available with the `std::io::BufWriter` API.
65#[derive(Debug)]
66struct Buffer {
67    /// The contents of the buffer.
68    buf: Vec<u8>,
69    /// The number of bytes written to the buffer.
70    len: usize,
71}
72
73impl Drop for MemWriter {
74    fn drop(&mut self) {
75        if !self.state.panicked {
76            let _ = self.flush();
77        }
78    }
79}
80
81impl MemWriter {
82    /// Create MemWriter using configuration stored in AsyncWriterBuilder.
83    /// 
84    pub fn new(builder: &AsyncWriterBuilder) -> Self {
85        let header_state = if builder.has_headers {
86            HeaderState::Write
87        } else {
88            HeaderState::None
89        };
90        MemWriter {
91            core: builder.builder.build(),
92            wtr: io::Cursor::new(Vec::new()),
93            buf: Buffer { buf: vec![0; builder.capacity], len: 0 },
94            state: WriterState {
95                header: header_state,
96                flexible: builder.flexible,
97                first_field_count: None,
98                fields_written: 0,
99                panicked: false,
100            },
101        }
102    }
103
104    /// Serialize a single record using Serde.
105    ///
106    pub fn serialize<S: Serialize>(&mut self, record: S) -> Result<()> {
107        if let HeaderState::Write = self.state.header {
108            let wrote_header = serialize_header(self, &record)?;
109            if wrote_header {
110                self.write_terminator()?;
111                self.state.header = HeaderState::DidWrite;
112            } else {
113                self.state.header = HeaderState::DidNotWrite;
114            };
115        }
116        serialize(self, &record)?;
117        self.write_terminator()?;
118        Ok(())
119    }
120
121    /// Write a single field.
122    pub fn write_field<T: AsRef<[u8]>>(&mut self, field: T) -> Result<()> {
123        self.write_field_impl(field)
124    }
125
126    /// Implementation of write_field.
127    ///
128    /// This is a separate method so we can force the compiler to inline it
129    /// into write_record.
130    #[inline(always)]
131    fn write_field_impl<T: AsRef<[u8]>>(&mut self, field: T) -> Result<()> {
132        if self.state.fields_written > 0 {
133            self.write_delimiter()?;
134        }
135        let mut field = field.as_ref();
136        loop {
137            let (res, nin, nout) = self.core.field(field, self.buf.writable());
138            field = &field[nin..];
139            self.buf.written(nout);
140            match res {
141                WriteResult::InputEmpty => {
142                    self.state.fields_written += 1;
143                    return Ok(());
144                }
145                WriteResult::OutputFull => self.flush_buf()?,
146            }
147        }
148    }
149
150    /// Flush the contents of the internal buffer to the underlying writer.
151    ///
152    /// If there was a problem writing to the underlying writer, then an error
153    /// is returned.
154    ///
155    /// Note that this also flushes the underlying writer.
156    pub fn flush(&mut self) -> io::Result<()> {
157        self.flush_buf()?;
158        self.wtr.flush()?;
159        Ok(())
160    }
161
162    /// Flush the contents of the internal buffer to the underlying writer,
163    /// without flushing the underlying writer.
164    fn flush_buf(&mut self) -> io::Result<()> {
165        self.state.panicked = true;
166        let result = self.wtr.write_all(self.buf.readable());
167        self.state.panicked = false;
168        result?;
169        self.buf.clear();
170        Ok(())
171    }
172
173    /// Returns slice with the accumulated data.
174    /// Caller is responsible for calling `flush()` before this call, 
175    /// otherwise returned vector may not contain all the data.
176    pub fn data(&mut self) -> &[u8] {
177        self.wtr.get_mut().as_slice()
178    }
179
180    /// Clears Writer internal vector, but not buffer.
181    // TODO: See note about removing double buffering
182    pub fn clear(&mut self) {
183        self.wtr.get_mut().clear();
184        self.wtr.set_position(0);
185    }
186
187    /// Write a CSV delimiter.
188    fn write_delimiter(&mut self) -> Result<()> {
189        loop {
190            let (res, nout) = self.core.delimiter(self.buf.writable());
191            self.buf.written(nout);
192            match res {
193                WriteResult::InputEmpty => return Ok(()),
194                WriteResult::OutputFull => self.flush_buf()?,
195            }
196        }
197    }
198
199    /// Write a CSV terminator.
200    fn write_terminator(&mut self) -> Result<()> {
201        self.check_field_count()?;
202        loop {
203            let (res, nout) = self.core.terminator(self.buf.writable());
204            self.buf.written(nout);
205            match res {
206                WriteResult::InputEmpty => {
207                    self.state.fields_written = 0;
208                    return Ok(());
209                }
210                WriteResult::OutputFull => self.flush_buf()?,
211            }
212        }
213    }
214
215    fn check_field_count(&mut self) -> Result<()> {
216        if !self.state.flexible {
217            match self.state.first_field_count {
218                None => {
219                    self.state.first_field_count =
220                        Some(self.state.fields_written);
221                }
222                Some(expected) if expected != self.state.fields_written => {
223                    return Err(Error::new(ErrorKind::UnequalLengths {
224                        pos: None,
225                        expected_len: expected,
226                        len: self.state.fields_written,
227                    }))
228                }
229                Some(_) => {}
230            }
231        }
232        Ok(())
233    }
234}
235
236impl Buffer {
237    /// Returns a slice of the buffer's current contents.
238    ///
239    /// The slice returned may be empty.
240    #[inline]
241    fn readable(&self) -> &[u8] {
242        &self.buf[..self.len]
243    }
244
245    /// Returns a mutable slice of the remaining space in this buffer.
246    ///
247    /// The slice returned may be empty.
248    #[inline]
249    fn writable(&mut self) -> &mut [u8] {
250        &mut self.buf[self.len..]
251    }
252
253    /// Indicates that `n` bytes have been written to this buffer.
254    #[inline]
255    fn written(&mut self, n: usize) {
256        self.len += n;
257    }
258
259    /// Clear the buffer.
260    #[inline]
261    fn clear(&mut self) {
262        self.len = 0;
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use std::error::Error;
269
270    use serde::{serde_if_integer128, Serialize};
271
272    use crate::byte_record::ByteRecord;
273    use crate::error::{ErrorKind, IntoInnerError};
274    use crate::string_record::StringRecord;
275
276    use super::{MemWriter, AsyncWriterBuilder};
277
278    fn wtr_as_string(wtr: MemWriter) -> String {
279        String::from_utf8(wtr.into_inner().unwrap()).unwrap()
280    }
281
282    impl MemWriter {
283        pub fn default() -> Self {
284            Self::new(&AsyncWriterBuilder::new())
285        }
286        
287        pub fn into_inner(
288            mut self,
289        ) -> Result<Vec<u8>, IntoInnerError<MemWriter>> {
290            match self.flush() {
291                // This is not official API, so not worth to use Option trick.
292                Ok(()) => Ok(self.wtr.clone().into_inner()),
293                Err(err) => Err(IntoInnerError::new(self, err)),
294            }
295        }
296    
297        pub fn write_record<I, T>(&mut self, record: I) -> crate::error::Result<()>
298        where
299            I: IntoIterator<Item = T>,
300            T: AsRef<[u8]>,
301        {
302            for field in record.into_iter() {
303                self.write_field_impl(field)?;
304            }
305            self.write_terminator()
306        }
307    }
308
309    #[test]
310    fn one_record() {
311        let mut wtr = MemWriter::default();
312        wtr.write_record(&["a", "b", "c"]).unwrap();
313
314        assert_eq!(wtr_as_string(wtr), "a,b,c\n");
315    }
316
317    #[test]
318    fn one_string_record() {
319        let mut wtr = MemWriter::default();
320        wtr.write_record(&StringRecord::from(vec!["a", "b", "c"])).unwrap();
321
322        assert_eq!(wtr_as_string(wtr), "a,b,c\n");
323    }
324
325    #[test]
326    fn one_byte_record() {
327        let mut wtr = MemWriter::default();
328        wtr.write_record(&ByteRecord::from(vec!["a", "b", "c"])).unwrap();
329
330        assert_eq!(wtr_as_string(wtr), "a,b,c\n");
331    }
332
333    #[test]
334    fn one_empty_record() {
335        let mut wtr = MemWriter::default();
336        wtr.write_record(&[""]).unwrap();
337
338        assert_eq!(wtr_as_string(wtr), "\"\"\n");
339    }
340
341    #[test]
342    fn two_empty_records() {
343        let mut wtr = MemWriter::default();
344        wtr.write_record(&[""]).unwrap();
345        wtr.write_record(&[""]).unwrap();
346
347        assert_eq!(wtr_as_string(wtr), "\"\"\n\"\"\n");
348    }
349
350    #[test]
351    fn unequal_records_bad() {
352        let mut wtr = MemWriter::default();
353        wtr.write_record(&ByteRecord::from(vec!["a", "b", "c"])).unwrap();
354        let err = wtr.write_record(&ByteRecord::from(vec!["a"])).unwrap_err();
355        match *err.kind() {
356            ErrorKind::UnequalLengths { ref pos, expected_len, len } => {
357                assert!(pos.is_none());
358                assert_eq!(expected_len, 3);
359                assert_eq!(len, 1);
360            }
361            ref x => {
362                panic!("expected UnequalLengths error, but got '{:?}'", x);
363            }
364        }
365    }
366
367    #[test]
368    fn unequal_records_ok() {
369        let mut wtr = MemWriter::new(&AsyncWriterBuilder::new().flexible(true));
370        wtr.write_record(&ByteRecord::from(vec!["a", "b", "c"])).unwrap();
371        wtr.write_record(&ByteRecord::from(vec!["a"])).unwrap();
372        assert_eq!(wtr_as_string(wtr), "a,b,c\na\n");
373    }
374
375    #[test]
376    fn write_field() -> Result<(), Box<dyn Error>> {
377        let mut wtr = MemWriter::default();
378        wtr.write_field("a")?;
379        wtr.write_field("b")?;
380        wtr.write_field("c")?;
381        wtr.write_terminator()?;
382        wtr.write_field("x")?;
383        wtr.write_field("y")?;
384        wtr.write_field("z")?;
385        wtr.write_terminator()?;
386
387        let data = String::from_utf8(wtr.into_inner()?)?;
388        assert_eq!(data, "a,b,c\nx,y,z\n");
389        Ok(())
390    }
391
392    #[test]
393    fn serialize_with_headers() {
394        #[derive(Serialize)]
395        struct Row {
396            foo: i32,
397            bar: f64,
398            baz: bool,
399        }
400
401        let mut wtr = MemWriter::default();
402        wtr.serialize(Row { foo: 42, bar: 42.5, baz: true }).unwrap();
403        assert_eq!(wtr_as_string(wtr), "foo,bar,baz\n42,42.5,true\n");
404    }
405
406    #[test]
407    fn serialize_no_headers() {
408        #[derive(Serialize)]
409        struct Row {
410            foo: i32,
411            bar: f64,
412            baz: bool,
413        }
414
415        let mut wtr = MemWriter::new(&AsyncWriterBuilder::new().has_headers(false));
416        wtr.serialize(Row { foo: 42, bar: 42.5, baz: true }).unwrap();
417        assert_eq!(wtr_as_string(wtr), "42,42.5,true\n");
418    }
419
420    serde_if_integer128! {
421        #[test]
422        fn serialize_no_headers_128() {
423            #[derive(Serialize)]
424            struct Row {
425                foo: i128,
426                bar: f64,
427                baz: bool,
428            }
429
430            let mut wtr = MemWriter::new(&AsyncWriterBuilder::new().has_headers(false));
431            wtr.serialize(Row {
432                foo: 9_223_372_036_854_775_808,
433                bar: 42.5,
434                baz: true,
435            }).unwrap();
436            assert_eq!(wtr_as_string(wtr), "9223372036854775808,42.5,true\n");
437        }
438    }
439
440    #[test]
441    fn serialize_tuple() {
442        let mut wtr = MemWriter::default();
443        wtr.serialize((true, 1.3, "hi")).unwrap();
444        assert_eq!(wtr_as_string(wtr), "true,1.3,hi\n");
445    }
446
447    #[test]
448    fn serialize_struct() -> Result<(), Box<dyn Error>> {
449        #[derive(Serialize)]
450        struct Row<'a> {
451            city: &'a str,
452            country: &'a str,
453            // Serde allows us to name our headers exactly,
454            // even if they don't match our struct field names.
455            #[serde(rename = "popcount")]
456            population: u64,
457        }
458        
459        let mut wtr = MemWriter::default();
460        wtr.serialize(Row {
461            city: "Boston",
462            country: "United States",
463            population: 4628910,
464        })?;
465        wtr.serialize(Row {
466            city: "Concord",
467            country: "United States",
468            population: 42695,
469        })?;
470    
471        let data = String::from_utf8(wtr.into_inner()?)?;
472        assert_eq!(data, "\
473city,country,popcount
474Boston,United States,4628910
475Concord,United States,42695
476");
477        Ok(())
478    }
479
480    #[test]
481    fn serialize_enum() -> Result<(), Box<dyn Error>> {
482        #[derive(Serialize)]
483        struct Row {
484            label: String,
485            value: Value,
486        }
487        #[derive(Serialize)]
488        enum Value {
489            Integer(i64),
490            Float(f64),
491        }
492
493        let mut wtr = MemWriter::default();
494        wtr.serialize(Row {
495            label: "foo".to_string(),
496            value: Value::Integer(3),
497        })?;
498        wtr.serialize(Row {
499            label: "bar".to_string(),
500            value: Value::Float(3.14),
501        })?;
502
503        let data = String::from_utf8(wtr.into_inner()?)?;
504        assert_eq!(data, "\
505label,value
506foo,3
507bar,3.14
508");
509        Ok(())
510    }
511    
512    #[test]
513    fn serialize_vec() -> Result<(), Box<dyn Error>> {
514        #[derive(Serialize)]
515        struct Row {
516            label: String,
517            values: Vec<f64>,
518        }
519
520        let mut wtr = MemWriter::new(
521            &AsyncWriterBuilder::new()
522                .has_headers(false)
523            );
524        wtr.serialize(Row {
525            label: "foo".to_string(),
526            values: vec![1.1234, 2.5678, 3.14],
527        })?;
528
529        let data = String::from_utf8(wtr.into_inner()?)?;
530        assert_eq!(data, "foo,1.1234,2.5678,3.14\n");
531        Ok(())
532    }
533}