use std::io;
use std::io::Write;
use csv_core::{
self, WriteResult, Writer as CoreWriter
};
use serde::Serialize;
use crate::error::{Error, ErrorKind, Result};
use crate::serializer::{serialize, serialize_header};
use crate::AsyncWriterBuilder;
#[derive(Debug)]
pub struct MemWriter {
core: CoreWriter,
wtr: io::Cursor<Vec<u8>>,
buf: Buffer,
state: WriterState,
}
#[derive(Debug)]
struct WriterState {
header: HeaderState,
flexible: bool,
first_field_count: Option<u64>,
fields_written: u64,
panicked: bool,
}
#[derive(Debug)]
enum HeaderState {
Write,
DidWrite,
DidNotWrite,
None,
}
#[derive(Debug)]
struct Buffer {
buf: Vec<u8>,
len: usize,
}
impl Drop for MemWriter {
fn drop(&mut self) {
if !self.state.panicked {
let _ = self.flush();
}
}
}
impl MemWriter {
pub fn new(builder: &AsyncWriterBuilder) -> Self {
let header_state = if builder.has_headers {
HeaderState::Write
} else {
HeaderState::None
};
MemWriter {
core: builder.builder.build(),
wtr: io::Cursor::new(Vec::new()),
buf: Buffer { buf: vec![0; builder.capacity], len: 0 },
state: WriterState {
header: header_state,
flexible: builder.flexible,
first_field_count: None,
fields_written: 0,
panicked: false,
},
}
}
pub fn serialize<S: Serialize>(&mut self, record: S) -> Result<()> {
if let HeaderState::Write = self.state.header {
let wrote_header = serialize_header(self, &record)?;
if wrote_header {
self.write_terminator()?;
self.state.header = HeaderState::DidWrite;
} else {
self.state.header = HeaderState::DidNotWrite;
};
}
serialize(self, &record)?;
self.write_terminator()?;
Ok(())
}
pub fn write_field<T: AsRef<[u8]>>(&mut self, field: T) -> Result<()> {
self.write_field_impl(field)
}
#[inline(always)]
fn write_field_impl<T: AsRef<[u8]>>(&mut self, field: T) -> Result<()> {
if self.state.fields_written > 0 {
self.write_delimiter()?;
}
let mut field = field.as_ref();
loop {
let (res, nin, nout) = self.core.field(field, self.buf.writable());
field = &field[nin..];
self.buf.written(nout);
match res {
WriteResult::InputEmpty => {
self.state.fields_written += 1;
return Ok(());
}
WriteResult::OutputFull => self.flush_buf()?,
}
}
}
pub fn flush(&mut self) -> io::Result<()> {
self.flush_buf()?;
self.wtr.flush()?;
Ok(())
}
fn flush_buf(&mut self) -> io::Result<()> {
self.state.panicked = true;
let result = self.wtr.write_all(self.buf.readable());
self.state.panicked = false;
result?;
self.buf.clear();
Ok(())
}
pub fn data(&mut self) -> &[u8] {
self.wtr.get_mut().as_slice()
}
pub fn clear(&mut self) {
self.wtr.get_mut().clear();
self.wtr.set_position(0);
}
fn write_delimiter(&mut self) -> Result<()> {
loop {
let (res, nout) = self.core.delimiter(self.buf.writable());
self.buf.written(nout);
match res {
WriteResult::InputEmpty => return Ok(()),
WriteResult::OutputFull => self.flush_buf()?,
}
}
}
fn write_terminator(&mut self) -> Result<()> {
self.check_field_count()?;
loop {
let (res, nout) = self.core.terminator(self.buf.writable());
self.buf.written(nout);
match res {
WriteResult::InputEmpty => {
self.state.fields_written = 0;
return Ok(());
}
WriteResult::OutputFull => self.flush_buf()?,
}
}
}
fn check_field_count(&mut self) -> Result<()> {
if !self.state.flexible {
match self.state.first_field_count {
None => {
self.state.first_field_count =
Some(self.state.fields_written);
}
Some(expected) if expected != self.state.fields_written => {
return Err(Error::new(ErrorKind::UnequalLengths {
pos: None,
expected_len: expected,
len: self.state.fields_written,
}))
}
Some(_) => {}
}
}
Ok(())
}
}
impl Buffer {
#[inline]
fn readable(&self) -> &[u8] {
&self.buf[..self.len]
}
#[inline]
fn writable(&mut self) -> &mut [u8] {
&mut self.buf[self.len..]
}
#[inline]
fn written(&mut self, n: usize) {
self.len += n;
}
#[inline]
fn clear(&mut self) {
self.len = 0;
}
}
#[cfg(test)]
mod tests {
use std::error::Error;
use serde::{serde_if_integer128, Serialize};
use crate::byte_record::ByteRecord;
use crate::error::{ErrorKind, IntoInnerError};
use crate::string_record::StringRecord;
use super::{MemWriter, AsyncWriterBuilder};
fn wtr_as_string(wtr: MemWriter) -> String {
String::from_utf8(wtr.into_inner().unwrap()).unwrap()
}
impl MemWriter {
pub fn default() -> Self {
Self::new(&AsyncWriterBuilder::new())
}
pub fn into_inner(
mut self,
) -> Result<Vec<u8>, IntoInnerError<MemWriter>> {
match self.flush() {
Ok(()) => Ok(self.wtr.clone().into_inner()),
Err(err) => Err(IntoInnerError::new(self, err)),
}
}
pub fn write_record<I, T>(&mut self, record: I) -> crate::error::Result<()>
where
I: IntoIterator<Item = T>,
T: AsRef<[u8]>,
{
for field in record.into_iter() {
self.write_field_impl(field)?;
}
self.write_terminator()
}
}
#[test]
fn one_record() {
let mut wtr = MemWriter::default();
wtr.write_record(&["a", "b", "c"]).unwrap();
assert_eq!(wtr_as_string(wtr), "a,b,c\n");
}
#[test]
fn one_string_record() {
let mut wtr = MemWriter::default();
wtr.write_record(&StringRecord::from(vec!["a", "b", "c"])).unwrap();
assert_eq!(wtr_as_string(wtr), "a,b,c\n");
}
#[test]
fn one_byte_record() {
let mut wtr = MemWriter::default();
wtr.write_record(&ByteRecord::from(vec!["a", "b", "c"])).unwrap();
assert_eq!(wtr_as_string(wtr), "a,b,c\n");
}
#[test]
fn one_empty_record() {
let mut wtr = MemWriter::default();
wtr.write_record(&[""]).unwrap();
assert_eq!(wtr_as_string(wtr), "\"\"\n");
}
#[test]
fn two_empty_records() {
let mut wtr = MemWriter::default();
wtr.write_record(&[""]).unwrap();
wtr.write_record(&[""]).unwrap();
assert_eq!(wtr_as_string(wtr), "\"\"\n\"\"\n");
}
#[test]
fn unequal_records_bad() {
let mut wtr = MemWriter::default();
wtr.write_record(&ByteRecord::from(vec!["a", "b", "c"])).unwrap();
let err = wtr.write_record(&ByteRecord::from(vec!["a"])).unwrap_err();
match *err.kind() {
ErrorKind::UnequalLengths { ref pos, expected_len, len } => {
assert!(pos.is_none());
assert_eq!(expected_len, 3);
assert_eq!(len, 1);
}
ref x => {
panic!("expected UnequalLengths error, but got '{:?}'", x);
}
}
}
#[test]
fn unequal_records_ok() {
let mut wtr = MemWriter::new(&AsyncWriterBuilder::new().flexible(true));
wtr.write_record(&ByteRecord::from(vec!["a", "b", "c"])).unwrap();
wtr.write_record(&ByteRecord::from(vec!["a"])).unwrap();
assert_eq!(wtr_as_string(wtr), "a,b,c\na\n");
}
#[test]
fn write_field() -> Result<(), Box<dyn Error>> {
let mut wtr = MemWriter::default();
wtr.write_field("a")?;
wtr.write_field("b")?;
wtr.write_field("c")?;
wtr.write_terminator()?;
wtr.write_field("x")?;
wtr.write_field("y")?;
wtr.write_field("z")?;
wtr.write_terminator()?;
let data = String::from_utf8(wtr.into_inner()?)?;
assert_eq!(data, "a,b,c\nx,y,z\n");
Ok(())
}
#[test]
fn serialize_with_headers() {
#[derive(Serialize)]
struct Row {
foo: i32,
bar: f64,
baz: bool,
}
let mut wtr = MemWriter::default();
wtr.serialize(Row { foo: 42, bar: 42.5, baz: true }).unwrap();
assert_eq!(wtr_as_string(wtr), "foo,bar,baz\n42,42.5,true\n");
}
#[test]
fn serialize_no_headers() {
#[derive(Serialize)]
struct Row {
foo: i32,
bar: f64,
baz: bool,
}
let mut wtr = MemWriter::new(&AsyncWriterBuilder::new().has_headers(false));
wtr.serialize(Row { foo: 42, bar: 42.5, baz: true }).unwrap();
assert_eq!(wtr_as_string(wtr), "42,42.5,true\n");
}
serde_if_integer128! {
#[test]
fn serialize_no_headers_128() {
#[derive(Serialize)]
struct Row {
foo: i128,
bar: f64,
baz: bool,
}
let mut wtr = MemWriter::new(&AsyncWriterBuilder::new().has_headers(false));
wtr.serialize(Row {
foo: 9_223_372_036_854_775_808,
bar: 42.5,
baz: true,
}).unwrap();
assert_eq!(wtr_as_string(wtr), "9223372036854775808,42.5,true\n");
}
}
#[test]
fn serialize_tuple() {
let mut wtr = MemWriter::default();
wtr.serialize((true, 1.3, "hi")).unwrap();
assert_eq!(wtr_as_string(wtr), "true,1.3,hi\n");
}
#[test]
fn serialize_struct() -> Result<(), Box<dyn Error>> {
#[derive(Serialize)]
struct Row<'a> {
city: &'a str,
country: &'a str,
#[serde(rename = "popcount")]
population: u64,
}
let mut wtr = MemWriter::default();
wtr.serialize(Row {
city: "Boston",
country: "United States",
population: 4628910,
})?;
wtr.serialize(Row {
city: "Concord",
country: "United States",
population: 42695,
})?;
let data = String::from_utf8(wtr.into_inner()?)?;
assert_eq!(data, "\
city,country,popcount
Boston,United States,4628910
Concord,United States,42695
");
Ok(())
}
#[test]
fn serialize_enum() -> Result<(), Box<dyn Error>> {
#[derive(Serialize)]
struct Row {
label: String,
value: Value,
}
#[derive(Serialize)]
enum Value {
Integer(i64),
Float(f64),
}
let mut wtr = MemWriter::default();
wtr.serialize(Row {
label: "foo".to_string(),
value: Value::Integer(3),
})?;
wtr.serialize(Row {
label: "bar".to_string(),
value: Value::Float(3.14),
})?;
let data = String::from_utf8(wtr.into_inner()?)?;
assert_eq!(data, "\
label,value
foo,3
bar,3.14
");
Ok(())
}
#[test]
fn serialize_vec() -> Result<(), Box<dyn Error>> {
#[derive(Serialize)]
struct Row {
label: String,
values: Vec<f64>,
}
let mut wtr = MemWriter::new(
&AsyncWriterBuilder::new()
.has_headers(false)
);
wtr.serialize(Row {
label: "foo".to_string(),
values: vec![1.1234, 2.5678, 3.14],
})?;
let data = String::from_utf8(wtr.into_inner()?)?;
assert_eq!(data, "foo,1.1234,2.5678,3.14\n");
Ok(())
}
}