mz_avro/
writer.rs

1// Copyright 2018 Flavien Raynaud.
2// Copyright Materialize, Inc. and contributors. All rights reserved.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License in the LICENSE file at the
7// root of this repository, or online at
8//
9//     http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16//
17// This file is derived from the avro-rs project, available at
18// https://github.com/flavray/avro-rs. It was incorporated
19// directly into Materialize on March 3, 2020.
20//
21// The original source code is subject to the terms of the MIT license, a copy
22// of which can be found in the LICENSE file at the root of this repository.
23
24//! Logic handling writing in Avro format at user level.
25
26use std::collections::BTreeMap;
27use std::fmt;
28use std::io::{Seek, SeekFrom, Write};
29
30use anyhow::Error;
31use rand::random;
32
33use crate::Codec;
34use crate::decode::AvroRead;
35use crate::encode::{encode, encode_ref, encode_to_vec};
36use crate::reader::Header;
37use crate::schema::{Schema, SchemaPiece};
38use crate::types::{ToAvro, Value};
39
40const SYNC_SIZE: usize = 16;
41const SYNC_INTERVAL: usize = 1000 * SYNC_SIZE; // TODO: parametrize in Writer
42
43const AVRO_OBJECT_HEADER: &[u8] = &[b'O', b'b', b'j', 1u8];
44
45/// Describes errors happened while validating Avro data.
46#[derive(Debug)]
47pub struct ValidationError(String);
48
49impl ValidationError {
50    pub fn new<S>(msg: S) -> ValidationError
51    where
52        S: Into<String>,
53    {
54        ValidationError(msg.into())
55    }
56}
57
58impl fmt::Display for ValidationError {
59    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
60        write!(f, "Validation error: {}", self.0)
61    }
62}
63
64impl std::error::Error for ValidationError {}
65
66/// Main interface for writing Avro Object Container Files.
67pub struct Writer<W> {
68    schema: Schema,
69    writer: W,
70    buffer: Vec<u8>,
71    num_values: usize,
72    codec: Option<Codec>,
73    marker: [u8; 16],
74    has_header: bool,
75}
76
77impl<W: Write> Writer<W> {
78    /// Creates a `Writer` for the `Schema` and something implementing the [`std::io::Write`]
79    /// trait to write to.
80    ///
81    /// This uses the no-compression [`Codec::Null`] when appending records.
82    pub fn new(schema: Schema, writer: W) -> Writer<W> {
83        Self::with_codec(schema, writer, Codec::Null)
84    }
85
86    /// Creates a `Writer` given a [`Schema`] and a specific compression [`Codec`]
87    pub fn with_codec(schema: Schema, writer: W, codec: Codec) -> Writer<W> {
88        Writer::with_codec_opt(schema, writer, Some(codec))
89    }
90
91    /// Create a `Writer` with the given parameters.
92    ///
93    /// All parameters have the same meaning as `with_codec`, but if `codec` is `None`
94    /// then no compression will be used and the `avro.codec` field in the header will be
95    /// omitted.
96    pub fn with_codec_opt(schema: Schema, writer: W, codec: Option<Codec>) -> Writer<W> {
97        let mut marker = [0; 16];
98        for i in 0..16 {
99            marker[i] = random::<u8>();
100        }
101
102        Writer {
103            schema,
104            writer,
105            buffer: Vec::with_capacity(SYNC_INTERVAL),
106            num_values: 0,
107            codec,
108            marker,
109            has_header: false,
110        }
111    }
112
113    /// Creates a `Writer` that appends to an existing OCF file.
114    pub fn append_to(mut file: W) -> Result<Writer<W>, Error>
115    where
116        W: AvroRead + Seek + Unpin + Send,
117    {
118        let header = Header::from_reader(&mut file)?;
119        let (schema, marker, codec) = header.into_parts();
120        file.seek(SeekFrom::End(0))?;
121        Ok(Writer {
122            schema,
123            writer: file,
124            buffer: Vec::with_capacity(SYNC_INTERVAL),
125            num_values: 0,
126            codec: Some(codec),
127            marker,
128            has_header: true,
129        })
130    }
131
132    /// Get a reference to the `Schema` associated to a `Writer`.
133    pub fn schema(&self) -> &Schema {
134        &self.schema
135    }
136
137    /// Append a compatible value (implementing the `ToAvro` trait) to a `Writer`, also performing
138    /// schema validation.
139    ///
140    /// Return the number of bytes written (it might be 0, see below).
141    ///
142    /// **NOTE** This function is not guaranteed to perform any actual write, since it relies on
143    /// internal buffering for performance reasons. If you want to be sure the value has been
144    /// written, then call [`flush`](struct.Writer.html#method.flush).
145    pub fn append<T: ToAvro>(&mut self, value: T) -> Result<usize, Error> {
146        let n = if !self.has_header {
147            let header = self.header()?;
148            let n = self.append_bytes(header.as_ref())?;
149            self.has_header = true;
150            n
151        } else {
152            0
153        };
154        let avro = value.avro();
155        write_value_ref(&self.schema, &avro, &mut self.buffer)?;
156
157        self.num_values += 1;
158
159        if self.buffer.len() >= SYNC_INTERVAL {
160            return self.flush().map(|b| b + n);
161        }
162
163        Ok(n)
164    }
165
166    /// Append a compatible value to a `Writer`, also performing schema validation.
167    ///
168    /// Return the number of bytes written (it might be 0, see below).
169    ///
170    /// **NOTE** This function is not guaranteed to perform any actual write, since it relies on
171    /// internal buffering for performance reasons. If you want to be sure the value has been
172    /// written, then call [`flush`](struct.Writer.html#method.flush).
173    pub fn append_value_ref(&mut self, value: &Value) -> Result<usize, Error> {
174        let n = if !self.has_header {
175            let header = self.header()?;
176            let n = self.append_bytes(header.as_ref())?;
177            self.has_header = true;
178            n
179        } else {
180            0
181        };
182
183        write_value_ref(&self.schema, value, &mut self.buffer)?;
184
185        self.num_values += 1;
186
187        if self.buffer.len() >= SYNC_INTERVAL {
188            return self.flush().map(|b| b + n);
189        }
190
191        Ok(n)
192    }
193
194    /// Extend a `Writer` with an `Iterator` of compatible values (implementing the `ToAvro`
195    /// trait), also performing schema validation.
196    ///
197    /// Return the number of bytes written.
198    ///
199    /// **NOTE** This function forces the written data to be flushed (an implicit
200    /// call to [`flush`](struct.Writer.html#method.flush) is performed).
201    pub fn extend<I, T: ToAvro>(&mut self, values: I) -> Result<usize, Error>
202    where
203        I: IntoIterator<Item = T>,
204    {
205        /*
206        https://github.com/rust-lang/rfcs/issues/811 :(
207        let mut stream = values
208            .filter_map(|value| value.serialize(&mut self.serializer).ok())
209            .map(|value| value.encode(self.schema))
210            .collect::<Option<Vec<_>>>()
211            .ok_or_else(|| err_msg("value does not match given schema"))?
212            .into_iter()
213            .fold(Vec::new(), |mut acc, stream| {
214                num_values += 1;
215                acc.extend(stream); acc
216            });
217        */
218
219        let mut num_bytes = 0;
220        for value in values {
221            num_bytes += self.append(value)?;
222        }
223        num_bytes += self.flush()?;
224
225        Ok(num_bytes)
226    }
227
228    /// Extend a `Writer` by appending each `Value` from a slice, while also performing schema
229    /// validation on each value appended.
230    ///
231    /// Return the number of bytes written.
232    ///
233    /// **NOTE** This function forces the written data to be flushed (an implicit
234    /// call to [`flush`](struct.Writer.html#method.flush) is performed).
235    pub fn extend_from_slice(&mut self, values: &[Value]) -> Result<usize, Error> {
236        let mut num_bytes = 0;
237        for value in values {
238            num_bytes += self.append_value_ref(value)?;
239        }
240        num_bytes += self.flush()?;
241
242        Ok(num_bytes)
243    }
244
245    /// Flush the content appended to a `Writer`. Call this function to make sure all the content
246    /// has been written before releasing the `Writer`.
247    ///
248    /// Return the number of bytes written.
249    pub fn flush(&mut self) -> Result<usize, Error> {
250        if self.num_values == 0 {
251            return Ok(0);
252        }
253
254        let compressor = self.codec.unwrap_or(Codec::Null);
255        compressor.compress(&mut self.buffer)?;
256
257        let num_values = self.num_values;
258        let stream_len = self.buffer.len();
259
260        let ls = Schema {
261            named: vec![],
262            indices: Default::default(),
263            top: SchemaPiece::Long.into(),
264        };
265
266        let num_bytes = self.append_raw(&num_values.avro(), &ls)?
267            + self.append_raw(&stream_len.avro(), &ls)?
268            + self.writer.write(self.buffer.as_ref())?
269            + self.append_marker()?;
270
271        self.buffer.clear();
272        self.num_values = 0;
273
274        Ok(num_bytes)
275    }
276
277    /// Return what the `Writer` is writing to, consuming the `Writer` itself.
278    ///
279    /// **NOTE** This function doesn't guarantee that everything gets written before consuming the
280    /// buffer. Please call [`flush`](struct.Writer.html#method.flush) before.
281    pub fn into_inner(self) -> W {
282        self.writer
283    }
284
285    /// Generate and append synchronization marker to the payload.
286    fn append_marker(&mut self) -> Result<usize, Error> {
287        // using .writer.write directly to avoid mutable borrow of self
288        // with ref borrowing of self.marker
289        Ok(self.writer.write(&self.marker)?)
290    }
291
292    /// Append a raw Avro Value to the payload avoiding to encode it again.
293    fn append_raw(&mut self, value: &Value, schema: &Schema) -> Result<usize, Error> {
294        self.append_bytes(encode_to_vec(value, schema).as_ref())
295    }
296
297    /// Append pure bytes to the payload.
298    fn append_bytes(&mut self, bytes: &[u8]) -> Result<usize, Error> {
299        Ok(self.writer.write(bytes)?)
300    }
301
302    /// Create an Avro header based on schema, codec and sync marker.
303    fn header(&self) -> Result<Vec<u8>, Error> {
304        let schema_bytes = serde_json::to_string(&self.schema)?.into_bytes();
305
306        let mut metadata = BTreeMap::new();
307        metadata.insert("avro.schema", Value::Bytes(schema_bytes));
308        if let Some(codec) = self.codec {
309            metadata.insert("avro.codec", codec.avro());
310        };
311
312        let mut header = Vec::new();
313        header.extend_from_slice(AVRO_OBJECT_HEADER);
314        encode(
315            &metadata.avro(),
316            &Schema {
317                named: vec![],
318                indices: Default::default(),
319                top: SchemaPiece::Map(Box::new(SchemaPiece::Bytes.into())).into(),
320            },
321            &mut header,
322        );
323        header.extend_from_slice(&self.marker);
324
325        Ok(header)
326    }
327}
328
329/// Encode a compatible value (implementing the `ToAvro` trait) into Avro format, also performing
330/// schema validation.
331///
332/// This is a function which gets the bytes buffer where to write as parameter instead of
333/// creating a new one like `to_avro_datum`.
334pub fn write_avro_datum<T: ToAvro>(
335    schema: &Schema,
336    value: T,
337    buffer: &mut Vec<u8>,
338) -> Result<(), Error> {
339    let avro = value.avro();
340    if !avro.validate(schema.top_node()) {
341        return Err(ValidationError::new("value does not match schema").into());
342    }
343    encode(&avro, schema, buffer);
344    Ok(())
345}
346
347fn write_value_ref(schema: &Schema, value: &Value, buffer: &mut Vec<u8>) -> Result<(), Error> {
348    if !value.validate(schema.top_node()) {
349        return Err(ValidationError::new("value does not match schema").into());
350    }
351    encode_ref(value, schema.top_node(), buffer);
352    Ok(())
353}
354
355/// Encode a compatible value (implementing the `ToAvro` trait) into Avro format, also
356/// performing schema validation.
357///
358/// **NOTE** This function has a quite small niche of usage and does NOT generate headers and sync
359/// markers; use [`Writer`](struct.Writer.html) to be fully Avro-compatible if you don't know what
360/// you are doing, instead.
361pub fn to_avro_datum<T: ToAvro>(schema: &Schema, value: T) -> Result<Vec<u8>, Error> {
362    let mut buffer = Vec::new();
363    write_avro_datum(schema, value, &mut buffer)?;
364    Ok(buffer)
365}
366
367#[cfg(test)]
368mod tests {
369    use std::io::Cursor;
370    use std::str::FromStr;
371
372    use crate::Reader;
373    use crate::types::Record;
374    use crate::util::zig_i64;
375
376    use super::*;
377
378    static SCHEMA: &str = r#"
379            {
380                "type": "record",
381                "name": "test",
382                "fields": [
383                    {"name": "a", "type": "long", "default": 42},
384                    {"name": "b", "type": "string"}
385                ]
386            }
387        "#;
388    static UNION_SCHEMA: &str = r#"
389            ["null", "long"]
390        "#;
391
392    #[mz_ore::test]
393    fn test_to_avro_datum() {
394        let schema = Schema::from_str(SCHEMA).unwrap();
395        let mut record = Record::new(schema.top_node()).unwrap();
396        record.put("a", 27i64);
397        record.put("b", "foo");
398
399        let mut expected = Vec::new();
400        zig_i64(27, &mut expected);
401        zig_i64(3, &mut expected);
402        expected.extend(vec![b'f', b'o', b'o'].into_iter());
403
404        assert_eq!(to_avro_datum(&schema, record).unwrap(), expected);
405    }
406
407    #[mz_ore::test]
408    fn test_union() {
409        let schema = Schema::from_str(UNION_SCHEMA).unwrap();
410        let union = Value::Union {
411            index: 1,
412            inner: Box::new(Value::Long(3)),
413            n_variants: 2,
414            null_variant: Some(0),
415        };
416
417        let mut expected = Vec::new();
418        zig_i64(1, &mut expected);
419        zig_i64(3, &mut expected);
420
421        assert_eq!(to_avro_datum(&schema, union).unwrap(), expected);
422    }
423
424    #[mz_ore::test]
425    fn test_writer_append() {
426        let schema = Schema::from_str(SCHEMA).unwrap();
427        let mut writer = Writer::new(schema.clone(), Vec::new());
428
429        let mut record = Record::new(schema.top_node()).unwrap();
430        record.put("a", 27i64);
431        record.put("b", "foo");
432
433        let n1 = writer.append(record.clone()).unwrap();
434        let n2 = writer.append(record.clone()).unwrap();
435        let n3 = writer.flush().unwrap();
436        let result = writer.into_inner();
437
438        assert_eq!(n1 + n2 + n3, result.len());
439
440        let mut header = Vec::new();
441        header.extend(vec![b'O', b'b', b'j', b'\x01']);
442
443        let mut data = Vec::new();
444        zig_i64(27, &mut data);
445        zig_i64(3, &mut data);
446        data.extend(vec![b'f', b'o', b'o'].into_iter());
447        let data_copy = data.clone();
448        data.extend(data_copy);
449
450        // starts with magic
451        assert_eq!(
452            result
453                .iter()
454                .cloned()
455                .take(header.len())
456                .collect::<Vec<u8>>(),
457            header
458        );
459        // ends with data and sync marker
460        assert_eq!(
461            result
462                .iter()
463                .cloned()
464                .rev()
465                .skip(16)
466                .take(data.len())
467                .collect::<Vec<u8>>()
468                .into_iter()
469                .rev()
470                .collect::<Vec<u8>>(),
471            data
472        );
473    }
474
475    #[mz_ore::test]
476    fn test_writer_extend() {
477        let schema = Schema::from_str(SCHEMA).unwrap();
478        let mut writer = Writer::new(schema.clone(), Vec::new());
479
480        let mut record = Record::new(schema.top_node()).unwrap();
481        record.put("a", 27i64);
482        record.put("b", "foo");
483        let record_copy = record.clone();
484        let records = vec![record, record_copy];
485
486        let n1 = writer.extend(records.into_iter()).unwrap();
487        let n2 = writer.flush().unwrap();
488        let result = writer.into_inner();
489
490        assert_eq!(n1 + n2, result.len());
491
492        let mut header = Vec::new();
493        header.extend(vec![b'O', b'b', b'j', b'\x01']);
494
495        let mut data = Vec::new();
496        zig_i64(27, &mut data);
497        zig_i64(3, &mut data);
498        data.extend(vec![b'f', b'o', b'o'].into_iter());
499        let data_copy = data.clone();
500        data.extend(data_copy);
501
502        // starts with magic
503        assert_eq!(
504            result
505                .iter()
506                .cloned()
507                .take(header.len())
508                .collect::<Vec<u8>>(),
509            header
510        );
511        // ends with data and sync marker
512        assert_eq!(
513            result
514                .iter()
515                .cloned()
516                .rev()
517                .skip(16)
518                .take(data.len())
519                .collect::<Vec<u8>>()
520                .into_iter()
521                .rev()
522                .collect::<Vec<u8>>(),
523            data
524        );
525    }
526
527    #[mz_ore::test]
528    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `deflateInit2_` on OS `linux`
529    fn test_writer_with_codec() {
530        let schema = Schema::from_str(SCHEMA).unwrap();
531        let mut writer = Writer::with_codec(schema.clone(), Vec::new(), Codec::Deflate);
532
533        let mut record = Record::new(schema.top_node()).unwrap();
534        record.put("a", 27i64);
535        record.put("b", "foo");
536
537        let n1 = writer.append(record.clone()).unwrap();
538        let n2 = writer.append(record.clone()).unwrap();
539        let n3 = writer.flush().unwrap();
540        let result = writer.into_inner();
541
542        assert_eq!(n1 + n2 + n3, result.len());
543
544        let mut header = Vec::new();
545        header.extend(vec![b'O', b'b', b'j', b'\x01']);
546
547        let mut data = Vec::new();
548        zig_i64(27, &mut data);
549        zig_i64(3, &mut data);
550        data.extend(vec![b'f', b'o', b'o'].into_iter());
551        let data_copy = data.clone();
552        data.extend(data_copy);
553        Codec::Deflate.compress(&mut data).unwrap();
554
555        // starts with magic
556        assert_eq!(
557            result
558                .iter()
559                .cloned()
560                .take(header.len())
561                .collect::<Vec<u8>>(),
562            header
563        );
564        // ends with data and sync marker
565        assert_eq!(
566            result
567                .iter()
568                .cloned()
569                .rev()
570                .skip(16)
571                .take(data.len())
572                .collect::<Vec<u8>>()
573                .into_iter()
574                .rev()
575                .collect::<Vec<u8>>(),
576            data
577        );
578    }
579
580    #[mz_ore::test]
581    #[cfg_attr(miri, ignore)] // slow
582    fn test_writer_roundtrip() {
583        let schema = Schema::from_str(SCHEMA).unwrap();
584        let make_record = |a: i64, b| {
585            let mut record = Record::new(schema.top_node()).unwrap();
586            record.put("a", a);
587            record.put("b", b);
588            record.avro()
589        };
590
591        let mut buf = Vec::new();
592
593        // Write out a file with two blocks.
594        {
595            let mut writer = Writer::new(schema.clone(), &mut buf);
596            writer.append(make_record(27, "foo")).unwrap();
597            writer.flush().unwrap();
598            writer.append(make_record(54, "bar")).unwrap();
599            writer.flush().unwrap();
600        }
601
602        // Add another block from a new writer, part i.
603        {
604            let mut writer = Writer::append_to(Cursor::new(&mut buf)).unwrap();
605            writer.append(make_record(42, "baz")).unwrap();
606            writer.flush().unwrap();
607        }
608
609        // Add another block from a new writer, part ii.
610        {
611            let mut writer = Writer::append_to(Cursor::new(&mut buf)).unwrap();
612            writer.append(make_record(84, "zar")).unwrap();
613            writer.flush().unwrap();
614        }
615
616        // Ensure all four blocks appear in the file.
617        let reader = Reader::new(&buf[..]).unwrap();
618        let actual: Result<Vec<_>, _> = reader.collect();
619        let actual = actual.unwrap();
620        assert_eq!(
621            vec![
622                make_record(27, "foo"),
623                make_record(54, "bar"),
624                make_record(42, "baz"),
625                make_record(84, "zar")
626            ],
627            actual
628        );
629    }
630}