1use std::io;
2use std::io::Write;
3
4use csv_core::{
5 self, WriteResult, Writer as CoreWriter
6};
7use serde::Serialize;
8
9use crate::error::{Error, ErrorKind, Result};
10use crate::serializer::{serialize, serialize_header};
11use crate::AsyncWriterBuilder;
12
13#[derive(Debug)]
20pub struct MemWriter {
21 core: CoreWriter,
22 wtr: io::Cursor<Vec<u8>>,
23 buf: Buffer,
24 state: WriterState,
25}
26
27#[derive(Debug)]
28struct WriterState {
29 header: HeaderState,
31 flexible: bool,
33 first_field_count: Option<u64>,
37 fields_written: u64,
40 panicked: bool,
44}
45
46#[derive(Debug)]
48enum HeaderState {
49 Write,
51 DidWrite,
53 DidNotWrite,
56 None,
59}
60
61#[derive(Debug)]
66struct Buffer {
67 buf: Vec<u8>,
69 len: usize,
71}
72
73impl Drop for MemWriter {
74 fn drop(&mut self) {
75 if !self.state.panicked {
76 let _ = self.flush();
77 }
78 }
79}
80
81impl MemWriter {
82 pub fn new(builder: &AsyncWriterBuilder) -> Self {
85 let header_state = if builder.has_headers {
86 HeaderState::Write
87 } else {
88 HeaderState::None
89 };
90 MemWriter {
91 core: builder.builder.build(),
92 wtr: io::Cursor::new(Vec::new()),
93 buf: Buffer { buf: vec![0; builder.capacity], len: 0 },
94 state: WriterState {
95 header: header_state,
96 flexible: builder.flexible,
97 first_field_count: None,
98 fields_written: 0,
99 panicked: false,
100 },
101 }
102 }
103
104 pub fn serialize<S: Serialize>(&mut self, record: S) -> Result<()> {
107 if let HeaderState::Write = self.state.header {
108 let wrote_header = serialize_header(self, &record)?;
109 if wrote_header {
110 self.write_terminator()?;
111 self.state.header = HeaderState::DidWrite;
112 } else {
113 self.state.header = HeaderState::DidNotWrite;
114 };
115 }
116 serialize(self, &record)?;
117 self.write_terminator()?;
118 Ok(())
119 }
120
121 pub fn write_field<T: AsRef<[u8]>>(&mut self, field: T) -> Result<()> {
123 self.write_field_impl(field)
124 }
125
126 #[inline(always)]
131 fn write_field_impl<T: AsRef<[u8]>>(&mut self, field: T) -> Result<()> {
132 if self.state.fields_written > 0 {
133 self.write_delimiter()?;
134 }
135 let mut field = field.as_ref();
136 loop {
137 let (res, nin, nout) = self.core.field(field, self.buf.writable());
138 field = &field[nin..];
139 self.buf.written(nout);
140 match res {
141 WriteResult::InputEmpty => {
142 self.state.fields_written += 1;
143 return Ok(());
144 }
145 WriteResult::OutputFull => self.flush_buf()?,
146 }
147 }
148 }
149
150 pub fn flush(&mut self) -> io::Result<()> {
157 self.flush_buf()?;
158 self.wtr.flush()?;
159 Ok(())
160 }
161
162 fn flush_buf(&mut self) -> io::Result<()> {
165 self.state.panicked = true;
166 let result = self.wtr.write_all(self.buf.readable());
167 self.state.panicked = false;
168 result?;
169 self.buf.clear();
170 Ok(())
171 }
172
173 pub fn data(&mut self) -> &[u8] {
177 self.wtr.get_mut().as_slice()
178 }
179
180 pub fn clear(&mut self) {
183 self.wtr.get_mut().clear();
184 self.wtr.set_position(0);
185 }
186
187 fn write_delimiter(&mut self) -> Result<()> {
189 loop {
190 let (res, nout) = self.core.delimiter(self.buf.writable());
191 self.buf.written(nout);
192 match res {
193 WriteResult::InputEmpty => return Ok(()),
194 WriteResult::OutputFull => self.flush_buf()?,
195 }
196 }
197 }
198
199 fn write_terminator(&mut self) -> Result<()> {
201 self.check_field_count()?;
202 loop {
203 let (res, nout) = self.core.terminator(self.buf.writable());
204 self.buf.written(nout);
205 match res {
206 WriteResult::InputEmpty => {
207 self.state.fields_written = 0;
208 return Ok(());
209 }
210 WriteResult::OutputFull => self.flush_buf()?,
211 }
212 }
213 }
214
215 fn check_field_count(&mut self) -> Result<()> {
216 if !self.state.flexible {
217 match self.state.first_field_count {
218 None => {
219 self.state.first_field_count =
220 Some(self.state.fields_written);
221 }
222 Some(expected) if expected != self.state.fields_written => {
223 return Err(Error::new(ErrorKind::UnequalLengths {
224 pos: None,
225 expected_len: expected,
226 len: self.state.fields_written,
227 }))
228 }
229 Some(_) => {}
230 }
231 }
232 Ok(())
233 }
234}
235
236impl Buffer {
237 #[inline]
241 fn readable(&self) -> &[u8] {
242 &self.buf[..self.len]
243 }
244
245 #[inline]
249 fn writable(&mut self) -> &mut [u8] {
250 &mut self.buf[self.len..]
251 }
252
253 #[inline]
255 fn written(&mut self, n: usize) {
256 self.len += n;
257 }
258
259 #[inline]
261 fn clear(&mut self) {
262 self.len = 0;
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use std::error::Error;
269
270 use serde::{serde_if_integer128, Serialize};
271
272 use crate::byte_record::ByteRecord;
273 use crate::error::{ErrorKind, IntoInnerError};
274 use crate::string_record::StringRecord;
275
276 use super::{MemWriter, AsyncWriterBuilder};
277
278 fn wtr_as_string(wtr: MemWriter) -> String {
279 String::from_utf8(wtr.into_inner().unwrap()).unwrap()
280 }
281
282 impl MemWriter {
283 pub fn default() -> Self {
284 Self::new(&AsyncWriterBuilder::new())
285 }
286
287 pub fn into_inner(
288 mut self,
289 ) -> Result<Vec<u8>, IntoInnerError<MemWriter>> {
290 match self.flush() {
291 Ok(()) => Ok(self.wtr.clone().into_inner()),
293 Err(err) => Err(IntoInnerError::new(self, err)),
294 }
295 }
296
297 pub fn write_record<I, T>(&mut self, record: I) -> crate::error::Result<()>
298 where
299 I: IntoIterator<Item = T>,
300 T: AsRef<[u8]>,
301 {
302 for field in record.into_iter() {
303 self.write_field_impl(field)?;
304 }
305 self.write_terminator()
306 }
307 }
308
309 #[test]
310 fn one_record() {
311 let mut wtr = MemWriter::default();
312 wtr.write_record(&["a", "b", "c"]).unwrap();
313
314 assert_eq!(wtr_as_string(wtr), "a,b,c\n");
315 }
316
317 #[test]
318 fn one_string_record() {
319 let mut wtr = MemWriter::default();
320 wtr.write_record(&StringRecord::from(vec!["a", "b", "c"])).unwrap();
321
322 assert_eq!(wtr_as_string(wtr), "a,b,c\n");
323 }
324
325 #[test]
326 fn one_byte_record() {
327 let mut wtr = MemWriter::default();
328 wtr.write_record(&ByteRecord::from(vec!["a", "b", "c"])).unwrap();
329
330 assert_eq!(wtr_as_string(wtr), "a,b,c\n");
331 }
332
333 #[test]
334 fn one_empty_record() {
335 let mut wtr = MemWriter::default();
336 wtr.write_record(&[""]).unwrap();
337
338 assert_eq!(wtr_as_string(wtr), "\"\"\n");
339 }
340
341 #[test]
342 fn two_empty_records() {
343 let mut wtr = MemWriter::default();
344 wtr.write_record(&[""]).unwrap();
345 wtr.write_record(&[""]).unwrap();
346
347 assert_eq!(wtr_as_string(wtr), "\"\"\n\"\"\n");
348 }
349
350 #[test]
351 fn unequal_records_bad() {
352 let mut wtr = MemWriter::default();
353 wtr.write_record(&ByteRecord::from(vec!["a", "b", "c"])).unwrap();
354 let err = wtr.write_record(&ByteRecord::from(vec!["a"])).unwrap_err();
355 match *err.kind() {
356 ErrorKind::UnequalLengths { ref pos, expected_len, len } => {
357 assert!(pos.is_none());
358 assert_eq!(expected_len, 3);
359 assert_eq!(len, 1);
360 }
361 ref x => {
362 panic!("expected UnequalLengths error, but got '{:?}'", x);
363 }
364 }
365 }
366
367 #[test]
368 fn unequal_records_ok() {
369 let mut wtr = MemWriter::new(&AsyncWriterBuilder::new().flexible(true));
370 wtr.write_record(&ByteRecord::from(vec!["a", "b", "c"])).unwrap();
371 wtr.write_record(&ByteRecord::from(vec!["a"])).unwrap();
372 assert_eq!(wtr_as_string(wtr), "a,b,c\na\n");
373 }
374
375 #[test]
376 fn write_field() -> Result<(), Box<dyn Error>> {
377 let mut wtr = MemWriter::default();
378 wtr.write_field("a")?;
379 wtr.write_field("b")?;
380 wtr.write_field("c")?;
381 wtr.write_terminator()?;
382 wtr.write_field("x")?;
383 wtr.write_field("y")?;
384 wtr.write_field("z")?;
385 wtr.write_terminator()?;
386
387 let data = String::from_utf8(wtr.into_inner()?)?;
388 assert_eq!(data, "a,b,c\nx,y,z\n");
389 Ok(())
390 }
391
392 #[test]
393 fn serialize_with_headers() {
394 #[derive(Serialize)]
395 struct Row {
396 foo: i32,
397 bar: f64,
398 baz: bool,
399 }
400
401 let mut wtr = MemWriter::default();
402 wtr.serialize(Row { foo: 42, bar: 42.5, baz: true }).unwrap();
403 assert_eq!(wtr_as_string(wtr), "foo,bar,baz\n42,42.5,true\n");
404 }
405
406 #[test]
407 fn serialize_no_headers() {
408 #[derive(Serialize)]
409 struct Row {
410 foo: i32,
411 bar: f64,
412 baz: bool,
413 }
414
415 let mut wtr = MemWriter::new(&AsyncWriterBuilder::new().has_headers(false));
416 wtr.serialize(Row { foo: 42, bar: 42.5, baz: true }).unwrap();
417 assert_eq!(wtr_as_string(wtr), "42,42.5,true\n");
418 }
419
420 serde_if_integer128! {
421 #[test]
422 fn serialize_no_headers_128() {
423 #[derive(Serialize)]
424 struct Row {
425 foo: i128,
426 bar: f64,
427 baz: bool,
428 }
429
430 let mut wtr = MemWriter::new(&AsyncWriterBuilder::new().has_headers(false));
431 wtr.serialize(Row {
432 foo: 9_223_372_036_854_775_808,
433 bar: 42.5,
434 baz: true,
435 }).unwrap();
436 assert_eq!(wtr_as_string(wtr), "9223372036854775808,42.5,true\n");
437 }
438 }
439
440 #[test]
441 fn serialize_tuple() {
442 let mut wtr = MemWriter::default();
443 wtr.serialize((true, 1.3, "hi")).unwrap();
444 assert_eq!(wtr_as_string(wtr), "true,1.3,hi\n");
445 }
446
447 #[test]
448 fn serialize_struct() -> Result<(), Box<dyn Error>> {
449 #[derive(Serialize)]
450 struct Row<'a> {
451 city: &'a str,
452 country: &'a str,
453 #[serde(rename = "popcount")]
456 population: u64,
457 }
458
459 let mut wtr = MemWriter::default();
460 wtr.serialize(Row {
461 city: "Boston",
462 country: "United States",
463 population: 4628910,
464 })?;
465 wtr.serialize(Row {
466 city: "Concord",
467 country: "United States",
468 population: 42695,
469 })?;
470
471 let data = String::from_utf8(wtr.into_inner()?)?;
472 assert_eq!(data, "\
473city,country,popcount
474Boston,United States,4628910
475Concord,United States,42695
476");
477 Ok(())
478 }
479
480 #[test]
481 fn serialize_enum() -> Result<(), Box<dyn Error>> {
482 #[derive(Serialize)]
483 struct Row {
484 label: String,
485 value: Value,
486 }
487 #[derive(Serialize)]
488 enum Value {
489 Integer(i64),
490 Float(f64),
491 }
492
493 let mut wtr = MemWriter::default();
494 wtr.serialize(Row {
495 label: "foo".to_string(),
496 value: Value::Integer(3),
497 })?;
498 wtr.serialize(Row {
499 label: "bar".to_string(),
500 value: Value::Float(3.14),
501 })?;
502
503 let data = String::from_utf8(wtr.into_inner()?)?;
504 assert_eq!(data, "\
505label,value
506foo,3
507bar,3.14
508");
509 Ok(())
510 }
511
512 #[test]
513 fn serialize_vec() -> Result<(), Box<dyn Error>> {
514 #[derive(Serialize)]
515 struct Row {
516 label: String,
517 values: Vec<f64>,
518 }
519
520 let mut wtr = MemWriter::new(
521 &AsyncWriterBuilder::new()
522 .has_headers(false)
523 );
524 wtr.serialize(Row {
525 label: "foo".to_string(),
526 values: vec![1.1234, 2.5678, 3.14],
527 })?;
528
529 let data = String::from_utf8(wtr.into_inner()?)?;
530 assert_eq!(data, "foo,1.1234,2.5678,3.14\n");
531 Ok(())
532 }
533}