1use std::collections::{BTreeMap, BTreeSet};
92use std::fmt;
93use std::io::{ErrorKind, Read, Write};
94use std::net::TcpStream;
95use std::time::{Duration, Instant};
96
97use anyhow::{anyhow, bail};
98use bytes::{BufMut, BytesMut};
99use fallible_iterator::FallibleIterator;
100use mz_ore::collections::CollectionExt;
101use postgres_protocol::IsNull;
102use postgres_protocol::message::backend::Message;
103use postgres_protocol::message::frontend;
104use serde::{Deserialize, Serialize};
105
106struct PgConn {
107 stream: TcpStream,
108 recv_buf: BytesMut,
109 send_buf: BytesMut,
110 timeout: Duration,
111 verbose: bool,
112}
113
114impl PgConn {
115 fn new<'a>(
116 addr: &str,
117 user: &'a str,
118 timeout: Duration,
119 verbose: bool,
120 mut options: Vec<(&'a str, &'a str)>,
121 ) -> anyhow::Result<Self> {
122 let mut conn = Self {
123 stream: TcpStream::connect(addr)?,
124 recv_buf: BytesMut::new(),
125 send_buf: BytesMut::new(),
126 timeout,
127 verbose,
128 };
129
130 conn.stream.set_read_timeout(Some(timeout))?;
131 options.insert(0, ("user", user));
132 options.insert(0, ("welcome_message", "off"));
133 conn.send(|buf| frontend::startup_message(options, buf).unwrap())?;
134 match conn.recv()?.1 {
135 Message::AuthenticationOk => {}
136 _ => bail!("expected AuthenticationOk"),
137 };
138 conn.until(vec!["ReadyForQuery"], vec!['C', 'S', 'M'], BTreeSet::new())?;
139 Ok(conn)
140 }
141
142 fn send<F: FnOnce(&mut BytesMut)>(&mut self, f: F) -> anyhow::Result<()> {
143 self.send_buf.clear();
144 f(&mut self.send_buf);
145 self.stream.write_all(&self.send_buf)?;
146 Ok(())
147 }
148 fn until(
149 &mut self,
150 until: Vec<&str>,
151 err_field_typs: Vec<char>,
152 ignore: BTreeSet<String>,
153 ) -> anyhow::Result<Vec<String>> {
154 let mut msgs = Vec::with_capacity(until.len());
155 for expect in until {
156 loop {
157 let (ch, msg) = match self.recv() {
158 Ok((ch, msg)) => (ch, msg),
159 Err(err) => bail!("{}: waiting for {}, saw {:#?}", err, expect, msgs),
160 };
161 let (typ, args) = match msg {
162 Message::ReadyForQuery(body) => (
163 "ReadyForQuery",
164 serde_json::to_string(&ReadyForQuery {
165 status: char::from(body.status()).to_string(),
166 })?,
167 ),
168 Message::RowDescription(body) => (
169 "RowDescription",
170 serde_json::to_string(&RowDescription {
171 fields: body
172 .fields()
173 .map(|f| {
174 Ok(Field {
175 name: f.name().to_string(),
176 })
177 })
178 .collect()
179 .unwrap(),
180 })?,
181 ),
182 Message::DataRow(body) => {
183 let buf = body.buffer();
184 (
185 "DataRow",
186 serde_json::to_string(&DataRow {
187 fields: body
188 .ranges()
189 .map(|range| {
190 match range {
191 Some(range) => {
192 Ok(String::from_utf8(
194 buf[range.start..range.end].to_vec(),
195 )
196 .unwrap_or_else(|_| {
197 format!(
198 "{:?}",
199 buf[range.start..range.end].to_vec()
200 )
201 }))
202 }
203 None => Ok("NULL".into()),
204 }
205 })
206 .collect()
207 .unwrap(),
208 })?,
209 )
210 }
211 Message::CommandComplete(body) => (
212 "CommandComplete",
213 serde_json::to_string(&CommandComplete {
214 tag: body.tag().unwrap().to_string(),
215 })?,
216 ),
217 Message::ParseComplete => ("ParseComplete", "".to_string()),
218 Message::BindComplete => ("BindComplete", "".to_string()),
219 Message::PortalSuspended => ("PortalSuspended", "".to_string()),
220 Message::ErrorResponse(body) => (
221 "ErrorResponse",
222 serde_json::to_string(&ErrorResponse {
223 fields: body
224 .fields()
225 .filter_map(|f| {
226 let typ = char::from(f.type_());
227 if err_field_typs.contains(&typ) {
228 Ok(Some(ErrorField {
229 typ,
230 value: String::from_utf8_lossy(f.value_bytes())
231 .into_owned(),
232 }))
233 } else {
234 Ok(None)
235 }
236 })
237 .collect()
238 .unwrap(),
239 })?,
240 ),
241 Message::NoticeResponse(body) => (
242 "NoticeResponse",
243 serde_json::to_string(&ErrorResponse {
244 fields: body
245 .fields()
246 .filter_map(|f| {
247 let typ = char::from(f.type_());
248 if err_field_typs.contains(&typ) {
249 Ok(Some(ErrorField {
250 typ,
251 value: String::from_utf8_lossy(f.value_bytes())
252 .into_owned(),
253 }))
254 } else {
255 Ok(None)
256 }
257 })
258 .collect()
259 .unwrap(),
260 })?,
261 ),
262 Message::CopyOutResponse(body) => (
263 "CopyOut",
264 serde_json::to_string(&CopyOut {
265 format: format_name(body.format()),
266 column_formats: body
267 .column_formats()
268 .map(|format| Ok(format_name(format)))
269 .collect()
270 .unwrap(),
271 })?,
272 ),
273 Message::CopyInResponse(body) => (
274 "CopyIn",
275 serde_json::to_string(&CopyOut {
276 format: format_name(body.format()),
277 column_formats: body
278 .column_formats()
279 .map(|format| Ok(format_name(format)))
280 .collect()
281 .unwrap(),
282 })?,
283 ),
284 Message::CopyData(body) => (
285 "CopyData",
286 serde_json::to_string(
287 &std::str::from_utf8(body.data())
288 .map(|s| s.to_string())
289 .unwrap_or_else(|_| format!("{:?}", body.data())),
290 )?,
291 ),
292 Message::CopyDone => ("CopyDone", "".to_string()),
293 Message::ParameterDescription(body) => (
294 "ParameterDescription",
295 serde_json::to_string(&ParameterDescription {
296 parameters: body.parameters().collect().unwrap(),
297 })?,
298 ),
299 Message::ParameterStatus(_) => continue,
300 Message::NoData => ("NoData", "".to_string()),
301 Message::EmptyQueryResponse => ("EmptyQueryResponse", "".to_string()),
302 _ => ("UNKNOWN", format!("'{}'", ch)),
303 };
304 if self.verbose {
305 println!("RECV {}: {:?}", ch, typ);
306 }
307 if ignore.contains(typ) {
308 continue;
309 }
310 let mut s = typ.to_string();
311 if !args.is_empty() {
312 s.push(' ');
313 s.push_str(&args);
314 }
315 msgs.push(s);
316 if expect == typ {
317 break;
318 }
319 }
320 }
321 Ok(msgs)
322 }
323 pub fn recv(&mut self) -> anyhow::Result<(char, Message)> {
327 let mut buf = [0; 1024];
328 let until = Instant::now();
329 loop {
330 if until.elapsed() > self.timeout {
331 bail!("timeout after {:?} waiting for new message", self.timeout);
332 }
333 let mut ch: char = '0';
334 if self.recv_buf.len() > 0 {
335 ch = char::from(self.recv_buf[0]);
336 }
337 if let Some(msg) = Message::parse(&mut self.recv_buf)? {
338 return Ok((ch, msg));
339 };
340 let sz = match self.stream.read(&mut buf) {
342 Ok(n) => n,
343 Err(e) if e.kind() == ErrorKind::Interrupted => continue,
346 Err(e) => return Err(anyhow!(e)),
347 };
348 self.recv_buf.extend_from_slice(&buf[..sz]);
349 }
350 }
351}
352
353const DEFAULT_CONN: &str = "";
354
355pub struct PgTest {
356 addr: String,
357 user: String,
358 timeout: Duration,
359 conns: BTreeMap<String, PgConn>,
360 verbose: bool,
361}
362
363impl PgTest {
364 pub fn new(addr: String, user: String, timeout: Duration) -> anyhow::Result<Self> {
365 let verbose = std::env::var_os("PGTEST_VERBOSE").is_some();
366 let conn = PgConn::new(&addr, &user, timeout.clone(), verbose, vec![])?;
367 let mut conns = BTreeMap::new();
368 conns.insert(DEFAULT_CONN.to_string(), conn);
369
370 Ok(PgTest {
371 addr,
372 user,
373 timeout,
374 conns,
375 verbose,
376 })
377 }
378
379 fn get_conn(
382 &mut self,
383 name: Option<String>,
384 options: Vec<(&str, &str)>,
385 ) -> anyhow::Result<&mut PgConn> {
386 let name = name.unwrap_or_else(|| DEFAULT_CONN.to_string());
387 if !self.conns.contains_key(&name) {
388 let conn = PgConn::new(
389 &self.addr,
390 &self.user,
391 self.timeout.clone(),
392 self.verbose,
393 options,
394 )?;
395 self.conns.insert(name.clone(), conn);
396 }
397 Ok(self.conns.get_mut(&name).expect("must exist"))
398 }
399
400 pub fn send<F: Fn(&mut BytesMut)>(
401 &mut self,
402 conn: Option<String>,
403 options: Vec<(&str, &str)>,
404 f: F,
405 ) -> anyhow::Result<()> {
406 let conn = self.get_conn(conn, options)?;
407 conn.send(f)
408 }
409
410 pub fn until(
411 &mut self,
412 conn: Option<String>,
413 options: Vec<(&str, &str)>,
414 until: Vec<&str>,
415 err_field_typs: Vec<char>,
416 ignore: BTreeSet<String>,
417 ) -> anyhow::Result<Vec<String>> {
418 let conn = self.get_conn(conn, options)?;
419 conn.until(until, err_field_typs, ignore)
420 }
421}
422
423#[derive(Serialize)]
426pub struct ReadyForQuery {
427 pub status: String,
428}
429
430#[derive(Serialize)]
431pub struct RowDescription {
432 pub fields: Vec<Field>,
433}
434
435#[derive(Serialize)]
436pub struct Field {
437 pub name: String,
438}
439
440#[derive(Serialize)]
441pub struct DataRow {
442 pub fields: Vec<String>,
443}
444
445#[derive(Serialize)]
446pub struct CopyOut {
447 pub format: String,
448 pub column_formats: Vec<String>,
449}
450
451#[derive(Serialize)]
452pub struct ParameterDescription {
453 parameters: Vec<u32>,
454}
455
456#[derive(Serialize)]
457pub struct CommandComplete {
458 pub tag: String,
459}
460
461#[derive(Serialize)]
462pub struct ErrorResponse {
463 pub fields: Vec<ErrorField>,
464}
465
466#[derive(Serialize)]
467pub struct ErrorField {
468 pub typ: char,
469 pub value: String,
470}
471
472impl Drop for PgTest {
473 fn drop(&mut self) {
474 for conn in self.conns.values_mut() {
475 let _ = conn.send(frontend::terminate);
476 }
477 }
478}
479
480fn format_name<T>(format: T) -> String
481where
482 T: Copy + TryInto<u16> + fmt::Display,
483{
484 match format.try_into() {
485 Ok(0) => "text".to_string(),
486 Ok(1) => "binary".to_string(),
487 _ => format!("unknown: {}", format),
488 }
489}
490
491pub fn walk(addr: String, user: String, timeout: Duration, dir: &str) {
492 datadriven::walk(dir, |tf| run_test(tf, addr.clone(), user.clone(), timeout));
493}
494
495pub fn run_test(tf: &mut datadriven::TestFile, addr: String, user: String, timeout: Duration) {
496 let mut pgt = PgTest::new(addr, user, timeout).unwrap();
497 tf.run(|tc| -> String {
498 let lines = tc.input.lines();
499 let mut args = tc.args.clone();
500 let conn: Option<String> = args
501 .remove("conn")
502 .map(|args| Some(args.into_first()))
503 .unwrap_or(None);
504 let mut options: Vec<(&str, &str)> = Vec::new();
505 let cluster = args.remove("cluster");
506 if let Some(cluster) = &cluster {
507 let cluster = cluster.into_first();
508 options.push(("cluster", cluster.as_str()));
509 }
510 match tc.directive.as_str() {
511 "send" => {
512 for line in lines {
513 if pgt.verbose {
514 println!("SEND {}", line);
515 }
516 let mut line = line.splitn(2, ' ');
517 let typ = line.next().unwrap_or("");
518 let args = line.next().unwrap_or("{}");
519 pgt.send(conn.clone(), options.clone(), |buf| match typ {
520 "Query" => {
521 let v: Query = serde_json::from_str(args).unwrap();
522 frontend::query(&v.query, buf).unwrap();
523 }
524 "Parse" => {
525 let v: Parse = serde_json::from_str(args).unwrap();
526 frontend::parse(
527 &v.name.unwrap_or_else(|| "".into()),
528 &v.query,
529 vec![],
530 buf,
531 )
532 .unwrap();
533 }
534 "Sync" => frontend::sync(buf),
535 "Bind" => {
536 let v: Bind = serde_json::from_str(args).unwrap();
537 let values = v.values.unwrap_or_default();
538 if frontend::bind(
539 &v.portal.unwrap_or_else(|| "".into()),
540 &v.statement.unwrap_or_else(|| "".into()),
541 vec![], values, |t, buf| {
544 buf.put_slice(t.as_bytes());
545 Ok(IsNull::No)
546 }, v.result_formats.unwrap_or_default(),
548 buf,
549 )
550 .is_err()
551 {
552 panic!("bind error");
553 }
554 }
555 "Describe" => {
556 let v: Describe = serde_json::from_str(args).unwrap();
557 frontend::describe(
558 v.variant.unwrap_or_else(|| "S".into()).as_bytes()[0],
559 &v.name.unwrap_or_else(|| "".into()),
560 buf,
561 )
562 .unwrap();
563 }
564 "Execute" => {
565 let v: Execute = serde_json::from_str(args).unwrap();
566 frontend::execute(
567 &v.portal.unwrap_or_else(|| "".into()),
568 v.max_rows.unwrap_or(0),
569 buf,
570 )
571 .unwrap();
572 }
573 "CopyData" => {
574 let v: String = serde_json::from_str(args).unwrap();
575 frontend::CopyData::new(v.as_bytes()).unwrap().write(buf);
576 }
577 "CopyDone" => {
578 frontend::copy_done(buf);
579 }
580 "CopyFail" => {
581 let v: String = serde_json::from_str(args).unwrap();
582 frontend::copy_fail(&v, buf).unwrap();
583 }
584 _ => panic!("unknown message type {}", typ),
585 })
586 .unwrap();
587 }
588 "".to_string()
589 }
590 "until" => {
591 let err_field_typs = if let Some(_) = args.remove("no_error_fields") {
595 vec![]
596 } else {
597 match args.remove("err_field_typs") {
598 Some(typs) => typs.join("").chars().collect(),
599 None => vec!['C', 'S', 'M'],
600 }
601 };
602 let mut ignore = BTreeSet::new();
603 if let Some(values) = args.remove("ignore") {
604 for v in values {
605 ignore.insert(v);
606 }
607 }
608 if !args.is_empty() {
609 panic!("extra until arguments: {:?}", args);
610 }
611 format!(
612 "{}\n",
613 pgt.until(conn, options, lines.collect(), err_field_typs, ignore)
614 .unwrap()
615 .join("\n")
616 )
617 }
618 _ => panic!("unknown directive {}", tc.input),
619 }
620 })
621}
622
623#[derive(Deserialize)]
626#[serde(deny_unknown_fields)]
627pub struct Query {
628 pub query: String,
629}
630
631#[derive(Deserialize)]
632#[serde(deny_unknown_fields)]
633pub struct Parse {
634 pub name: Option<String>,
635 pub query: String,
636}
637
638#[derive(Deserialize)]
639#[serde(deny_unknown_fields)]
640pub struct Bind {
641 pub portal: Option<String>,
642 pub statement: Option<String>,
643 pub values: Option<Vec<String>>,
644 pub result_formats: Option<Vec<i16>>,
645}
646
647#[derive(Deserialize)]
648#[serde(deny_unknown_fields)]
649pub struct Execute {
650 pub portal: Option<String>,
651 pub max_rows: Option<i32>,
652}
653
654#[derive(Deserialize)]
655#[serde(deny_unknown_fields)]
656pub struct Describe {
657 pub variant: Option<String>,
658 pub name: Option<String>,
659}