mz_avro/
writer.rs

1// Copyright 2018 Flavien Raynaud.
2// Copyright Materialize, Inc. and contributors. All rights reserved.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License in the LICENSE file at the
7// root of this repository, or online at
8//
9//     http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16//
17// This file is derived from the avro-rs project, available at
18// https://github.com/flavray/avro-rs. It was incorporated
19// directly into Materialize on March 3, 2020.
20//
21// The original source code is subject to the terms of the MIT license, a copy
22// of which can be found in the LICENSE file at the root of this repository.
23
24//! Logic handling writing in Avro format at user level.
25
26use std::collections::BTreeMap;
27use std::fmt;
28use std::io::{Seek, SeekFrom, Write};
29
30use anyhow::Error;
31use rand::random;
32
33use crate::Codec;
34use crate::decode::AvroRead;
35use crate::encode::{encode, encode_ref, encode_to_vec};
36use crate::reader::Header;
37use crate::schema::{Schema, SchemaPiece};
38use crate::types::{ToAvro, Value};
39
40const SYNC_SIZE: usize = 16;
41const SYNC_INTERVAL: usize = 1000 * SYNC_SIZE; // TODO: parametrize in Writer
42
43const AVRO_OBJECT_HEADER: &[u8] = &[b'O', b'b', b'j', 1u8];
44
45/// Describes errors happened while validating Avro data.
46#[derive(Debug)]
47pub struct ValidationError(String);
48
49impl ValidationError {
50    pub fn new<S>(msg: S) -> ValidationError
51    where
52        S: Into<String>,
53    {
54        ValidationError(msg.into())
55    }
56}
57
58impl fmt::Display for ValidationError {
59    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
60        write!(f, "Validation error: {}", self.0)
61    }
62}
63
64impl std::error::Error for ValidationError {}
65
66/// Main interface for writing Avro Object Container Files.
67pub struct Writer<W> {
68    schema: Schema,
69    writer: W,
70    buffer: Vec<u8>,
71    num_values: usize,
72    codec: Option<Codec>,
73    marker: [u8; 16],
74    has_header: bool,
75}
76
77impl<W: Write> Writer<W> {
78    /// Creates a `Writer` for the `Schema` and something implementing the [`std::io::Write`]
79    /// trait to write to.
80    ///
81    /// This uses the no-compression [`Codec::Null`] when appending records.
82    pub fn new(schema: Schema, writer: W) -> Writer<W> {
83        Self::with_codec(schema, writer, Codec::Null)
84    }
85
86    /// Creates a `Writer` given a [`Schema`] and a specific compression [`Codec`]
87    pub fn with_codec(schema: Schema, writer: W, codec: Codec) -> Writer<W> {
88        Writer::with_codec_opt(schema, writer, Some(codec))
89    }
90
91    /// Create a `Writer` with the given parameters.
92    ///
93    /// All parameters have the same meaning as `with_codec`, but if `codec` is `None`
94    /// then no compression will be used and the `avro.codec` field in the header will be
95    /// omitted.
96    pub fn with_codec_opt(schema: Schema, writer: W, codec: Option<Codec>) -> Writer<W> {
97        let mut marker = [0; 16];
98        for i in 0..16 {
99            marker[i] = random::<u8>();
100        }
101
102        Writer {
103            schema,
104            writer,
105            buffer: Vec::with_capacity(SYNC_INTERVAL),
106            num_values: 0,
107            codec,
108            marker,
109            has_header: false,
110        }
111    }
112
113    /// Creates a `Writer` that appends to an existing OCF file.
114    pub fn append_to(mut file: W) -> Result<Writer<W>, Error>
115    where
116        W: AvroRead + Seek + Unpin + Send,
117    {
118        let header = Header::from_reader(&mut file)?;
119        let (schema, marker, codec) = header.into_parts();
120        file.seek(SeekFrom::End(0))?;
121        Ok(Writer {
122            schema,
123            writer: file,
124            buffer: Vec::with_capacity(SYNC_INTERVAL),
125            num_values: 0,
126            codec: Some(codec),
127            marker,
128            has_header: true,
129        })
130    }
131
132    /// Get a reference to the `Schema` associated to a `Writer`.
133    pub fn schema(&self) -> &Schema {
134        &self.schema
135    }
136
137    /// Append a compatible value (implementing the `ToAvro` trait) to a `Writer`, also performing
138    /// schema validation.
139    ///
140    /// Return the number of bytes written (it might be 0, see below).
141    ///
142    /// **NOTE** This function is not guaranteed to perform any actual write, since it relies on
143    /// internal buffering for performance reasons. If you want to be sure the value has been
144    /// written, then call [`flush`](struct.Writer.html#method.flush).
145    pub fn append<T: ToAvro>(&mut self, value: T) -> Result<usize, Error> {
146        let n = if !self.has_header {
147            let header = self.header()?;
148            let n = self.append_bytes(header.as_ref())?;
149            self.has_header = true;
150            n
151        } else {
152            0
153        };
154        let avro = value.avro();
155        write_value_ref(&self.schema, &avro, &mut self.buffer)?;
156
157        self.num_values += 1;
158
159        if self.buffer.len() >= SYNC_INTERVAL {
160            return self.flush().map(|b| b + n);
161        }
162
163        Ok(n)
164    }
165
166    /// Append a compatible value to a `Writer`, also performing schema validation.
167    ///
168    /// Return the number of bytes written (it might be 0, see below).
169    ///
170    /// **NOTE** This function is not guaranteed to perform any actual write, since it relies on
171    /// internal buffering for performance reasons. If you want to be sure the value has been
172    /// written, then call [`flush`](struct.Writer.html#method.flush).
173    pub fn append_value_ref(&mut self, value: &Value) -> Result<usize, Error> {
174        let n = if !self.has_header {
175            let header = self.header()?;
176            let n = self.append_bytes(header.as_ref())?;
177            self.has_header = true;
178            n
179        } else {
180            0
181        };
182
183        write_value_ref(&self.schema, value, &mut self.buffer)?;
184
185        self.num_values += 1;
186
187        if self.buffer.len() >= SYNC_INTERVAL {
188            return self.flush().map(|b| b + n);
189        }
190
191        Ok(n)
192    }
193
194    /// Extend a `Writer` with an `Iterator` of compatible values (implementing the `ToAvro`
195    /// trait), also performing schema validation.
196    ///
197    /// Return the number of bytes written.
198    ///
199    /// **NOTE** This function forces the written data to be flushed (an implicit
200    /// call to [`flush`](struct.Writer.html#method.flush) is performed).
201    pub fn extend<I, T: ToAvro>(&mut self, values: I) -> Result<usize, Error>
202    where
203        I: IntoIterator<Item = T>,
204    {
205        /*
206        https://github.com/rust-lang/rfcs/issues/811 :(
207        let mut stream = values
208            .filter_map(|value| value.serialize(&mut self.serializer).ok())
209            .map(|value| value.encode(self.schema))
210            .collect::<Option<Vec<_>>>()
211            .ok_or_else(|| err_msg("value does not match given schema"))?
212            .into_iter()
213            .fold(Vec::new(), |mut acc, stream| {
214                num_values += 1;
215                acc.extend(stream); acc
216            });
217        */
218
219        let mut num_bytes = 0;
220        for value in values {
221            num_bytes += self.append(value)?;
222        }
223        num_bytes += self.flush()?;
224
225        Ok(num_bytes)
226    }
227
228    /// Extend a `Writer` by appending each `Value` from a slice, while also performing schema
229    /// validation on each value appended.
230    ///
231    /// Return the number of bytes written.
232    ///
233    /// **NOTE** This function forces the written data to be flushed (an implicit
234    /// call to [`flush`](struct.Writer.html#method.flush) is performed).
235    pub fn extend_from_slice(&mut self, values: &[Value]) -> Result<usize, Error> {
236        let mut num_bytes = 0;
237        for value in values {
238            num_bytes += self.append_value_ref(value)?;
239        }
240        num_bytes += self.flush()?;
241
242        Ok(num_bytes)
243    }
244
245    /// Flush the content appended to a `Writer`. Call this function to make sure all the content
246    /// has been written before releasing the `Writer`.
247    ///
248    /// Return the number of bytes written.
249    pub fn flush(&mut self) -> Result<usize, Error> {
250        if self.num_values == 0 {
251            return Ok(0);
252        }
253
254        let compressor = self.codec.unwrap_or(Codec::Null);
255        compressor.compress(&mut self.buffer)?;
256
257        let num_values = self.num_values;
258        let stream_len = self.buffer.len();
259
260        let ls = Schema {
261            named: vec![],
262            indices: Default::default(),
263            top: SchemaPiece::Long.into(),
264        };
265
266        let num_bytes = self.append_raw(&num_values.avro(), &ls)?
267            + self.append_raw(&stream_len.avro(), &ls)?
268            + self.writer.write(self.buffer.as_ref())?
269            + self.append_marker()?;
270
271        self.buffer.clear();
272        self.num_values = 0;
273
274        Ok(num_bytes)
275    }
276
277    /// Return what the `Writer` is writing to, consuming the `Writer` itself.
278    ///
279    /// **NOTE** This function doesn't guarantee that everything gets written before consuming the
280    /// buffer. Please call [`flush`](struct.Writer.html#method.flush) before.
281    pub fn into_inner(self) -> W {
282        self.writer
283    }
284
285    /// Generate and append synchronization marker to the payload.
286    fn append_marker(&mut self) -> Result<usize, Error> {
287        // using .writer.write directly to avoid mutable borrow of self
288        // with ref borrowing of self.marker
289        Ok(self.writer.write(&self.marker)?)
290    }
291
292    /// Append a raw Avro Value to the payload avoiding to encode it again.
293    fn append_raw(&mut self, value: &Value, schema: &Schema) -> Result<usize, Error> {
294        self.append_bytes(encode_to_vec(value, schema).as_ref())
295    }
296
297    /// Append pure bytes to the payload.
298    fn append_bytes(&mut self, bytes: &[u8]) -> Result<usize, Error> {
299        Ok(self.writer.write(bytes)?)
300    }
301
302    /// Create an Avro header based on schema, codec and sync marker.
303    fn header(&self) -> Result<Vec<u8>, Error> {
304        let schema_bytes = serde_json::to_string(&self.schema)?.into_bytes();
305
306        let mut metadata = BTreeMap::new();
307        metadata.insert("avro.schema", Value::Bytes(schema_bytes));
308        if let Some(codec) = self.codec {
309            metadata.insert("avro.codec", codec.avro());
310        };
311
312        let mut header = Vec::new();
313        header.extend_from_slice(AVRO_OBJECT_HEADER);
314        encode(
315            &metadata.avro(),
316            &Schema {
317                named: vec![],
318                indices: Default::default(),
319                top: SchemaPiece::Map(Box::new(SchemaPiece::Bytes.into())).into(),
320            },
321            &mut header,
322        );
323        header.extend_from_slice(&self.marker);
324
325        Ok(header)
326    }
327}
328
329/// Encode a compatible value (implementing the `ToAvro` trait) into Avro format, also performing
330/// schema validation.
331///
332/// This is a function which gets the bytes buffer where to write as parameter instead of
333/// creating a new one like `to_avro_datum`.
334pub fn write_avro_datum<T: ToAvro>(
335    schema: &Schema,
336    value: T,
337    buffer: &mut Vec<u8>,
338) -> Result<(), Error> {
339    let avro = value.avro();
340    if !avro.validate(schema.top_node()) {
341        return Err(ValidationError::new("value does not match schema").into());
342    }
343    encode(&avro, schema, buffer);
344    Ok(())
345}
346
347fn write_value_ref(schema: &Schema, value: &Value, buffer: &mut Vec<u8>) -> Result<(), Error> {
348    if !value.validate(schema.top_node()) {
349        return Err(ValidationError::new("value does not match schema").into());
350    }
351    encode_ref(value, schema.top_node(), buffer);
352    Ok(())
353}
354
355/// Encode a compatible value (implementing the `ToAvro` trait) into Avro format, also
356/// performing schema validation.
357///
358/// **NOTE** This function has a quite small niche of usage and does NOT generate headers and sync
359/// markers; use [`Writer`](struct.Writer.html) to be fully Avro-compatible if you don't know what
360/// you are doing, instead.
361pub fn to_avro_datum<T: ToAvro>(schema: &Schema, value: T) -> Result<Vec<u8>, Error> {
362    let mut buffer = Vec::new();
363    write_avro_datum(schema, value, &mut buffer)?;
364    Ok(buffer)
365}
366
367#[cfg(test)]
368mod tests {
369    use std::io::Cursor;
370    use std::str::FromStr;
371
372    use serde::{Deserialize, Serialize};
373
374    use crate::Reader;
375    use crate::types::Record;
376    use crate::util::zig_i64;
377
378    use super::*;
379
380    static SCHEMA: &str = r#"
381            {
382                "type": "record",
383                "name": "test",
384                "fields": [
385                    {"name": "a", "type": "long", "default": 42},
386                    {"name": "b", "type": "string"}
387                ]
388            }
389        "#;
390    static UNION_SCHEMA: &str = r#"
391            ["null", "long"]
392        "#;
393
394    #[mz_ore::test]
395    fn test_to_avro_datum() {
396        let schema = Schema::from_str(SCHEMA).unwrap();
397        let mut record = Record::new(schema.top_node()).unwrap();
398        record.put("a", 27i64);
399        record.put("b", "foo");
400
401        let mut expected = Vec::new();
402        zig_i64(27, &mut expected);
403        zig_i64(3, &mut expected);
404        expected.extend(vec![b'f', b'o', b'o'].into_iter());
405
406        assert_eq!(to_avro_datum(&schema, record).unwrap(), expected);
407    }
408
409    #[mz_ore::test]
410    fn test_union() {
411        let schema = Schema::from_str(UNION_SCHEMA).unwrap();
412        let union = Value::Union {
413            index: 1,
414            inner: Box::new(Value::Long(3)),
415            n_variants: 2,
416            null_variant: Some(0),
417        };
418
419        let mut expected = Vec::new();
420        zig_i64(1, &mut expected);
421        zig_i64(3, &mut expected);
422
423        assert_eq!(to_avro_datum(&schema, union).unwrap(), expected);
424    }
425
426    #[mz_ore::test]
427    fn test_writer_append() {
428        let schema = Schema::from_str(SCHEMA).unwrap();
429        let mut writer = Writer::new(schema.clone(), Vec::new());
430
431        let mut record = Record::new(schema.top_node()).unwrap();
432        record.put("a", 27i64);
433        record.put("b", "foo");
434
435        let n1 = writer.append(record.clone()).unwrap();
436        let n2 = writer.append(record.clone()).unwrap();
437        let n3 = writer.flush().unwrap();
438        let result = writer.into_inner();
439
440        assert_eq!(n1 + n2 + n3, result.len());
441
442        let mut header = Vec::new();
443        header.extend(vec![b'O', b'b', b'j', b'\x01']);
444
445        let mut data = Vec::new();
446        zig_i64(27, &mut data);
447        zig_i64(3, &mut data);
448        data.extend(vec![b'f', b'o', b'o'].into_iter());
449        let data_copy = data.clone();
450        data.extend(data_copy);
451
452        // starts with magic
453        assert_eq!(
454            result
455                .iter()
456                .cloned()
457                .take(header.len())
458                .collect::<Vec<u8>>(),
459            header
460        );
461        // ends with data and sync marker
462        assert_eq!(
463            result
464                .iter()
465                .cloned()
466                .rev()
467                .skip(16)
468                .take(data.len())
469                .collect::<Vec<u8>>()
470                .into_iter()
471                .rev()
472                .collect::<Vec<u8>>(),
473            data
474        );
475    }
476
477    #[mz_ore::test]
478    fn test_writer_extend() {
479        let schema = Schema::from_str(SCHEMA).unwrap();
480        let mut writer = Writer::new(schema.clone(), Vec::new());
481
482        let mut record = Record::new(schema.top_node()).unwrap();
483        record.put("a", 27i64);
484        record.put("b", "foo");
485        let record_copy = record.clone();
486        let records = vec![record, record_copy];
487
488        let n1 = writer.extend(records.into_iter()).unwrap();
489        let n2 = writer.flush().unwrap();
490        let result = writer.into_inner();
491
492        assert_eq!(n1 + n2, result.len());
493
494        let mut header = Vec::new();
495        header.extend(vec![b'O', b'b', b'j', b'\x01']);
496
497        let mut data = Vec::new();
498        zig_i64(27, &mut data);
499        zig_i64(3, &mut data);
500        data.extend(vec![b'f', b'o', b'o'].into_iter());
501        let data_copy = data.clone();
502        data.extend(data_copy);
503
504        // starts with magic
505        assert_eq!(
506            result
507                .iter()
508                .cloned()
509                .take(header.len())
510                .collect::<Vec<u8>>(),
511            header
512        );
513        // ends with data and sync marker
514        assert_eq!(
515            result
516                .iter()
517                .cloned()
518                .rev()
519                .skip(16)
520                .take(data.len())
521                .collect::<Vec<u8>>()
522                .into_iter()
523                .rev()
524                .collect::<Vec<u8>>(),
525            data
526        );
527    }
528
529    #[derive(Debug, Clone, Deserialize, Serialize)]
530    struct TestSerdeSerialize {
531        a: i64,
532        b: String,
533    }
534
535    #[mz_ore::test]
536    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `deflateInit2_` on OS `linux`
537    fn test_writer_with_codec() {
538        let schema = Schema::from_str(SCHEMA).unwrap();
539        let mut writer = Writer::with_codec(schema.clone(), Vec::new(), Codec::Deflate);
540
541        let mut record = Record::new(schema.top_node()).unwrap();
542        record.put("a", 27i64);
543        record.put("b", "foo");
544
545        let n1 = writer.append(record.clone()).unwrap();
546        let n2 = writer.append(record.clone()).unwrap();
547        let n3 = writer.flush().unwrap();
548        let result = writer.into_inner();
549
550        assert_eq!(n1 + n2 + n3, result.len());
551
552        let mut header = Vec::new();
553        header.extend(vec![b'O', b'b', b'j', b'\x01']);
554
555        let mut data = Vec::new();
556        zig_i64(27, &mut data);
557        zig_i64(3, &mut data);
558        data.extend(vec![b'f', b'o', b'o'].into_iter());
559        let data_copy = data.clone();
560        data.extend(data_copy);
561        Codec::Deflate.compress(&mut data).unwrap();
562
563        // starts with magic
564        assert_eq!(
565            result
566                .iter()
567                .cloned()
568                .take(header.len())
569                .collect::<Vec<u8>>(),
570            header
571        );
572        // ends with data and sync marker
573        assert_eq!(
574            result
575                .iter()
576                .cloned()
577                .rev()
578                .skip(16)
579                .take(data.len())
580                .collect::<Vec<u8>>()
581                .into_iter()
582                .rev()
583                .collect::<Vec<u8>>(),
584            data
585        );
586    }
587
588    #[mz_ore::test]
589    #[cfg_attr(miri, ignore)] // slow
590    fn test_writer_roundtrip() {
591        let schema = Schema::from_str(SCHEMA).unwrap();
592        let make_record = |a: i64, b| {
593            let mut record = Record::new(schema.top_node()).unwrap();
594            record.put("a", a);
595            record.put("b", b);
596            record.avro()
597        };
598
599        let mut buf = Vec::new();
600
601        // Write out a file with two blocks.
602        {
603            let mut writer = Writer::new(schema.clone(), &mut buf);
604            writer.append(make_record(27, "foo")).unwrap();
605            writer.flush().unwrap();
606            writer.append(make_record(54, "bar")).unwrap();
607            writer.flush().unwrap();
608        }
609
610        // Add another block from a new writer, part i.
611        {
612            let mut writer = Writer::append_to(Cursor::new(&mut buf)).unwrap();
613            writer.append(make_record(42, "baz")).unwrap();
614            writer.flush().unwrap();
615        }
616
617        // Add another block from a new writer, part ii.
618        {
619            let mut writer = Writer::append_to(Cursor::new(&mut buf)).unwrap();
620            writer.append(make_record(84, "zar")).unwrap();
621            writer.flush().unwrap();
622        }
623
624        // Ensure all four blocks appear in the file.
625        let reader = Reader::new(&buf[..]).unwrap();
626        let actual: Result<Vec<_>, _> = reader.collect();
627        let actual = actual.unwrap();
628        assert_eq!(
629            vec![
630                make_record(27, "foo"),
631                make_record(54, "bar"),
632                make_record(42, "baz"),
633                make_record(84, "zar")
634            ],
635            actual
636        );
637    }
638}