1use std::collections::BTreeMap;
27use std::fmt;
28use std::io::{Seek, SeekFrom, Write};
29
30use anyhow::Error;
31use rand::random;
32
33use crate::Codec;
34use crate::decode::AvroRead;
35use crate::encode::{encode, encode_ref, encode_to_vec};
36use crate::reader::Header;
37use crate::schema::{Schema, SchemaPiece};
38use crate::types::{ToAvro, Value};
39
40const SYNC_SIZE: usize = 16;
41const SYNC_INTERVAL: usize = 1000 * SYNC_SIZE; const AVRO_OBJECT_HEADER: &[u8] = &[b'O', b'b', b'j', 1u8];
44
45#[derive(Debug)]
47pub struct ValidationError(String);
48
49impl ValidationError {
50 pub fn new<S>(msg: S) -> ValidationError
51 where
52 S: Into<String>,
53 {
54 ValidationError(msg.into())
55 }
56}
57
58impl fmt::Display for ValidationError {
59 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
60 write!(f, "Validation error: {}", self.0)
61 }
62}
63
64impl std::error::Error for ValidationError {}
65
66pub struct Writer<W> {
68 schema: Schema,
69 writer: W,
70 buffer: Vec<u8>,
71 num_values: usize,
72 codec: Option<Codec>,
73 marker: [u8; 16],
74 has_header: bool,
75}
76
77impl<W: Write> Writer<W> {
78 pub fn new(schema: Schema, writer: W) -> Writer<W> {
83 Self::with_codec(schema, writer, Codec::Null)
84 }
85
86 pub fn with_codec(schema: Schema, writer: W, codec: Codec) -> Writer<W> {
88 Writer::with_codec_opt(schema, writer, Some(codec))
89 }
90
91 pub fn with_codec_opt(schema: Schema, writer: W, codec: Option<Codec>) -> Writer<W> {
97 let mut marker = [0; 16];
98 for i in 0..16 {
99 marker[i] = random::<u8>();
100 }
101
102 Writer {
103 schema,
104 writer,
105 buffer: Vec::with_capacity(SYNC_INTERVAL),
106 num_values: 0,
107 codec,
108 marker,
109 has_header: false,
110 }
111 }
112
113 pub fn append_to(mut file: W) -> Result<Writer<W>, Error>
115 where
116 W: AvroRead + Seek + Unpin + Send,
117 {
118 let header = Header::from_reader(&mut file)?;
119 let (schema, marker, codec) = header.into_parts();
120 file.seek(SeekFrom::End(0))?;
121 Ok(Writer {
122 schema,
123 writer: file,
124 buffer: Vec::with_capacity(SYNC_INTERVAL),
125 num_values: 0,
126 codec: Some(codec),
127 marker,
128 has_header: true,
129 })
130 }
131
132 pub fn schema(&self) -> &Schema {
134 &self.schema
135 }
136
137 pub fn append<T: ToAvro>(&mut self, value: T) -> Result<usize, Error> {
146 let n = if !self.has_header {
147 let header = self.header()?;
148 let n = self.append_bytes(header.as_ref())?;
149 self.has_header = true;
150 n
151 } else {
152 0
153 };
154 let avro = value.avro();
155 write_value_ref(&self.schema, &avro, &mut self.buffer)?;
156
157 self.num_values += 1;
158
159 if self.buffer.len() >= SYNC_INTERVAL {
160 return self.flush().map(|b| b + n);
161 }
162
163 Ok(n)
164 }
165
166 pub fn append_value_ref(&mut self, value: &Value) -> Result<usize, Error> {
174 let n = if !self.has_header {
175 let header = self.header()?;
176 let n = self.append_bytes(header.as_ref())?;
177 self.has_header = true;
178 n
179 } else {
180 0
181 };
182
183 write_value_ref(&self.schema, value, &mut self.buffer)?;
184
185 self.num_values += 1;
186
187 if self.buffer.len() >= SYNC_INTERVAL {
188 return self.flush().map(|b| b + n);
189 }
190
191 Ok(n)
192 }
193
194 pub fn extend<I, T: ToAvro>(&mut self, values: I) -> Result<usize, Error>
202 where
203 I: IntoIterator<Item = T>,
204 {
205 let mut num_bytes = 0;
220 for value in values {
221 num_bytes += self.append(value)?;
222 }
223 num_bytes += self.flush()?;
224
225 Ok(num_bytes)
226 }
227
228 pub fn extend_from_slice(&mut self, values: &[Value]) -> Result<usize, Error> {
236 let mut num_bytes = 0;
237 for value in values {
238 num_bytes += self.append_value_ref(value)?;
239 }
240 num_bytes += self.flush()?;
241
242 Ok(num_bytes)
243 }
244
245 pub fn flush(&mut self) -> Result<usize, Error> {
250 if self.num_values == 0 {
251 return Ok(0);
252 }
253
254 let compressor = self.codec.unwrap_or(Codec::Null);
255 compressor.compress(&mut self.buffer)?;
256
257 let num_values = self.num_values;
258 let stream_len = self.buffer.len();
259
260 let ls = Schema {
261 named: vec![],
262 indices: Default::default(),
263 top: SchemaPiece::Long.into(),
264 };
265
266 let num_bytes = self.append_raw(&num_values.avro(), &ls)?
267 + self.append_raw(&stream_len.avro(), &ls)?
268 + self.writer.write(self.buffer.as_ref())?
269 + self.append_marker()?;
270
271 self.buffer.clear();
272 self.num_values = 0;
273
274 Ok(num_bytes)
275 }
276
277 pub fn into_inner(self) -> W {
282 self.writer
283 }
284
285 fn append_marker(&mut self) -> Result<usize, Error> {
287 Ok(self.writer.write(&self.marker)?)
290 }
291
292 fn append_raw(&mut self, value: &Value, schema: &Schema) -> Result<usize, Error> {
294 self.append_bytes(encode_to_vec(value, schema).as_ref())
295 }
296
297 fn append_bytes(&mut self, bytes: &[u8]) -> Result<usize, Error> {
299 Ok(self.writer.write(bytes)?)
300 }
301
302 fn header(&self) -> Result<Vec<u8>, Error> {
304 let schema_bytes = serde_json::to_string(&self.schema)?.into_bytes();
305
306 let mut metadata = BTreeMap::new();
307 metadata.insert("avro.schema", Value::Bytes(schema_bytes));
308 if let Some(codec) = self.codec {
309 metadata.insert("avro.codec", codec.avro());
310 };
311
312 let mut header = Vec::new();
313 header.extend_from_slice(AVRO_OBJECT_HEADER);
314 encode(
315 &metadata.avro(),
316 &Schema {
317 named: vec![],
318 indices: Default::default(),
319 top: SchemaPiece::Map(Box::new(SchemaPiece::Bytes.into())).into(),
320 },
321 &mut header,
322 );
323 header.extend_from_slice(&self.marker);
324
325 Ok(header)
326 }
327}
328
329pub fn write_avro_datum<T: ToAvro>(
335 schema: &Schema,
336 value: T,
337 buffer: &mut Vec<u8>,
338) -> Result<(), Error> {
339 let avro = value.avro();
340 if !avro.validate(schema.top_node()) {
341 return Err(ValidationError::new("value does not match schema").into());
342 }
343 encode(&avro, schema, buffer);
344 Ok(())
345}
346
347fn write_value_ref(schema: &Schema, value: &Value, buffer: &mut Vec<u8>) -> Result<(), Error> {
348 if !value.validate(schema.top_node()) {
349 return Err(ValidationError::new("value does not match schema").into());
350 }
351 encode_ref(value, schema.top_node(), buffer);
352 Ok(())
353}
354
355pub fn to_avro_datum<T: ToAvro>(schema: &Schema, value: T) -> Result<Vec<u8>, Error> {
362 let mut buffer = Vec::new();
363 write_avro_datum(schema, value, &mut buffer)?;
364 Ok(buffer)
365}
366
367#[cfg(test)]
368mod tests {
369 use std::io::Cursor;
370 use std::str::FromStr;
371
372 use serde::{Deserialize, Serialize};
373
374 use crate::Reader;
375 use crate::types::Record;
376 use crate::util::zig_i64;
377
378 use super::*;
379
380 static SCHEMA: &str = r#"
381 {
382 "type": "record",
383 "name": "test",
384 "fields": [
385 {"name": "a", "type": "long", "default": 42},
386 {"name": "b", "type": "string"}
387 ]
388 }
389 "#;
390 static UNION_SCHEMA: &str = r#"
391 ["null", "long"]
392 "#;
393
394 #[mz_ore::test]
395 fn test_to_avro_datum() {
396 let schema = Schema::from_str(SCHEMA).unwrap();
397 let mut record = Record::new(schema.top_node()).unwrap();
398 record.put("a", 27i64);
399 record.put("b", "foo");
400
401 let mut expected = Vec::new();
402 zig_i64(27, &mut expected);
403 zig_i64(3, &mut expected);
404 expected.extend(vec![b'f', b'o', b'o'].into_iter());
405
406 assert_eq!(to_avro_datum(&schema, record).unwrap(), expected);
407 }
408
409 #[mz_ore::test]
410 fn test_union() {
411 let schema = Schema::from_str(UNION_SCHEMA).unwrap();
412 let union = Value::Union {
413 index: 1,
414 inner: Box::new(Value::Long(3)),
415 n_variants: 2,
416 null_variant: Some(0),
417 };
418
419 let mut expected = Vec::new();
420 zig_i64(1, &mut expected);
421 zig_i64(3, &mut expected);
422
423 assert_eq!(to_avro_datum(&schema, union).unwrap(), expected);
424 }
425
426 #[mz_ore::test]
427 fn test_writer_append() {
428 let schema = Schema::from_str(SCHEMA).unwrap();
429 let mut writer = Writer::new(schema.clone(), Vec::new());
430
431 let mut record = Record::new(schema.top_node()).unwrap();
432 record.put("a", 27i64);
433 record.put("b", "foo");
434
435 let n1 = writer.append(record.clone()).unwrap();
436 let n2 = writer.append(record.clone()).unwrap();
437 let n3 = writer.flush().unwrap();
438 let result = writer.into_inner();
439
440 assert_eq!(n1 + n2 + n3, result.len());
441
442 let mut header = Vec::new();
443 header.extend(vec![b'O', b'b', b'j', b'\x01']);
444
445 let mut data = Vec::new();
446 zig_i64(27, &mut data);
447 zig_i64(3, &mut data);
448 data.extend(vec![b'f', b'o', b'o'].into_iter());
449 let data_copy = data.clone();
450 data.extend(data_copy);
451
452 assert_eq!(
454 result
455 .iter()
456 .cloned()
457 .take(header.len())
458 .collect::<Vec<u8>>(),
459 header
460 );
461 assert_eq!(
463 result
464 .iter()
465 .cloned()
466 .rev()
467 .skip(16)
468 .take(data.len())
469 .collect::<Vec<u8>>()
470 .into_iter()
471 .rev()
472 .collect::<Vec<u8>>(),
473 data
474 );
475 }
476
477 #[mz_ore::test]
478 fn test_writer_extend() {
479 let schema = Schema::from_str(SCHEMA).unwrap();
480 let mut writer = Writer::new(schema.clone(), Vec::new());
481
482 let mut record = Record::new(schema.top_node()).unwrap();
483 record.put("a", 27i64);
484 record.put("b", "foo");
485 let record_copy = record.clone();
486 let records = vec![record, record_copy];
487
488 let n1 = writer.extend(records.into_iter()).unwrap();
489 let n2 = writer.flush().unwrap();
490 let result = writer.into_inner();
491
492 assert_eq!(n1 + n2, result.len());
493
494 let mut header = Vec::new();
495 header.extend(vec![b'O', b'b', b'j', b'\x01']);
496
497 let mut data = Vec::new();
498 zig_i64(27, &mut data);
499 zig_i64(3, &mut data);
500 data.extend(vec![b'f', b'o', b'o'].into_iter());
501 let data_copy = data.clone();
502 data.extend(data_copy);
503
504 assert_eq!(
506 result
507 .iter()
508 .cloned()
509 .take(header.len())
510 .collect::<Vec<u8>>(),
511 header
512 );
513 assert_eq!(
515 result
516 .iter()
517 .cloned()
518 .rev()
519 .skip(16)
520 .take(data.len())
521 .collect::<Vec<u8>>()
522 .into_iter()
523 .rev()
524 .collect::<Vec<u8>>(),
525 data
526 );
527 }
528
529 #[derive(Debug, Clone, Deserialize, Serialize)]
530 struct TestSerdeSerialize {
531 a: i64,
532 b: String,
533 }
534
535 #[mz_ore::test]
536 #[cfg_attr(miri, ignore)] fn test_writer_with_codec() {
538 let schema = Schema::from_str(SCHEMA).unwrap();
539 let mut writer = Writer::with_codec(schema.clone(), Vec::new(), Codec::Deflate);
540
541 let mut record = Record::new(schema.top_node()).unwrap();
542 record.put("a", 27i64);
543 record.put("b", "foo");
544
545 let n1 = writer.append(record.clone()).unwrap();
546 let n2 = writer.append(record.clone()).unwrap();
547 let n3 = writer.flush().unwrap();
548 let result = writer.into_inner();
549
550 assert_eq!(n1 + n2 + n3, result.len());
551
552 let mut header = Vec::new();
553 header.extend(vec![b'O', b'b', b'j', b'\x01']);
554
555 let mut data = Vec::new();
556 zig_i64(27, &mut data);
557 zig_i64(3, &mut data);
558 data.extend(vec![b'f', b'o', b'o'].into_iter());
559 let data_copy = data.clone();
560 data.extend(data_copy);
561 Codec::Deflate.compress(&mut data).unwrap();
562
563 assert_eq!(
565 result
566 .iter()
567 .cloned()
568 .take(header.len())
569 .collect::<Vec<u8>>(),
570 header
571 );
572 assert_eq!(
574 result
575 .iter()
576 .cloned()
577 .rev()
578 .skip(16)
579 .take(data.len())
580 .collect::<Vec<u8>>()
581 .into_iter()
582 .rev()
583 .collect::<Vec<u8>>(),
584 data
585 );
586 }
587
588 #[mz_ore::test]
589 #[cfg_attr(miri, ignore)] fn test_writer_roundtrip() {
591 let schema = Schema::from_str(SCHEMA).unwrap();
592 let make_record = |a: i64, b| {
593 let mut record = Record::new(schema.top_node()).unwrap();
594 record.put("a", a);
595 record.put("b", b);
596 record.avro()
597 };
598
599 let mut buf = Vec::new();
600
601 {
603 let mut writer = Writer::new(schema.clone(), &mut buf);
604 writer.append(make_record(27, "foo")).unwrap();
605 writer.flush().unwrap();
606 writer.append(make_record(54, "bar")).unwrap();
607 writer.flush().unwrap();
608 }
609
610 {
612 let mut writer = Writer::append_to(Cursor::new(&mut buf)).unwrap();
613 writer.append(make_record(42, "baz")).unwrap();
614 writer.flush().unwrap();
615 }
616
617 {
619 let mut writer = Writer::append_to(Cursor::new(&mut buf)).unwrap();
620 writer.append(make_record(84, "zar")).unwrap();
621 writer.flush().unwrap();
622 }
623
624 let reader = Reader::new(&buf[..]).unwrap();
626 let actual: Result<Vec<_>, _> = reader.collect();
627 let actual = actual.unwrap();
628 assert_eq!(
629 vec![
630 make_record(27, "foo"),
631 make_record(54, "bar"),
632 make_record(42, "baz"),
633 make_record(84, "zar")
634 ],
635 actual
636 );
637 }
638}