mz_pgtest/
lib.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! pgtest is a Postgres wire protocol tester using datadriven test files. It
11//! can be used to send [specific
12//! messages](https://www.postgresql.org/docs/current/protocol-message-formats.html)
13//! to any Postgres-compatible server and record received messages.
14//!
15//! The following datadriven directives are supported. They support a
16//! `conn=name` argument to specify a non-default connection.
17//! - `send`: Sends input messages to the server. Arguments, if needed, are
18//! specified using JSON. Refer to the associated types to see supported
19//! arguments. Arguments can be omitted to use defaults.
20//! - `until`: Waits until input messages have been received from the server.
21//! Additional messages are accumulated and returned as well.
22//!
23//! The first time a `conn=name` argument is specified, `cluster=name` can also
24//! be specified to set the sessions cluster on initial connection.
25//!
26//! During debugging, set the environment variable `PGTEST_VERBOSE=1` to see
27//! messages sent and received.
28//!
29//! Supported `send` types:
30//! - [`Query`](struct.Query.html)
31//! - [`Parse`](struct.Parse.html)
32//! - [`Describe`](struct.Describe.html)
33//! - [`Bind`](struct.Bind.html)
34//! - [`Execute`](struct.Execute.html)
35//! - `Sync`
36//!
37//! Supported `until` arguments:
38//! - `no_error_fields` causes `ErrorResponse` messages to have empty contents.
39//! Useful when none of our fields match Postgres. For example `until
40//! no_error_fields`.
41//! - `err_field_typs` specifies the set of error message fields
42//! ([reference](https://www.postgresql.org/docs/current/protocol-error-fields.html)).
43//! The default is `CMS` (code, message, severity). For example: `until
44//! err_field_typs=SC` would return the severity and code fields in any
45//! ErrorResponse message.
46//!
47//! For example, to execute a simple prepared statement:
48//! ```pgtest
49//! send
50//! Parse {"query": "SELECT $1::text, 1 + $2::int4"}
51//! Bind {"values": ["blah", "4"]}
52//! Execute
53//! Sync
54//! ----
55//!
56//! until
57//! ReadyForQuery
58//! ----
59//! ParseComplete
60//! BindComplete
61//! DataRow {"fields":["blah","5"]}
62//! CommandComplete {"tag":"SELECT 1"}
63//! ReadyForQuery {"status":"I"}
64//! ```
65//!
66//! # Usage while writing tests
67//!
68//! The expected way to use this while writing tests is to generate output from
69//! a postgres server. Use the `pgtest-mz` directory if our output differs
70//! incompatibly from postgres. Write your test, excluding any lines after the
71//! `----` of the `until` directive. For example:
72//! ```pgtest
73//! send
74//! Query {"query": "SELECT 1"}
75//! ----
76//!
77//! until
78//! ReadyForQuery
79//! ----
80//! ```
81//! Then run the pgtest binary, enabling rewrites and pointing it at postgres:
82//! ```shell
83//! REWRITE=1 cargo run --bin mz-pgtest -- test/pgtest/test.pt --addr localhost:5432 --user postgres
84//! ```
85//! This will generate the expected output for the `until` directive. Now rerun
86//! against a running Materialize server:
87//! ```shell
88//! cargo run --bin mz-pgtest -- test/pgtest/test.pt
89//! ```
90
91use std::collections::{BTreeMap, BTreeSet};
92use std::fmt;
93use std::io::{ErrorKind, Read, Write};
94use std::net::TcpStream;
95use std::time::{Duration, Instant};
96
97use anyhow::{anyhow, bail};
98use bytes::{BufMut, BytesMut};
99use fallible_iterator::FallibleIterator;
100use mz_ore::collections::CollectionExt;
101use postgres_protocol::IsNull;
102use postgres_protocol::message::backend::Message;
103use postgres_protocol::message::frontend;
104use serde::{Deserialize, Serialize};
105
106struct PgConn {
107    stream: TcpStream,
108    recv_buf: BytesMut,
109    send_buf: BytesMut,
110    timeout: Duration,
111    verbose: bool,
112}
113
114impl PgConn {
115    fn new<'a>(
116        addr: &str,
117        user: &'a str,
118        timeout: Duration,
119        verbose: bool,
120        mut options: Vec<(&'a str, &'a str)>,
121    ) -> anyhow::Result<Self> {
122        let mut conn = Self {
123            stream: TcpStream::connect(addr)?,
124            recv_buf: BytesMut::new(),
125            send_buf: BytesMut::new(),
126            timeout,
127            verbose,
128        };
129
130        conn.stream.set_read_timeout(Some(timeout))?;
131        options.insert(0, ("user", user));
132        options.insert(0, ("welcome_message", "off"));
133        conn.send(|buf| frontend::startup_message(options, buf).unwrap())?;
134        match conn.recv()?.1 {
135            Message::AuthenticationOk => {}
136            _ => bail!("expected AuthenticationOk"),
137        };
138        conn.until(vec!["ReadyForQuery"], vec!['C', 'S', 'M'], BTreeSet::new())?;
139        Ok(conn)
140    }
141
142    fn send<F: FnOnce(&mut BytesMut)>(&mut self, f: F) -> anyhow::Result<()> {
143        self.send_buf.clear();
144        f(&mut self.send_buf);
145        self.stream.write_all(&self.send_buf)?;
146        Ok(())
147    }
148    fn until(
149        &mut self,
150        until: Vec<&str>,
151        err_field_typs: Vec<char>,
152        ignore: BTreeSet<String>,
153    ) -> anyhow::Result<Vec<String>> {
154        let mut msgs = Vec::with_capacity(until.len());
155        for expect in until {
156            loop {
157                let (ch, msg) = match self.recv() {
158                    Ok((ch, msg)) => (ch, msg),
159                    Err(err) => bail!("{}: waiting for {}, saw {:#?}", err, expect, msgs),
160                };
161                let (typ, args) = match msg {
162                    Message::ReadyForQuery(body) => (
163                        "ReadyForQuery",
164                        serde_json::to_string(&ReadyForQuery {
165                            status: char::from(body.status()).to_string(),
166                        })?,
167                    ),
168                    Message::RowDescription(body) => (
169                        "RowDescription",
170                        serde_json::to_string(&RowDescription {
171                            fields: body
172                                .fields()
173                                .map(|f| {
174                                    Ok(Field {
175                                        name: f.name().to_string(),
176                                    })
177                                })
178                                .collect()
179                                .unwrap(),
180                        })?,
181                    ),
182                    Message::DataRow(body) => {
183                        let buf = body.buffer();
184                        (
185                            "DataRow",
186                            serde_json::to_string(&DataRow {
187                                fields: body
188                                    .ranges()
189                                    .map(|range| {
190                                        match range {
191                                            Some(range) => {
192                                                // Attempt to convert to a String. If not utf8, print as array of bytes instead.
193                                                Ok(String::from_utf8(
194                                                    buf[range.start..range.end].to_vec(),
195                                                )
196                                                .unwrap_or_else(|_| {
197                                                    format!(
198                                                        "{:?}",
199                                                        buf[range.start..range.end].to_vec()
200                                                    )
201                                                }))
202                                            }
203                                            None => Ok("NULL".into()),
204                                        }
205                                    })
206                                    .collect()
207                                    .unwrap(),
208                            })?,
209                        )
210                    }
211                    Message::CommandComplete(body) => (
212                        "CommandComplete",
213                        serde_json::to_string(&CommandComplete {
214                            tag: body.tag().unwrap().to_string(),
215                        })?,
216                    ),
217                    Message::ParseComplete => ("ParseComplete", "".to_string()),
218                    Message::BindComplete => ("BindComplete", "".to_string()),
219                    Message::PortalSuspended => ("PortalSuspended", "".to_string()),
220                    Message::ErrorResponse(body) => (
221                        "ErrorResponse",
222                        serde_json::to_string(&ErrorResponse {
223                            fields: body
224                                .fields()
225                                .filter_map(|f| {
226                                    let typ = char::from(f.type_());
227                                    if err_field_typs.contains(&typ) {
228                                        Ok(Some(ErrorField {
229                                            typ,
230                                            value: String::from_utf8_lossy(f.value_bytes())
231                                                .into_owned(),
232                                        }))
233                                    } else {
234                                        Ok(None)
235                                    }
236                                })
237                                .collect()
238                                .unwrap(),
239                        })?,
240                    ),
241                    Message::NoticeResponse(body) => (
242                        "NoticeResponse",
243                        serde_json::to_string(&ErrorResponse {
244                            fields: body
245                                .fields()
246                                .filter_map(|f| {
247                                    let typ = char::from(f.type_());
248                                    if err_field_typs.contains(&typ) {
249                                        Ok(Some(ErrorField {
250                                            typ,
251                                            value: String::from_utf8_lossy(f.value_bytes())
252                                                .into_owned(),
253                                        }))
254                                    } else {
255                                        Ok(None)
256                                    }
257                                })
258                                .collect()
259                                .unwrap(),
260                        })?,
261                    ),
262                    Message::CopyOutResponse(body) => (
263                        "CopyOut",
264                        serde_json::to_string(&CopyOut {
265                            format: format_name(body.format()),
266                            column_formats: body
267                                .column_formats()
268                                .map(|format| Ok(format_name(format)))
269                                .collect()
270                                .unwrap(),
271                        })?,
272                    ),
273                    Message::CopyInResponse(body) => (
274                        "CopyIn",
275                        serde_json::to_string(&CopyOut {
276                            format: format_name(body.format()),
277                            column_formats: body
278                                .column_formats()
279                                .map(|format| Ok(format_name(format)))
280                                .collect()
281                                .unwrap(),
282                        })?,
283                    ),
284                    Message::CopyData(body) => (
285                        "CopyData",
286                        serde_json::to_string(
287                            &std::str::from_utf8(body.data())
288                                .map(|s| s.to_string())
289                                .unwrap_or_else(|_| format!("{:?}", body.data())),
290                        )?,
291                    ),
292                    Message::CopyDone => ("CopyDone", "".to_string()),
293                    Message::ParameterDescription(body) => (
294                        "ParameterDescription",
295                        serde_json::to_string(&ParameterDescription {
296                            parameters: body.parameters().collect().unwrap(),
297                        })?,
298                    ),
299                    Message::ParameterStatus(_) => continue,
300                    Message::NoData => ("NoData", "".to_string()),
301                    Message::EmptyQueryResponse => ("EmptyQueryResponse", "".to_string()),
302                    _ => ("UNKNOWN", format!("'{}'", ch)),
303                };
304                if self.verbose {
305                    println!("RECV {}: {:?}", ch, typ);
306                }
307                if ignore.contains(typ) {
308                    continue;
309                }
310                let mut s = typ.to_string();
311                if !args.is_empty() {
312                    s.push(' ');
313                    s.push_str(&args);
314                }
315                msgs.push(s);
316                if expect == typ {
317                    break;
318                }
319            }
320        }
321        Ok(msgs)
322    }
323    /// Returns the PostgreSQL message format and the `Message`.
324    ///
325    /// An error is returned if a new message is not received within the timeout.
326    pub fn recv(&mut self) -> anyhow::Result<(char, Message)> {
327        let mut buf = [0; 1024];
328        let until = Instant::now();
329        loop {
330            if until.elapsed() > self.timeout {
331                bail!("timeout after {:?} waiting for new message", self.timeout);
332            }
333            let mut ch: char = '0';
334            if self.recv_buf.len() > 0 {
335                ch = char::from(self.recv_buf[0]);
336            }
337            if let Some(msg) = Message::parse(&mut self.recv_buf)? {
338                return Ok((ch, msg));
339            };
340            // If there was no message, read more bytes.
341            let sz = match self.stream.read(&mut buf) {
342                Ok(n) => n,
343                // According to the `read` docs, this is a non-fatal retryable error.
344                // https://doc.rust-lang.org/std/io/trait.Read.html#errors
345                Err(e) if e.kind() == ErrorKind::Interrupted => continue,
346                Err(e) => return Err(anyhow!(e)),
347            };
348            self.recv_buf.extend_from_slice(&buf[..sz]);
349        }
350    }
351}
352
353const DEFAULT_CONN: &str = "";
354
355pub struct PgTest {
356    addr: String,
357    user: String,
358    timeout: Duration,
359    conns: BTreeMap<String, PgConn>,
360    verbose: bool,
361}
362
363impl PgTest {
364    pub fn new(addr: String, user: String, timeout: Duration) -> anyhow::Result<Self> {
365        let verbose = std::env::var_os("PGTEST_VERBOSE").is_some();
366        let conn = PgConn::new(&addr, &user, timeout.clone(), verbose, vec![])?;
367        let mut conns = BTreeMap::new();
368        conns.insert(DEFAULT_CONN.to_string(), conn);
369
370        Ok(PgTest {
371            addr,
372            user,
373            timeout,
374            conns,
375            verbose,
376        })
377    }
378
379    // Returns the named connection. If this is the first time that connection is
380    // seen, sends options.
381    fn get_conn(
382        &mut self,
383        name: Option<String>,
384        options: Vec<(&str, &str)>,
385    ) -> anyhow::Result<&mut PgConn> {
386        let name = name.unwrap_or_else(|| DEFAULT_CONN.to_string());
387        if !self.conns.contains_key(&name) {
388            let conn = PgConn::new(
389                &self.addr,
390                &self.user,
391                self.timeout.clone(),
392                self.verbose,
393                options,
394            )?;
395            self.conns.insert(name.clone(), conn);
396        }
397        Ok(self.conns.get_mut(&name).expect("must exist"))
398    }
399
400    pub fn send<F: Fn(&mut BytesMut)>(
401        &mut self,
402        conn: Option<String>,
403        options: Vec<(&str, &str)>,
404        f: F,
405    ) -> anyhow::Result<()> {
406        let conn = self.get_conn(conn, options)?;
407        conn.send(f)
408    }
409
410    pub fn until(
411        &mut self,
412        conn: Option<String>,
413        options: Vec<(&str, &str)>,
414        until: Vec<&str>,
415        err_field_typs: Vec<char>,
416        ignore: BTreeSet<String>,
417    ) -> anyhow::Result<Vec<String>> {
418        let conn = self.get_conn(conn, options)?;
419        conn.until(until, err_field_typs, ignore)
420    }
421}
422
423// Backend messages
424
425#[derive(Serialize)]
426pub struct ReadyForQuery {
427    pub status: String,
428}
429
430#[derive(Serialize)]
431pub struct RowDescription {
432    pub fields: Vec<Field>,
433}
434
435#[derive(Serialize)]
436pub struct Field {
437    pub name: String,
438}
439
440#[derive(Serialize)]
441pub struct DataRow {
442    pub fields: Vec<String>,
443}
444
445#[derive(Serialize)]
446pub struct CopyOut {
447    pub format: String,
448    pub column_formats: Vec<String>,
449}
450
451#[derive(Serialize)]
452pub struct ParameterDescription {
453    parameters: Vec<u32>,
454}
455
456#[derive(Serialize)]
457pub struct CommandComplete {
458    pub tag: String,
459}
460
461#[derive(Serialize)]
462pub struct ErrorResponse {
463    pub fields: Vec<ErrorField>,
464}
465
466#[derive(Serialize)]
467pub struct ErrorField {
468    pub typ: char,
469    pub value: String,
470}
471
472impl Drop for PgTest {
473    fn drop(&mut self) {
474        for conn in self.conns.values_mut() {
475            let _ = conn.send(frontend::terminate);
476        }
477    }
478}
479
480fn format_name<T>(format: T) -> String
481where
482    T: Copy + TryInto<u16> + fmt::Display,
483{
484    match format.try_into() {
485        Ok(0) => "text".to_string(),
486        Ok(1) => "binary".to_string(),
487        _ => format!("unknown: {}", format),
488    }
489}
490
491pub fn walk(addr: String, user: String, timeout: Duration, dir: &str) {
492    datadriven::walk(dir, |tf| run_test(tf, addr.clone(), user.clone(), timeout));
493}
494
495pub fn run_test(tf: &mut datadriven::TestFile, addr: String, user: String, timeout: Duration) {
496    let mut pgt = PgTest::new(addr, user, timeout).unwrap();
497    tf.run(|tc| -> String {
498        let lines = tc.input.lines();
499        let mut args = tc.args.clone();
500        let conn: Option<String> = args
501            .remove("conn")
502            .map(|args| Some(args.into_first()))
503            .unwrap_or(None);
504        let mut options: Vec<(&str, &str)> = Vec::new();
505        let cluster = args.remove("cluster");
506        if let Some(cluster) = &cluster {
507            let cluster = cluster.into_first();
508            options.push(("cluster", cluster.as_str()));
509        }
510        match tc.directive.as_str() {
511            "send" => {
512                for line in lines {
513                    if pgt.verbose {
514                        println!("SEND {}", line);
515                    }
516                    let mut line = line.splitn(2, ' ');
517                    let typ = line.next().unwrap_or("");
518                    let args = line.next().unwrap_or("{}");
519                    pgt.send(conn.clone(), options.clone(), |buf| match typ {
520                        "Query" => {
521                            let v: Query = serde_json::from_str(args).unwrap();
522                            frontend::query(&v.query, buf).unwrap();
523                        }
524                        "Parse" => {
525                            let v: Parse = serde_json::from_str(args).unwrap();
526                            frontend::parse(
527                                &v.name.unwrap_or_else(|| "".into()),
528                                &v.query,
529                                vec![],
530                                buf,
531                            )
532                            .unwrap();
533                        }
534                        "Sync" => frontend::sync(buf),
535                        "Bind" => {
536                            let v: Bind = serde_json::from_str(args).unwrap();
537                            let values = v.values.unwrap_or_default();
538                            if frontend::bind(
539                                &v.portal.unwrap_or_else(|| "".into()),
540                                &v.statement.unwrap_or_else(|| "".into()),
541                                vec![], // formats
542                                values, // values
543                                |t, buf| {
544                                    buf.put_slice(t.as_bytes());
545                                    Ok(IsNull::No)
546                                }, // serializer
547                                v.result_formats.unwrap_or_default(),
548                                buf,
549                            )
550                            .is_err()
551                            {
552                                panic!("bind error");
553                            }
554                        }
555                        "Describe" => {
556                            let v: Describe = serde_json::from_str(args).unwrap();
557                            frontend::describe(
558                                v.variant.unwrap_or_else(|| "S".into()).as_bytes()[0],
559                                &v.name.unwrap_or_else(|| "".into()),
560                                buf,
561                            )
562                            .unwrap();
563                        }
564                        "Execute" => {
565                            let v: Execute = serde_json::from_str(args).unwrap();
566                            frontend::execute(
567                                &v.portal.unwrap_or_else(|| "".into()),
568                                v.max_rows.unwrap_or(0),
569                                buf,
570                            )
571                            .unwrap();
572                        }
573                        "CopyData" => {
574                            let v: String = serde_json::from_str(args).unwrap();
575                            frontend::CopyData::new(v.as_bytes()).unwrap().write(buf);
576                        }
577                        "CopyDone" => {
578                            frontend::copy_done(buf);
579                        }
580                        "CopyFail" => {
581                            let v: String = serde_json::from_str(args).unwrap();
582                            frontend::copy_fail(&v, buf).unwrap();
583                        }
584                        _ => panic!("unknown message type {}", typ),
585                    })
586                    .unwrap();
587                }
588                "".to_string()
589            }
590            "until" => {
591                // Our error field values don't always match postgres. Default to reporting
592                // the error code (C) and message (M), but allow the user to specify which ones
593                // they want.
594                let err_field_typs = if let Some(_) = args.remove("no_error_fields") {
595                    vec![]
596                } else {
597                    match args.remove("err_field_typs") {
598                        Some(typs) => typs.join("").chars().collect(),
599                        None => vec!['C', 'S', 'M'],
600                    }
601                };
602                let mut ignore = BTreeSet::new();
603                if let Some(values) = args.remove("ignore") {
604                    for v in values {
605                        ignore.insert(v);
606                    }
607                }
608                if !args.is_empty() {
609                    panic!("extra until arguments: {:?}", args);
610                }
611                format!(
612                    "{}\n",
613                    pgt.until(conn, options, lines.collect(), err_field_typs, ignore)
614                        .unwrap()
615                        .join("\n")
616                )
617            }
618            _ => panic!("unknown directive {}", tc.input),
619        }
620    })
621}
622
623// Frontend messages
624
625#[derive(Deserialize)]
626#[serde(deny_unknown_fields)]
627pub struct Query {
628    pub query: String,
629}
630
631#[derive(Deserialize)]
632#[serde(deny_unknown_fields)]
633pub struct Parse {
634    pub name: Option<String>,
635    pub query: String,
636}
637
638#[derive(Deserialize)]
639#[serde(deny_unknown_fields)]
640pub struct Bind {
641    pub portal: Option<String>,
642    pub statement: Option<String>,
643    pub values: Option<Vec<String>>,
644    pub result_formats: Option<Vec<i16>>,
645}
646
647#[derive(Deserialize)]
648#[serde(deny_unknown_fields)]
649pub struct Execute {
650    pub portal: Option<String>,
651    pub max_rows: Option<i32>,
652}
653
654#[derive(Deserialize)]
655#[serde(deny_unknown_fields)]
656pub struct Describe {
657    pub variant: Option<String>,
658    pub name: Option<String>,
659}