1use std::collections::BTreeMap;
27use std::fmt;
28use std::io::{Seek, SeekFrom, Write};
29
30use anyhow::Error;
31use rand::random;
32
33use crate::Codec;
34use crate::decode::AvroRead;
35use crate::encode::{encode, encode_ref, encode_to_vec};
36use crate::reader::Header;
37use crate::schema::{Schema, SchemaPiece};
38use crate::types::{ToAvro, Value};
39
40const SYNC_SIZE: usize = 16;
41const SYNC_INTERVAL: usize = 1000 * SYNC_SIZE; const AVRO_OBJECT_HEADER: &[u8] = &[b'O', b'b', b'j', 1u8];
44
45#[derive(Debug)]
47pub struct ValidationError(String);
48
49impl ValidationError {
50 pub fn new<S>(msg: S) -> ValidationError
51 where
52 S: Into<String>,
53 {
54 ValidationError(msg.into())
55 }
56}
57
58impl fmt::Display for ValidationError {
59 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
60 write!(f, "Validation error: {}", self.0)
61 }
62}
63
64impl std::error::Error for ValidationError {}
65
66pub struct Writer<W> {
68 schema: Schema,
69 writer: W,
70 buffer: Vec<u8>,
71 num_values: usize,
72 codec: Option<Codec>,
73 marker: [u8; 16],
74 has_header: bool,
75}
76
77impl<W: Write> Writer<W> {
78 pub fn new(schema: Schema, writer: W) -> Writer<W> {
83 Self::with_codec(schema, writer, Codec::Null)
84 }
85
86 pub fn with_codec(schema: Schema, writer: W, codec: Codec) -> Writer<W> {
88 Writer::with_codec_opt(schema, writer, Some(codec))
89 }
90
91 pub fn with_codec_opt(schema: Schema, writer: W, codec: Option<Codec>) -> Writer<W> {
97 let mut marker = [0; 16];
98 for i in 0..16 {
99 marker[i] = random::<u8>();
100 }
101
102 Writer {
103 schema,
104 writer,
105 buffer: Vec::with_capacity(SYNC_INTERVAL),
106 num_values: 0,
107 codec,
108 marker,
109 has_header: false,
110 }
111 }
112
113 pub fn append_to(mut file: W) -> Result<Writer<W>, Error>
115 where
116 W: AvroRead + Seek + Unpin + Send,
117 {
118 let header = Header::from_reader(&mut file)?;
119 let (schema, marker, codec) = header.into_parts();
120 file.seek(SeekFrom::End(0))?;
121 Ok(Writer {
122 schema,
123 writer: file,
124 buffer: Vec::with_capacity(SYNC_INTERVAL),
125 num_values: 0,
126 codec: Some(codec),
127 marker,
128 has_header: true,
129 })
130 }
131
132 pub fn schema(&self) -> &Schema {
134 &self.schema
135 }
136
137 pub fn append<T: ToAvro>(&mut self, value: T) -> Result<usize, Error> {
146 let n = if !self.has_header {
147 let header = self.header()?;
148 let n = self.append_bytes(header.as_ref())?;
149 self.has_header = true;
150 n
151 } else {
152 0
153 };
154 let avro = value.avro();
155 write_value_ref(&self.schema, &avro, &mut self.buffer)?;
156
157 self.num_values += 1;
158
159 if self.buffer.len() >= SYNC_INTERVAL {
160 return self.flush().map(|b| b + n);
161 }
162
163 Ok(n)
164 }
165
166 pub fn append_value_ref(&mut self, value: &Value) -> Result<usize, Error> {
174 let n = if !self.has_header {
175 let header = self.header()?;
176 let n = self.append_bytes(header.as_ref())?;
177 self.has_header = true;
178 n
179 } else {
180 0
181 };
182
183 write_value_ref(&self.schema, value, &mut self.buffer)?;
184
185 self.num_values += 1;
186
187 if self.buffer.len() >= SYNC_INTERVAL {
188 return self.flush().map(|b| b + n);
189 }
190
191 Ok(n)
192 }
193
194 pub fn extend<I, T: ToAvro>(&mut self, values: I) -> Result<usize, Error>
202 where
203 I: IntoIterator<Item = T>,
204 {
205 let mut num_bytes = 0;
220 for value in values {
221 num_bytes += self.append(value)?;
222 }
223 num_bytes += self.flush()?;
224
225 Ok(num_bytes)
226 }
227
228 pub fn extend_from_slice(&mut self, values: &[Value]) -> Result<usize, Error> {
236 let mut num_bytes = 0;
237 for value in values {
238 num_bytes += self.append_value_ref(value)?;
239 }
240 num_bytes += self.flush()?;
241
242 Ok(num_bytes)
243 }
244
245 pub fn flush(&mut self) -> Result<usize, Error> {
250 if self.num_values == 0 {
251 return Ok(0);
252 }
253
254 let compressor = self.codec.unwrap_or(Codec::Null);
255 compressor.compress(&mut self.buffer)?;
256
257 let num_values = self.num_values;
258 let stream_len = self.buffer.len();
259
260 let ls = Schema {
261 named: vec![],
262 indices: Default::default(),
263 top: SchemaPiece::Long.into(),
264 };
265
266 let num_bytes = self.append_raw(&num_values.avro(), &ls)?
267 + self.append_raw(&stream_len.avro(), &ls)?
268 + self.writer.write(self.buffer.as_ref())?
269 + self.append_marker()?;
270
271 self.buffer.clear();
272 self.num_values = 0;
273
274 Ok(num_bytes)
275 }
276
277 pub fn into_inner(self) -> W {
282 self.writer
283 }
284
285 fn append_marker(&mut self) -> Result<usize, Error> {
287 Ok(self.writer.write(&self.marker)?)
290 }
291
292 fn append_raw(&mut self, value: &Value, schema: &Schema) -> Result<usize, Error> {
294 self.append_bytes(encode_to_vec(value, schema).as_ref())
295 }
296
297 fn append_bytes(&mut self, bytes: &[u8]) -> Result<usize, Error> {
299 Ok(self.writer.write(bytes)?)
300 }
301
302 fn header(&self) -> Result<Vec<u8>, Error> {
304 let schema_bytes = serde_json::to_string(&self.schema)?.into_bytes();
305
306 let mut metadata = BTreeMap::new();
307 metadata.insert("avro.schema", Value::Bytes(schema_bytes));
308 if let Some(codec) = self.codec {
309 metadata.insert("avro.codec", codec.avro());
310 };
311
312 let mut header = Vec::new();
313 header.extend_from_slice(AVRO_OBJECT_HEADER);
314 encode(
315 &metadata.avro(),
316 &Schema {
317 named: vec![],
318 indices: Default::default(),
319 top: SchemaPiece::Map(Box::new(SchemaPiece::Bytes.into())).into(),
320 },
321 &mut header,
322 );
323 header.extend_from_slice(&self.marker);
324
325 Ok(header)
326 }
327}
328
329pub fn write_avro_datum<T: ToAvro>(
335 schema: &Schema,
336 value: T,
337 buffer: &mut Vec<u8>,
338) -> Result<(), Error> {
339 let avro = value.avro();
340 if !avro.validate(schema.top_node()) {
341 return Err(ValidationError::new("value does not match schema").into());
342 }
343 encode(&avro, schema, buffer);
344 Ok(())
345}
346
347fn write_value_ref(schema: &Schema, value: &Value, buffer: &mut Vec<u8>) -> Result<(), Error> {
348 if !value.validate(schema.top_node()) {
349 return Err(ValidationError::new("value does not match schema").into());
350 }
351 encode_ref(value, schema.top_node(), buffer);
352 Ok(())
353}
354
355pub fn to_avro_datum<T: ToAvro>(schema: &Schema, value: T) -> Result<Vec<u8>, Error> {
362 let mut buffer = Vec::new();
363 write_avro_datum(schema, value, &mut buffer)?;
364 Ok(buffer)
365}
366
367#[cfg(test)]
368mod tests {
369 use std::io::Cursor;
370 use std::str::FromStr;
371
372 use crate::Reader;
373 use crate::types::Record;
374 use crate::util::zig_i64;
375
376 use super::*;
377
378 static SCHEMA: &str = r#"
379 {
380 "type": "record",
381 "name": "test",
382 "fields": [
383 {"name": "a", "type": "long", "default": 42},
384 {"name": "b", "type": "string"}
385 ]
386 }
387 "#;
388 static UNION_SCHEMA: &str = r#"
389 ["null", "long"]
390 "#;
391
392 #[mz_ore::test]
393 fn test_to_avro_datum() {
394 let schema = Schema::from_str(SCHEMA).unwrap();
395 let mut record = Record::new(schema.top_node()).unwrap();
396 record.put("a", 27i64);
397 record.put("b", "foo");
398
399 let mut expected = Vec::new();
400 zig_i64(27, &mut expected);
401 zig_i64(3, &mut expected);
402 expected.extend(vec![b'f', b'o', b'o'].into_iter());
403
404 assert_eq!(to_avro_datum(&schema, record).unwrap(), expected);
405 }
406
407 #[mz_ore::test]
408 fn test_union() {
409 let schema = Schema::from_str(UNION_SCHEMA).unwrap();
410 let union = Value::Union {
411 index: 1,
412 inner: Box::new(Value::Long(3)),
413 n_variants: 2,
414 null_variant: Some(0),
415 };
416
417 let mut expected = Vec::new();
418 zig_i64(1, &mut expected);
419 zig_i64(3, &mut expected);
420
421 assert_eq!(to_avro_datum(&schema, union).unwrap(), expected);
422 }
423
424 #[mz_ore::test]
425 fn test_writer_append() {
426 let schema = Schema::from_str(SCHEMA).unwrap();
427 let mut writer = Writer::new(schema.clone(), Vec::new());
428
429 let mut record = Record::new(schema.top_node()).unwrap();
430 record.put("a", 27i64);
431 record.put("b", "foo");
432
433 let n1 = writer.append(record.clone()).unwrap();
434 let n2 = writer.append(record.clone()).unwrap();
435 let n3 = writer.flush().unwrap();
436 let result = writer.into_inner();
437
438 assert_eq!(n1 + n2 + n3, result.len());
439
440 let mut header = Vec::new();
441 header.extend(vec![b'O', b'b', b'j', b'\x01']);
442
443 let mut data = Vec::new();
444 zig_i64(27, &mut data);
445 zig_i64(3, &mut data);
446 data.extend(vec![b'f', b'o', b'o'].into_iter());
447 let data_copy = data.clone();
448 data.extend(data_copy);
449
450 assert_eq!(
452 result
453 .iter()
454 .cloned()
455 .take(header.len())
456 .collect::<Vec<u8>>(),
457 header
458 );
459 assert_eq!(
461 result
462 .iter()
463 .cloned()
464 .rev()
465 .skip(16)
466 .take(data.len())
467 .collect::<Vec<u8>>()
468 .into_iter()
469 .rev()
470 .collect::<Vec<u8>>(),
471 data
472 );
473 }
474
475 #[mz_ore::test]
476 fn test_writer_extend() {
477 let schema = Schema::from_str(SCHEMA).unwrap();
478 let mut writer = Writer::new(schema.clone(), Vec::new());
479
480 let mut record = Record::new(schema.top_node()).unwrap();
481 record.put("a", 27i64);
482 record.put("b", "foo");
483 let record_copy = record.clone();
484 let records = vec![record, record_copy];
485
486 let n1 = writer.extend(records.into_iter()).unwrap();
487 let n2 = writer.flush().unwrap();
488 let result = writer.into_inner();
489
490 assert_eq!(n1 + n2, result.len());
491
492 let mut header = Vec::new();
493 header.extend(vec![b'O', b'b', b'j', b'\x01']);
494
495 let mut data = Vec::new();
496 zig_i64(27, &mut data);
497 zig_i64(3, &mut data);
498 data.extend(vec![b'f', b'o', b'o'].into_iter());
499 let data_copy = data.clone();
500 data.extend(data_copy);
501
502 assert_eq!(
504 result
505 .iter()
506 .cloned()
507 .take(header.len())
508 .collect::<Vec<u8>>(),
509 header
510 );
511 assert_eq!(
513 result
514 .iter()
515 .cloned()
516 .rev()
517 .skip(16)
518 .take(data.len())
519 .collect::<Vec<u8>>()
520 .into_iter()
521 .rev()
522 .collect::<Vec<u8>>(),
523 data
524 );
525 }
526
527 #[mz_ore::test]
528 #[cfg_attr(miri, ignore)] fn test_writer_with_codec() {
530 let schema = Schema::from_str(SCHEMA).unwrap();
531 let mut writer = Writer::with_codec(schema.clone(), Vec::new(), Codec::Deflate);
532
533 let mut record = Record::new(schema.top_node()).unwrap();
534 record.put("a", 27i64);
535 record.put("b", "foo");
536
537 let n1 = writer.append(record.clone()).unwrap();
538 let n2 = writer.append(record.clone()).unwrap();
539 let n3 = writer.flush().unwrap();
540 let result = writer.into_inner();
541
542 assert_eq!(n1 + n2 + n3, result.len());
543
544 let mut header = Vec::new();
545 header.extend(vec![b'O', b'b', b'j', b'\x01']);
546
547 let mut data = Vec::new();
548 zig_i64(27, &mut data);
549 zig_i64(3, &mut data);
550 data.extend(vec![b'f', b'o', b'o'].into_iter());
551 let data_copy = data.clone();
552 data.extend(data_copy);
553 Codec::Deflate.compress(&mut data).unwrap();
554
555 assert_eq!(
557 result
558 .iter()
559 .cloned()
560 .take(header.len())
561 .collect::<Vec<u8>>(),
562 header
563 );
564 assert_eq!(
566 result
567 .iter()
568 .cloned()
569 .rev()
570 .skip(16)
571 .take(data.len())
572 .collect::<Vec<u8>>()
573 .into_iter()
574 .rev()
575 .collect::<Vec<u8>>(),
576 data
577 );
578 }
579
580 #[mz_ore::test]
581 #[cfg_attr(miri, ignore)] fn test_writer_roundtrip() {
583 let schema = Schema::from_str(SCHEMA).unwrap();
584 let make_record = |a: i64, b| {
585 let mut record = Record::new(schema.top_node()).unwrap();
586 record.put("a", a);
587 record.put("b", b);
588 record.avro()
589 };
590
591 let mut buf = Vec::new();
592
593 {
595 let mut writer = Writer::new(schema.clone(), &mut buf);
596 writer.append(make_record(27, "foo")).unwrap();
597 writer.flush().unwrap();
598 writer.append(make_record(54, "bar")).unwrap();
599 writer.flush().unwrap();
600 }
601
602 {
604 let mut writer = Writer::append_to(Cursor::new(&mut buf)).unwrap();
605 writer.append(make_record(42, "baz")).unwrap();
606 writer.flush().unwrap();
607 }
608
609 {
611 let mut writer = Writer::append_to(Cursor::new(&mut buf)).unwrap();
612 writer.append(make_record(84, "zar")).unwrap();
613 writer.flush().unwrap();
614 }
615
616 let reader = Reader::new(&buf[..]).unwrap();
618 let actual: Result<Vec<_>, _> = reader.collect();
619 let actual = actual.unwrap();
620 assert_eq!(
621 vec![
622 make_record(27, "foo"),
623 make_record(54, "bar"),
624 make_record(42, "baz"),
625 make_record(84, "zar")
626 ],
627 actual
628 );
629 }
630}