1use std::collections::BTreeMap;
11use std::convert::TryFrom;
12use std::future::Future;
13use std::sync::Arc;
14use std::time::Instant;
15use std::{iter, mem};
16
17use byteorder::{ByteOrder, NetworkEndian};
18use futures::future::{BoxFuture, FutureExt, pending};
19use itertools::izip;
20use mz_adapter::client::RecordFirstRowStream;
21use mz_adapter::session::{
22 EndTransactionAction, InProgressRows, Portal, PortalState, SessionConfig, TransactionStatus,
23};
24use mz_adapter::statement_logging::{StatementEndedExecutionReason, StatementExecutionStrategy};
25use mz_adapter::{
26 AdapterError, AdapterNotice, ExecuteContextExtra, ExecuteResponse, PeekResponseUnary,
27 RowsFuture, verify_datum_desc,
28};
29use mz_auth::password::Password;
30use mz_frontegg_auth::Authenticator as FronteggAuthentication;
31use mz_ore::cast::CastFrom;
32use mz_ore::netio::AsyncReady;
33use mz_ore::str::StrExt;
34use mz_ore::{assert_none, assert_ok, instrument};
35use mz_pgcopy::{CopyCsvFormatParams, CopyFormatParams, CopyTextFormatParams};
36use mz_pgwire_common::{
37 ConnectionCounter, ErrorResponse, Format, FrontendMessage, Severity, VERSION_3, VERSIONS,
38};
39use mz_repr::user::InternalUserMetadata;
40use mz_repr::{
41 CatalogItemId, ColumnIndex, Datum, RelationDesc, RelationType, RowArena, RowIterator, RowRef,
42 ScalarType,
43};
44use mz_server_core::TlsMode;
45use mz_sql::ast::display::AstDisplay;
46use mz_sql::ast::{CopyDirection, CopyStatement, FetchDirection, Ident, Raw, Statement};
47use mz_sql::parse::StatementParseResult;
48use mz_sql::plan::{CopyFormat, ExecuteTimeout, StatementDesc};
49use mz_sql::session::metadata::SessionMetadata;
50use mz_sql::session::user::INTERNAL_USER_NAMES;
51use mz_sql::session::vars::{MAX_COPY_FROM_SIZE, Var, VarInput};
52use postgres::error::SqlState;
53use tokio::io::{self, AsyncRead, AsyncWrite};
54use tokio::select;
55use tokio::sync::mpsc::UnboundedReceiver;
56use tokio::time::{self};
57use tokio_stream::wrappers::UnboundedReceiverStream;
58use tracing::{Instrument, debug, debug_span, warn};
59use uuid::Uuid;
60
61use crate::codec::FramedConn;
62use crate::message::{self, BackendMessage};
63
64pub fn match_handshake(buf: &[u8]) -> bool {
68 if buf.len() < 8 {
78 return false;
79 }
80 let version = NetworkEndian::read_i32(&buf[4..8]);
81 VERSIONS.contains(&version)
82}
83
84pub struct RunParams<'a, A> {
86 pub tls_mode: Option<TlsMode>,
88 pub adapter_client: mz_adapter::Client,
90 pub conn: &'a mut FramedConn<A>,
92 pub conn_uuid: Uuid,
94 pub version: i32,
96 pub params: BTreeMap<String, String>,
98 pub frontegg: Option<&'a FronteggAuthentication>,
100 pub use_self_hosted_auth: bool,
102 pub internal: bool,
105 pub active_connection_counter: ConnectionCounter,
107 pub helm_chart_version: Option<String>,
109}
110
111#[mz_ore::instrument(level = "debug")]
121pub async fn run<'a, A>(
122 RunParams {
123 tls_mode,
124 adapter_client,
125 conn,
126 conn_uuid,
127 version,
128 mut params,
129 frontegg,
130 use_self_hosted_auth,
131 internal,
132 active_connection_counter,
133 helm_chart_version,
134 }: RunParams<'a, A>,
135) -> Result<(), io::Error>
136where
137 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin,
138{
139 if version != VERSION_3 {
140 return conn
141 .send(ErrorResponse::fatal(
142 SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION,
143 "server does not support the client's requested protocol version",
144 ))
145 .await;
146 }
147
148 let user = params.remove("user").unwrap_or_else(String::new);
149
150 if internal {
151 if !INTERNAL_USER_NAMES.contains(&user) {
153 let msg = format!("unauthorized login to user '{user}'");
154 return conn
155 .send(ErrorResponse::fatal(SqlState::INSUFFICIENT_PRIVILEGE, msg))
156 .await;
157 }
158 } else {
159 if mz_adapter::catalog::is_reserved_role_name(user.as_str()) {
161 let msg = format!("unauthorized login to user '{user}'");
162 return conn
163 .send(ErrorResponse::fatal(SqlState::INSUFFICIENT_PRIVILEGE, msg))
164 .await;
165 }
166 }
167
168 if let Err(err) = conn.inner().ensure_tls_compatibility(&tls_mode) {
169 return conn.send(err).await;
170 }
171
172 let (mut session, expired) = if let Some(frontegg) = frontegg {
173 conn.send(BackendMessage::AuthenticationCleartextPassword)
174 .await?;
175 conn.flush().await?;
176 let password = match conn.recv().await? {
177 Some(FrontendMessage::Password { password }) => password,
178 _ => {
179 return conn
180 .send(ErrorResponse::fatal(
181 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
182 "expected Password message",
183 ))
184 .await;
185 }
186 };
187
188 let auth_response = frontegg.authenticate(&user, &password).await;
189 match auth_response {
190 Ok(mut auth_session) => {
191 let session = adapter_client.new_session(SessionConfig {
199 conn_id: conn.conn_id().clone(),
200 uuid: conn_uuid,
201 user: auth_session.user().into(),
202 client_ip: conn.peer_addr().clone(),
203 external_metadata_rx: Some(auth_session.external_metadata_rx()),
204 internal_user_metadata: None,
205 helm_chart_version,
206 });
207 let expired = async move { auth_session.expired().await };
208 (session, expired.left_future())
209 }
210 Err(err) => {
211 warn!(?err, "pgwire connection failed authentication");
212 return conn
213 .send(ErrorResponse::fatal(
214 SqlState::INVALID_PASSWORD,
215 "invalid password",
216 ))
217 .await;
218 }
219 }
220 } else if use_self_hosted_auth {
221 conn.send(BackendMessage::AuthenticationCleartextPassword)
222 .await?;
223 conn.flush().await?;
224 let password = match conn.recv().await? {
225 Some(FrontendMessage::Password { password }) => Password(password),
226 _ => {
227 return conn
228 .send(ErrorResponse::fatal(
229 SqlState::INVALID_AUTHORIZATION_SPECIFICATION,
230 "expected Password message",
231 ))
232 .await;
233 }
234 };
235 let auth_response = match adapter_client.authenticate(&user, &password).await {
236 Ok(resp) => resp,
237 Err(err) => {
238 warn!(?err, "pgwire connection failed authentication");
239 return conn
240 .send(ErrorResponse::fatal(
241 SqlState::INVALID_PASSWORD,
242 "invalid password",
243 ))
244 .await;
245 }
246 };
247 let session = adapter_client.new_session(SessionConfig {
248 conn_id: conn.conn_id().clone(),
249 uuid: conn_uuid,
250 user,
251 client_ip: conn.peer_addr().clone(),
252 external_metadata_rx: None,
253 internal_user_metadata: Some(InternalUserMetadata {
254 superuser: auth_response.superuser,
255 }),
256 helm_chart_version,
257 });
258 let auth_session = pending().right_future();
260 (session, auth_session)
261 } else {
262 let session = adapter_client.new_session(SessionConfig {
263 conn_id: conn.conn_id().clone(),
264 uuid: conn_uuid,
265 user,
266 client_ip: conn.peer_addr().clone(),
267 external_metadata_rx: None,
268 internal_user_metadata: None,
269 helm_chart_version,
270 });
271 let auth_session = pending().right_future();
273 (session, auth_session)
274 };
275
276 let system_vars = adapter_client.get_system_vars().await;
277 for (name, value) in params {
278 let settings = match name.as_str() {
279 "options" => match parse_options(&value) {
280 Ok(opts) => opts,
281 Err(()) => {
282 session.add_notice(AdapterNotice::BadStartupSetting {
283 name,
284 reason: "could not parse".into(),
285 });
286 continue;
287 }
288 },
289 _ => vec![(name, value)],
290 };
291 for (key, val) in settings {
292 const LOCAL: bool = false;
293 if let Err(err) =
298 session
299 .vars_mut()
300 .set(&system_vars, &key, VarInput::Flat(&val), LOCAL)
301 {
302 session.add_notice(AdapterNotice::BadStartupSetting {
303 name: key,
304 reason: err.to_string(),
305 });
306 }
307 }
308 }
309 session
310 .vars_mut()
311 .end_transaction(EndTransactionAction::Commit);
312
313 let _guard = match active_connection_counter.allocate_connection(session.user()) {
314 Ok(drop_connection) => drop_connection,
315 Err(e) => {
316 let e: AdapterError = e.into();
317 return conn.send(e.into_response(Severity::Fatal)).await;
318 }
319 };
320
321 let mut adapter_client = match adapter_client.startup(session).await {
323 Ok(adapter_client) => adapter_client,
324 Err(e) => return conn.send(e.into_response(Severity::Fatal)).await,
325 };
326
327 let mut buf = vec![BackendMessage::AuthenticationOk];
328 for var in adapter_client.session().vars().notify_set() {
329 buf.push(BackendMessage::ParameterStatus(var.name(), var.value()));
330 }
331 buf.push(BackendMessage::BackendKeyData {
332 conn_id: adapter_client.session().conn_id().unhandled(),
333 secret_key: adapter_client.session().secret_key(),
334 });
335 buf.extend(
336 adapter_client
337 .session()
338 .drain_notices()
339 .into_iter()
340 .map(|notice| BackendMessage::ErrorResponse(notice.into_response())),
341 );
342 buf.push(BackendMessage::ReadyForQuery(
343 adapter_client.session().transaction().into(),
344 ));
345 conn.send_all(buf).await?;
346 conn.flush().await?;
347
348 let machine = StateMachine {
349 conn,
350 adapter_client,
351 txn_needs_commit: false,
352 };
353
354 select! {
355 r = machine.run() => {
356 if let Err(err) = &r {
361 let _ = conn
362 .send(ErrorResponse::fatal(
363 SqlState::CONNECTION_FAILURE,
364 err.to_string(),
365 ))
366 .await;
367 let _ = conn.flush().await;
368 }
369 r
370 },
371 _ = expired => {
372 conn
373 .send(ErrorResponse::fatal(SqlState::INVALID_AUTHORIZATION_SPECIFICATION, "authentication expired"))
374 .await?;
375 conn.flush().await
376 }
377 }
378}
379
380fn parse_options(value: &str) -> Result<Vec<(String, String)>, ()> {
385 let opts = split_options(value);
386 let mut pairs = Vec::with_capacity(opts.len());
387 let mut seen_prefix = false;
388 for opt in opts {
389 if !seen_prefix {
390 if opt == "-c" {
391 seen_prefix = true;
392 } else {
393 let (key, val) = parse_option(&opt)?;
394 pairs.push((key.to_owned(), val.to_owned()));
395 }
396 } else {
397 let (key, val) = opt.split_once('=').ok_or(())?;
398 pairs.push((key.to_owned(), val.to_owned()));
399 seen_prefix = false;
400 }
401 }
402 Ok(pairs)
403}
404
405fn parse_option(option: &str) -> Result<(&str, &str), ()> {
409 let (key, value) = option.split_once('=').ok_or(())?;
410 for prefix in &["-c", "--"] {
411 if let Some(key) = key.strip_prefix(prefix) {
412 return Ok((key, value));
413 }
414 }
415 Err(())
416}
417
418fn split_options(value: &str) -> Vec<String> {
420 let mut strs = Vec::new();
421 let mut current = String::new();
425 let mut was_slash = false;
426 for c in value.chars() {
427 was_slash = match c {
428 ' ' => {
429 if was_slash {
430 current.push(' ');
431 } else if !current.is_empty() {
432 strs.push(std::mem::take(&mut current));
435 }
436 false
437 }
438 '\\' => {
439 if was_slash {
440 current.push('\\');
443 false
444 } else {
445 true
446 }
447 }
448 _ => {
449 current.push(c);
450 false
451 }
452 };
453 }
454 if !current.is_empty() {
456 strs.push(current);
457 }
458 strs
459}
460
461#[derive(Debug)]
462enum State {
463 Ready,
464 Drain,
465 Done,
466}
467
468struct StateMachine<'a, A> {
469 conn: &'a mut FramedConn<A>,
470 adapter_client: mz_adapter::SessionClient,
471 txn_needs_commit: bool,
472}
473
474enum SendRowsEndedReason {
475 Success {
476 result_size: u64,
477 rows_returned: u64,
478 },
479 Errored {
480 error: String,
481 },
482 Canceled,
483}
484
485const ABORTED_TXN_MSG: &str =
486 "current transaction is aborted, commands ignored until end of transaction block";
487
488impl<'a, A> StateMachine<'a, A>
489where
490 A: AsyncRead + AsyncWrite + AsyncReady + Send + Sync + Unpin + 'a,
491{
492 #[allow(clippy::manual_async_fn)]
496 #[mz_ore::instrument(level = "debug")]
497 fn run(mut self) -> impl Future<Output = Result<(), io::Error>> + Send + 'a {
498 async move {
499 let mut state = State::Ready;
500 loop {
501 self.send_pending_notices().await?;
502 state = match state {
503 State::Ready => self.advance_ready().await?,
504 State::Drain => self.advance_drain().await?,
505 State::Done => return Ok(()),
506 };
507 self.adapter_client
508 .add_idle_in_transaction_session_timeout();
509 }
510 }
511 }
512
513 #[instrument(level = "debug")]
514 async fn advance_ready(&mut self) -> Result<State, io::Error> {
515 let message = select! {
517 biased;
518
519 Some(timeout) = self.adapter_client.recv_timeout() => {
521 let err: AdapterError = timeout.into();
522 let conn_id = self.adapter_client.session().conn_id();
523 tracing::warn!("session timed out, conn_id {}", conn_id);
524
525 let error_response = err.into_response(Severity::Fatal);
527 let error_state = self.error(error_response).await;
528
529 self.adapter_client.terminate().await;
531
532 let _ = self.conn.recv().await?;
536 return error_state;
537 },
538 message = self.conn.recv() => message?,
540 };
541
542 self.adapter_client
543 .remove_idle_in_transaction_session_timeout();
544
545 let message_name = message.as_ref().map(|m| m.name()).unwrap_or_default();
548
549 let next_state = match message {
550 Some(FrontendMessage::Query { sql }) => {
551 let query_root_span =
552 tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
553 query_root_span.follows_from(tracing::Span::current());
554 self.query(sql).instrument(query_root_span).await?
555 }
556 Some(FrontendMessage::Parse {
557 name,
558 sql,
559 param_types,
560 }) => self.parse(name, sql, param_types).await?,
561 Some(FrontendMessage::Bind {
562 portal_name,
563 statement_name,
564 param_formats,
565 raw_params,
566 result_formats,
567 }) => {
568 self.bind(
569 portal_name,
570 statement_name,
571 param_formats,
572 raw_params,
573 result_formats,
574 )
575 .await?
576 }
577 Some(FrontendMessage::Execute {
578 portal_name,
579 max_rows,
580 }) => {
581 let max_rows = match usize::try_from(max_rows) {
582 Ok(0) | Err(_) => ExecuteCount::All, Ok(n) => ExecuteCount::Count(n),
584 };
585 let execute_root_span =
586 tracing::info_span!(parent: None, "advance_ready", otel.name = message_name);
587 execute_root_span.follows_from(tracing::Span::current());
588 let state = self
589 .execute(
590 portal_name,
591 max_rows,
592 portal_exec_message,
593 None,
594 ExecuteTimeout::None,
595 None,
596 )
597 .instrument(execute_root_span)
598 .await?;
599 if self.adapter_client.session().transaction().is_implicit() {
614 self.txn_needs_commit = true;
615 }
616 state
617 }
618 Some(FrontendMessage::DescribeStatement { name }) => {
619 self.describe_statement(&name).await?
620 }
621 Some(FrontendMessage::DescribePortal { name }) => self.describe_portal(&name).await?,
622 Some(FrontendMessage::CloseStatement { name }) => self.close_statement(name).await?,
623 Some(FrontendMessage::ClosePortal { name }) => self.close_portal(name).await?,
624 Some(FrontendMessage::Flush) => self.flush().await?,
625 Some(FrontendMessage::Sync) => self.sync().await?,
626 Some(FrontendMessage::Terminate) => State::Done,
627
628 Some(FrontendMessage::CopyData(_))
629 | Some(FrontendMessage::CopyDone)
630 | Some(FrontendMessage::CopyFail(_))
631 | Some(FrontendMessage::Password { .. }) => State::Drain,
632 None => State::Done,
633 };
634
635 Ok(next_state)
636 }
637
638 async fn advance_drain(&mut self) -> Result<State, io::Error> {
639 let message = self.conn.recv().await?;
640 if message.is_some() {
641 self.adapter_client
642 .remove_idle_in_transaction_session_timeout();
643 }
644 match message {
645 Some(FrontendMessage::Sync) => self.sync().await,
646 None => Ok(State::Done),
647 _ => Ok(State::Drain),
648 }
649 }
650
651 #[instrument(level = "debug")]
652 async fn one_query(&mut self, stmt: Statement<Raw>, sql: String) -> Result<State, io::Error> {
653 const EMPTY_PORTAL: &str = "";
656 if let Err(e) = self
657 .adapter_client
658 .declare(EMPTY_PORTAL.to_string(), stmt, sql)
659 .await
660 {
661 return self.error(e.into_response(Severity::Error)).await;
662 }
663
664 let stmt_desc = self
665 .adapter_client
666 .session()
667 .get_portal_unverified(EMPTY_PORTAL)
668 .map(|portal| portal.desc.clone())
669 .expect("unnamed portal should be present");
670 if !stmt_desc.param_types.is_empty() {
671 return self
672 .error(ErrorResponse::error(
673 SqlState::UNDEFINED_PARAMETER,
674 "there is no parameter $1",
675 ))
676 .await;
677 }
678
679 if let Some(relation_desc) = &stmt_desc.relation_desc {
681 if !stmt_desc.is_copy {
682 let formats = vec![Format::Text; stmt_desc.arity()];
683 self.send(BackendMessage::RowDescription(
684 message::encode_row_description(relation_desc, &formats),
685 ))
686 .await?;
687 }
688 }
689
690 let result = match self
691 .adapter_client
692 .execute(EMPTY_PORTAL.to_string(), self.conn.wait_closed(), None)
693 .await
694 {
695 Ok((response, execute_started)) => {
696 self.send_pending_notices().await?;
697 self.send_execute_response(
698 response,
699 stmt_desc.relation_desc,
700 EMPTY_PORTAL.to_string(),
701 ExecuteCount::All,
702 portal_exec_message,
703 None,
704 ExecuteTimeout::None,
705 execute_started,
706 )
707 .await
708 }
709 Err(e) => {
710 self.send_pending_notices().await?;
711 self.error(e.into_response(Severity::Error)).await
712 }
713 };
714
715 self.adapter_client.session().remove_portal(EMPTY_PORTAL);
717
718 result
719 }
720
721 async fn ensure_transaction(&mut self, num_stmts: usize) -> Result<(), io::Error> {
722 if self.txn_needs_commit {
723 self.commit_transaction().await?;
724 }
725 let res = self.adapter_client.start_transaction(Some(num_stmts));
728 assert_ok!(res);
729 Ok(())
730 }
731
732 fn parse_sql<'b>(&self, sql: &'b str) -> Result<Vec<StatementParseResult<'b>>, ErrorResponse> {
733 match self.adapter_client.parse(sql) {
734 Ok(result) => result.map_err(|e| {
735 let pos = sql[..e.error.pos].chars().count() + 1;
738 ErrorResponse::error(SqlState::SYNTAX_ERROR, e.error.message).with_position(pos)
739 }),
740 Err(msg) => Err(ErrorResponse::error(SqlState::PROGRAM_LIMIT_EXCEEDED, msg)),
741 }
742 }
743
744 #[instrument(level = "debug")]
748 async fn query(&mut self, sql: String) -> Result<State, io::Error> {
749 let stmts = match self.parse_sql(&sql) {
751 Ok(stmts) => stmts,
752 Err(err) => {
753 self.error(err).await?;
754 return self.ready().await;
755 }
756 };
757
758 let num_stmts = stmts.len();
759
760 for StatementParseResult { ast: stmt, sql } in stmts {
762 if self.is_aborted_txn() && !is_txn_exit_stmt(Some(&stmt)) {
764 self.aborted_txn_error().await?;
765 break;
766 }
767
768 self.ensure_transaction(num_stmts).await?;
776
777 match self.one_query(stmt, sql.to_string()).await? {
778 State::Ready => (),
779 State::Drain => break,
780 State::Done => return Ok(State::Done),
781 }
782 }
783
784 {
786 if self.adapter_client.session().transaction().is_implicit() {
787 self.commit_transaction().await?;
788 }
789 }
790
791 if num_stmts == 0 {
792 self.send(BackendMessage::EmptyQueryResponse).await?;
793 }
794
795 self.ready().await
796 }
797
798 #[instrument(level = "debug")]
799 async fn parse(
800 &mut self,
801 name: String,
802 sql: String,
803 param_oids: Vec<u32>,
804 ) -> Result<State, io::Error> {
805 self.ensure_transaction(1).await?;
807
808 let mut param_types = vec![];
809 for oid in param_oids {
810 match mz_pgrepr::Type::from_oid(oid) {
811 Ok(ty) => match ScalarType::try_from(&ty) {
812 Ok(ty) => param_types.push(Some(ty)),
813 Err(err) => {
814 return self
815 .error(ErrorResponse::error(
816 SqlState::INVALID_PARAMETER_VALUE,
817 err.to_string(),
818 ))
819 .await;
820 }
821 },
822 Err(_) if oid == 0 => param_types.push(None),
823 Err(e) => {
824 return self
825 .error(ErrorResponse::error(
826 SqlState::PROTOCOL_VIOLATION,
827 e.to_string(),
828 ))
829 .await;
830 }
831 }
832 }
833
834 let stmts = match self.parse_sql(&sql) {
835 Ok(stmts) => stmts,
836 Err(err) => {
837 return self.error(err).await;
838 }
839 };
840 if stmts.len() > 1 {
841 return self
842 .error(ErrorResponse::error(
843 SqlState::INTERNAL_ERROR,
844 "cannot insert multiple commands into a prepared statement",
845 ))
846 .await;
847 }
848 let (maybe_stmt, sql) = match stmts.into_iter().next() {
849 None => (None, ""),
850 Some(StatementParseResult { ast, sql }) => (Some(ast), sql),
851 };
852 if self.is_aborted_txn() && !is_txn_exit_stmt(maybe_stmt.as_ref()) {
853 return self.aborted_txn_error().await;
854 }
855 match self
856 .adapter_client
857 .prepare(name, maybe_stmt, sql.to_string(), param_types)
858 .await
859 {
860 Ok(()) => {
861 self.send(BackendMessage::ParseComplete).await?;
862 Ok(State::Ready)
863 }
864 Err(e) => self.error(e.into_response(Severity::Error)).await,
865 }
866 }
867
868 #[instrument(level = "debug")]
870 async fn commit_transaction(&mut self) -> Result<(), io::Error> {
871 self.end_transaction(EndTransactionAction::Commit).await
872 }
873
874 #[instrument(level = "debug")]
876 async fn rollback_transaction(&mut self) -> Result<(), io::Error> {
877 self.end_transaction(EndTransactionAction::Rollback).await
878 }
879
880 #[instrument(level = "debug")]
882 async fn end_transaction(&mut self, action: EndTransactionAction) -> Result<(), io::Error> {
883 self.txn_needs_commit = false;
884 let resp = self.adapter_client.end_transaction(action).await;
885 if let Err(err) = resp {
886 self.send(BackendMessage::ErrorResponse(
887 err.into_response(Severity::Error),
888 ))
889 .await?;
890 }
891 Ok(())
892 }
893
894 #[instrument(level = "debug")]
895 async fn bind(
896 &mut self,
897 portal_name: String,
898 statement_name: String,
899 param_formats: Vec<Format>,
900 raw_params: Vec<Option<Vec<u8>>>,
901 result_formats: Vec<Format>,
902 ) -> Result<State, io::Error> {
903 self.ensure_transaction(1).await?;
905
906 let aborted_txn = self.is_aborted_txn();
907 let stmt = match self
908 .adapter_client
909 .get_prepared_statement(&statement_name)
910 .await
911 {
912 Ok(stmt) => stmt,
913 Err(err) => return self.error(err.into_response(Severity::Error)).await,
914 };
915
916 let param_types = &stmt.desc().param_types;
917 if param_types.len() != raw_params.len() {
918 let message = format!(
919 "bind message supplies {actual} parameters, \
920 but prepared statement \"{name}\" requires {expected}",
921 name = statement_name,
922 actual = raw_params.len(),
923 expected = param_types.len()
924 );
925 return self
926 .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, message))
927 .await;
928 }
929 let param_formats = match pad_formats(param_formats, raw_params.len()) {
930 Ok(param_formats) => param_formats,
931 Err(msg) => {
932 return self
933 .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
934 .await;
935 }
936 };
937 if aborted_txn && !is_txn_exit_stmt(stmt.stmt()) {
938 return self.aborted_txn_error().await;
939 }
940 let buf = RowArena::new();
941 let mut params = vec![];
942 for (raw_param, mz_typ, format) in izip!(raw_params, param_types, param_formats) {
943 let pg_typ = mz_pgrepr::Type::from(mz_typ);
944 let datum = match raw_param {
945 None => Datum::Null,
946 Some(bytes) => match mz_pgrepr::Value::decode(format, &pg_typ, &bytes) {
947 Ok(param) => param.into_datum(&buf, &pg_typ),
948 Err(err) => {
949 let msg = format!("unable to decode parameter: {}", err);
950 return self
951 .error(ErrorResponse::error(SqlState::INVALID_PARAMETER_VALUE, msg))
952 .await;
953 }
954 },
955 };
956 params.push((datum, mz_typ.clone()))
957 }
958
959 let result_formats = match pad_formats(
960 result_formats,
961 stmt.desc()
962 .relation_desc
963 .clone()
964 .map(|desc| desc.typ().column_types.len())
965 .unwrap_or(0),
966 ) {
967 Ok(result_formats) => result_formats,
968 Err(msg) => {
969 return self
970 .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
971 .await;
972 }
973 };
974
975 if !stmt.stmt().map_or(false, |stmt| {
978 matches!(
979 stmt,
980 Statement::Copy(CopyStatement {
981 direction: CopyDirection::To,
982 ..
983 })
984 )
985 }) {
986 if let Some(desc) = stmt.desc().relation_desc.clone() {
987 for (format, ty) in result_formats.iter().zip(desc.iter_types()) {
988 match (format, &ty.scalar_type) {
989 (Format::Binary, mz_repr::ScalarType::List { .. }) => {
990 return self
991 .error(ErrorResponse::error(
992 SqlState::PROTOCOL_VIOLATION,
993 "binary encoding of list types is not implemented",
994 ))
995 .await;
996 }
997 (Format::Binary, mz_repr::ScalarType::Map { .. }) => {
998 return self
999 .error(ErrorResponse::error(
1000 SqlState::PROTOCOL_VIOLATION,
1001 "binary encoding of map types is not implemented",
1002 ))
1003 .await;
1004 }
1005 (Format::Binary, mz_repr::ScalarType::AclItem) => {
1006 return self
1007 .error(ErrorResponse::error(
1008 SqlState::PROTOCOL_VIOLATION,
1009 "binary encoding of aclitem types does not exist",
1010 ))
1011 .await;
1012 }
1013 _ => (),
1014 }
1015 }
1016 }
1017 }
1018
1019 let desc = stmt.desc().clone();
1020 let revision = stmt.catalog_revision;
1021 let logging = Arc::clone(stmt.logging());
1022 let stmt = stmt.stmt().cloned();
1023 if let Err(err) = self.adapter_client.session().set_portal(
1024 portal_name,
1025 desc,
1026 stmt,
1027 logging,
1028 params,
1029 result_formats,
1030 revision,
1031 ) {
1032 return self.error(err.into_response(Severity::Error)).await;
1033 }
1034
1035 self.send(BackendMessage::BindComplete).await?;
1036 Ok(State::Ready)
1037 }
1038
1039 fn execute(
1040 &mut self,
1041 portal_name: String,
1042 max_rows: ExecuteCount,
1043 get_response: GetResponse,
1044 fetch_portal_name: Option<String>,
1045 timeout: ExecuteTimeout,
1046 outer_ctx_extra: Option<ExecuteContextExtra>,
1047 ) -> BoxFuture<'_, Result<State, io::Error>> {
1048 async move {
1049 let aborted_txn = self.is_aborted_txn();
1050
1051 let portal = match self
1053 .adapter_client
1054 .session()
1055 .get_portal_unverified_mut(&portal_name)
1056 {
1057 Some(portal) => portal,
1058 None => {
1059 let msg = format!("portal {} does not exist", portal_name.quoted());
1060 if let Some(outer_ctx_extra) = outer_ctx_extra {
1061 self.adapter_client.retire_execute(
1062 outer_ctx_extra,
1063 StatementEndedExecutionReason::Errored { error: msg.clone() },
1064 );
1065 }
1066 return self
1067 .error(ErrorResponse::error(SqlState::INVALID_CURSOR_NAME, msg))
1068 .await;
1069 }
1070 };
1071
1072 let txn_exit_stmt = is_txn_exit_stmt(portal.stmt.as_deref());
1074 if aborted_txn && !txn_exit_stmt {
1075 if let Some(outer_ctx_extra) = outer_ctx_extra {
1076 self.adapter_client.retire_execute(
1077 outer_ctx_extra,
1078 StatementEndedExecutionReason::Errored {
1079 error: ABORTED_TXN_MSG.to_string(),
1080 },
1081 );
1082 }
1083 return self.aborted_txn_error().await;
1084 }
1085
1086 let row_desc = portal.desc.relation_desc.clone();
1087 match &mut portal.state {
1088 PortalState::NotStarted => {
1089 self.ensure_transaction(1).await?;
1091 match self
1092 .adapter_client
1093 .execute(
1094 portal_name.clone(),
1095 self.conn.wait_closed(),
1096 outer_ctx_extra,
1097 )
1098 .await
1099 {
1100 Ok((response, execute_started)) => {
1101 self.send_pending_notices().await?;
1102 self.send_execute_response(
1103 response,
1104 row_desc,
1105 portal_name,
1106 max_rows,
1107 get_response,
1108 fetch_portal_name,
1109 timeout,
1110 execute_started,
1111 )
1112 .await
1113 }
1114 Err(e) => {
1115 self.send_pending_notices().await?;
1116 self.error(e.into_response(Severity::Error)).await
1117 }
1118 }
1119 }
1120 PortalState::InProgress(rows) => {
1121 let rows = rows.take().expect("InProgress rows must be populated");
1122 let (result, statement_ended_execution_reason) = match self
1123 .send_rows(
1124 row_desc.expect("portal missing row desc on resumption"),
1125 portal_name,
1126 rows,
1127 max_rows,
1128 get_response,
1129 fetch_portal_name,
1130 timeout,
1131 )
1132 .await
1133 {
1134 Err(e) => {
1135 (Err(e), StatementEndedExecutionReason::Canceled)
1138 }
1139 Ok((ok, SendRowsEndedReason::Canceled)) => {
1140 (Ok(ok), StatementEndedExecutionReason::Canceled)
1141 }
1142 Ok((
1155 ok,
1156 SendRowsEndedReason::Success {
1157 result_size: _,
1158 rows_returned: _,
1159 },
1160 )) => (
1161 Ok(ok),
1162 StatementEndedExecutionReason::Success {
1163 result_size: None,
1164 rows_returned: None,
1165 execution_strategy: None,
1166 },
1167 ),
1168 Ok((ok, SendRowsEndedReason::Errored { error })) => {
1169 (Ok(ok), StatementEndedExecutionReason::Errored { error })
1170 }
1171 };
1172 if let Some(outer_ctx_extra) = outer_ctx_extra {
1173 self.adapter_client
1174 .retire_execute(outer_ctx_extra, statement_ended_execution_reason);
1175 }
1176 result
1177 }
1178 PortalState::Completed(Some(tag)) => {
1185 let tag = tag.to_string();
1186 if let Some(outer_ctx_extra) = outer_ctx_extra {
1187 self.adapter_client.retire_execute(
1188 outer_ctx_extra,
1189 StatementEndedExecutionReason::Success {
1190 result_size: None,
1191 rows_returned: None,
1192 execution_strategy: None,
1193 },
1194 );
1195 }
1196 self.send(BackendMessage::CommandComplete { tag }).await?;
1197 Ok(State::Ready)
1198 }
1199 PortalState::Completed(None) => {
1200 let error = format!(
1201 "portal {} cannot be run",
1202 Ident::new_unchecked(portal_name).to_ast_string_stable()
1203 );
1204 if let Some(outer_ctx_extra) = outer_ctx_extra {
1205 self.adapter_client.retire_execute(
1206 outer_ctx_extra,
1207 StatementEndedExecutionReason::Errored {
1208 error: error.clone(),
1209 },
1210 );
1211 }
1212 self.error(ErrorResponse::error(
1213 SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE,
1214 error,
1215 ))
1216 .await
1217 }
1218 }
1219 }
1220 .instrument(debug_span!("execute"))
1221 .boxed()
1222 }
1223
1224 #[instrument(level = "debug")]
1225 async fn describe_statement(&mut self, name: &str) -> Result<State, io::Error> {
1226 self.ensure_transaction(1).await?;
1228
1229 let stmt = match self.adapter_client.get_prepared_statement(name).await {
1230 Ok(stmt) => stmt,
1231 Err(err) => return self.error(err.into_response(Severity::Error)).await,
1232 };
1233 let parameter_desc = BackendMessage::ParameterDescription(
1235 stmt.desc()
1236 .param_types
1237 .iter()
1238 .map(mz_pgrepr::Type::from)
1239 .collect(),
1240 );
1241 let formats = vec![Format::Text; stmt.desc().arity()];
1245 let row_desc = describe_rows(stmt.desc(), &formats);
1246 self.send_all([parameter_desc, row_desc]).await?;
1247 Ok(State::Ready)
1248 }
1249
1250 #[instrument(level = "debug")]
1251 async fn describe_portal(&mut self, name: &str) -> Result<State, io::Error> {
1252 self.ensure_transaction(1).await?;
1254
1255 let session = self.adapter_client.session();
1256 let row_desc = session
1257 .get_portal_unverified(name)
1258 .map(|portal| describe_rows(&portal.desc, &portal.result_formats));
1259 match row_desc {
1260 Some(row_desc) => {
1261 self.send(row_desc).await?;
1262 Ok(State::Ready)
1263 }
1264 None => {
1265 self.error(ErrorResponse::error(
1266 SqlState::INVALID_CURSOR_NAME,
1267 format!("portal {} does not exist", name.quoted()),
1268 ))
1269 .await
1270 }
1271 }
1272 }
1273
1274 #[instrument(level = "debug")]
1275 async fn close_statement(&mut self, name: String) -> Result<State, io::Error> {
1276 self.adapter_client
1277 .session()
1278 .remove_prepared_statement(&name);
1279 self.send(BackendMessage::CloseComplete).await?;
1280 Ok(State::Ready)
1281 }
1282
1283 #[instrument(level = "debug")]
1284 async fn close_portal(&mut self, name: String) -> Result<State, io::Error> {
1285 self.adapter_client.session().remove_portal(&name);
1286 self.send(BackendMessage::CloseComplete).await?;
1287 Ok(State::Ready)
1288 }
1289
1290 fn complete_portal(&mut self, name: &str) {
1291 let portal = self
1292 .adapter_client
1293 .session()
1294 .get_portal_unverified_mut(name)
1295 .expect("portal should exist");
1296 portal.state = PortalState::Completed(None);
1297 }
1298
1299 async fn fetch(
1300 &mut self,
1301 name: String,
1302 count: Option<FetchDirection>,
1303 max_rows: ExecuteCount,
1304 fetch_portal_name: Option<String>,
1305 timeout: ExecuteTimeout,
1306 ctx_extra: ExecuteContextExtra,
1307 ) -> Result<State, io::Error> {
1308 let count = count.unwrap_or(FetchDirection::ForwardCount(1));
1311
1312 let count = match (max_rows, count) {
1325 (ExecuteCount::Count(max_rows), FetchDirection::ForwardCount(count)) => {
1326 let count = usize::cast_from(count);
1327 if max_rows < count {
1328 let msg = "Execute with max_rows < a FETCH's count is not supported";
1329 self.adapter_client.retire_execute(
1330 ctx_extra,
1331 StatementEndedExecutionReason::Errored {
1332 error: msg.to_string(),
1333 },
1334 );
1335 return self
1336 .error(ErrorResponse::error(SqlState::FEATURE_NOT_SUPPORTED, msg))
1337 .await;
1338 }
1339 ExecuteCount::Count(count)
1340 }
1341 (ExecuteCount::Count(_), FetchDirection::ForwardAll) => {
1342 let msg = "Execute with max_rows of a FETCH ALL is not supported";
1343 self.adapter_client.retire_execute(
1344 ctx_extra,
1345 StatementEndedExecutionReason::Errored {
1346 error: msg.to_string(),
1347 },
1348 );
1349 return self
1350 .error(ErrorResponse::error(SqlState::FEATURE_NOT_SUPPORTED, msg))
1351 .await;
1352 }
1353 (ExecuteCount::All, FetchDirection::ForwardAll) => ExecuteCount::All,
1354 (ExecuteCount::All, FetchDirection::ForwardCount(count)) => {
1355 ExecuteCount::Count(usize::cast_from(count))
1356 }
1357 };
1358 let cursor_name = name.to_string();
1359 self.execute(
1360 cursor_name,
1361 count,
1362 fetch_message,
1363 fetch_portal_name,
1364 timeout,
1365 Some(ctx_extra),
1366 )
1367 .await
1368 }
1369
1370 async fn flush(&mut self) -> Result<State, io::Error> {
1371 self.conn.flush().await?;
1372 Ok(State::Ready)
1373 }
1374
1375 #[instrument(level = "debug")]
1380 async fn send<M>(&mut self, message: M) -> Result<(), io::Error>
1381 where
1382 M: Into<BackendMessage>,
1383 {
1384 let message: BackendMessage = message.into();
1385 let is_error =
1386 matches!(&message, BackendMessage::ErrorResponse(e) if e.severity.is_error());
1387
1388 self.conn.send(message).await?;
1389
1390 if is_error {
1397 self.conn.flush().await?;
1398 }
1399
1400 Ok(())
1401 }
1402
1403 #[instrument(level = "debug")]
1404 pub async fn send_all(
1405 &mut self,
1406 messages: impl IntoIterator<Item = BackendMessage>,
1407 ) -> Result<(), io::Error> {
1408 for m in messages {
1409 self.send(m).await?;
1410 }
1411 Ok(())
1412 }
1413
1414 #[instrument(level = "debug")]
1415 async fn sync(&mut self) -> Result<State, io::Error> {
1416 if self.adapter_client.session().transaction().is_implicit() {
1418 self.commit_transaction().await?;
1419 }
1420 self.ready().await
1421 }
1422
1423 #[instrument(level = "debug")]
1424 async fn ready(&mut self) -> Result<State, io::Error> {
1425 let txn_state = self.adapter_client.session().transaction().into();
1426 self.send(BackendMessage::ReadyForQuery(txn_state)).await?;
1427 self.flush().await
1428 }
1429
1430 #[instrument(level = "debug")]
1432 async fn row_future_to_stream<'s, 'p>(
1433 &'s mut self,
1434 parent: &'p tracing::Span,
1435 mut rows: RowsFuture,
1436 ) -> Result<UnboundedReceiver<PeekResponseUnary>, io::Error>
1437 where
1438 'p: 's,
1439 {
1440 let span = tracing::debug_span!(parent: parent, "row_future_to_stream");
1443 async move {
1444 loop {
1445 tokio::select! {
1446 err = self.conn.wait_closed() => return Err(err),
1447 rows = &mut rows => {
1448 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
1449 tx.send(rows).expect("send must succeed");
1450 return Ok(rx);
1451 }
1452 notice = self.adapter_client.session().recv_notice() => {
1453 self.send(notice.into_response())
1454 .await?;
1455 self.conn.flush().await?;
1456 }
1457 }
1458 }
1459 }
1460 .instrument(span)
1461 .await
1462 }
1463
1464 #[allow(clippy::too_many_arguments)]
1465 #[instrument(level = "debug")]
1466 async fn send_execute_response(
1467 &mut self,
1468 response: ExecuteResponse,
1469 row_desc: Option<RelationDesc>,
1470 portal_name: String,
1471 max_rows: ExecuteCount,
1472 get_response: GetResponse,
1473 fetch_portal_name: Option<String>,
1474 timeout: ExecuteTimeout,
1475 execute_started: Instant,
1476 ) -> Result<State, io::Error> {
1477 let mut tag = response.tag();
1478
1479 macro_rules! command_complete {
1480 () => {{
1481 self.send(BackendMessage::CommandComplete {
1482 tag: tag
1483 .take()
1484 .expect("command_complete only called on tag-generating results"),
1485 })
1486 .await?;
1487 Ok(State::Ready)
1488 }};
1489 }
1490
1491 let r = match response {
1492 ExecuteResponse::ClosedCursor => {
1493 self.complete_portal(&portal_name);
1494 command_complete!()
1495 }
1496 ExecuteResponse::DeclaredCursor => {
1497 self.complete_portal(&portal_name);
1498 command_complete!()
1499 }
1500 ExecuteResponse::EmptyQuery => {
1501 self.send(BackendMessage::EmptyQueryResponse).await?;
1502 Ok(State::Ready)
1503 }
1504 ExecuteResponse::Fetch {
1505 name,
1506 count,
1507 timeout,
1508 ctx_extra,
1509 } => {
1510 self.fetch(
1511 name,
1512 count,
1513 max_rows,
1514 Some(portal_name.to_string()),
1515 timeout,
1516 ctx_extra,
1517 )
1518 .await
1519 }
1520 ExecuteResponse::SendingRows {
1521 future: rx,
1522 instance_id,
1523 strategy,
1524 } => {
1525 let row_desc =
1526 row_desc.expect("missing row description for ExecuteResponse::SendingRows");
1527
1528 let span = tracing::debug_span!("sending_rows");
1529 let rows = self.row_future_to_stream(&span, rx).await?;
1530
1531 self.send_rows(
1532 row_desc,
1533 portal_name,
1534 InProgressRows::new(RecordFirstRowStream::new(
1535 Box::new(UnboundedReceiverStream::new(rows)),
1536 execute_started,
1537 &self.adapter_client,
1538 Some(instance_id),
1539 Some(strategy),
1540 )),
1541 max_rows,
1542 get_response,
1543 fetch_portal_name,
1544 timeout,
1545 )
1546 .instrument(span)
1547 .await
1548 .map(|(state, _)| state)
1549 }
1550 ExecuteResponse::SendingRowsImmediate { rows } => {
1551 let row_desc = row_desc
1552 .expect("missing row description for ExecuteResponse::SendingRowsImmediate");
1553
1554 let span = tracing::debug_span!("sending_rows_immediate");
1555
1556 let stream =
1557 futures::stream::once(futures::future::ready(PeekResponseUnary::Rows(rows)));
1558 self.send_rows(
1559 row_desc,
1560 portal_name,
1561 InProgressRows::new(RecordFirstRowStream::new(
1562 Box::new(stream),
1563 execute_started,
1564 &self.adapter_client,
1565 None,
1566 Some(StatementExecutionStrategy::Constant),
1567 )),
1568 max_rows,
1569 get_response,
1570 fetch_portal_name,
1571 timeout,
1572 )
1573 .instrument(span)
1574 .await
1575 .map(|(state, _)| state)
1576 }
1577 ExecuteResponse::SetVariable { name, .. } => {
1578 let qn = name.to_string();
1581 let msg = if let Some(var) = self
1582 .adapter_client
1583 .session()
1584 .vars_mut()
1585 .notify_set()
1586 .find(|v| v.name() == qn)
1587 {
1588 Some(BackendMessage::ParameterStatus(var.name(), var.value()))
1589 } else {
1590 None
1591 };
1592 if let Some(msg) = msg {
1593 self.send(msg).await?;
1594 }
1595 command_complete!()
1596 }
1597 ExecuteResponse::Subscribing {
1598 rx,
1599 ctx_extra,
1600 instance_id,
1601 } => {
1602 if fetch_portal_name.is_none() {
1603 let mut msg = ErrorResponse::notice(
1604 SqlState::WARNING,
1605 "streaming SUBSCRIBE rows directly requires a client that does not buffer output",
1606 );
1607 if self.adapter_client.session().vars().application_name() == "psql" {
1608 msg.hint = Some(
1609 "Wrap your SUBSCRIBE statement in `COPY (SUBSCRIBE ...) TO STDOUT`."
1610 .into(),
1611 )
1612 }
1613 self.send(msg).await?;
1614 self.conn.flush().await?;
1615 }
1616 let row_desc =
1617 row_desc.expect("missing row description for ExecuteResponse::Subscribing");
1618 let (result, statement_ended_execution_reason) = match self
1619 .send_rows(
1620 row_desc,
1621 portal_name,
1622 InProgressRows::new(RecordFirstRowStream::new(
1623 Box::new(UnboundedReceiverStream::new(rx)),
1624 execute_started,
1625 &self.adapter_client,
1626 Some(instance_id),
1627 None,
1628 )),
1629 max_rows,
1630 get_response,
1631 fetch_portal_name,
1632 timeout,
1633 )
1634 .await
1635 {
1636 Err(e) => {
1637 (Err(e), StatementEndedExecutionReason::Canceled)
1640 }
1641 Ok((ok, SendRowsEndedReason::Canceled)) => {
1642 (Ok(ok), StatementEndedExecutionReason::Canceled)
1643 }
1644 Ok((
1645 ok,
1646 SendRowsEndedReason::Success {
1647 result_size,
1648 rows_returned,
1649 },
1650 )) => (
1651 Ok(ok),
1652 StatementEndedExecutionReason::Success {
1653 result_size: Some(result_size),
1654 rows_returned: Some(rows_returned),
1655 execution_strategy: None,
1656 },
1657 ),
1658 Ok((ok, SendRowsEndedReason::Errored { error })) => {
1659 (Ok(ok), StatementEndedExecutionReason::Errored { error })
1660 }
1661 };
1662 self.adapter_client
1663 .retire_execute(ctx_extra, statement_ended_execution_reason);
1664 return result;
1665 }
1666 ExecuteResponse::CopyTo { format, resp } => {
1667 let row_desc =
1668 row_desc.expect("missing row description for ExecuteResponse::CopyTo");
1669 match *resp {
1670 ExecuteResponse::Subscribing {
1671 rx,
1672 ctx_extra,
1673 instance_id,
1674 } => {
1675 let (result, statement_ended_execution_reason) = match self
1676 .copy_rows(
1677 format,
1678 row_desc,
1679 RecordFirstRowStream::new(
1680 Box::new(UnboundedReceiverStream::new(rx)),
1681 execute_started,
1682 &self.adapter_client,
1683 Some(instance_id),
1684 None,
1685 ),
1686 )
1687 .await
1688 {
1689 Err(e) => {
1690 (Err(e), StatementEndedExecutionReason::Canceled)
1693 }
1694 Ok((
1695 state,
1696 SendRowsEndedReason::Success {
1697 result_size,
1698 rows_returned,
1699 },
1700 )) => (
1701 Ok(state),
1702 StatementEndedExecutionReason::Success {
1703 result_size: Some(result_size),
1704 rows_returned: Some(rows_returned),
1705 execution_strategy: None,
1706 },
1707 ),
1708 Ok((state, SendRowsEndedReason::Errored { error })) => {
1709 (Ok(state), StatementEndedExecutionReason::Errored { error })
1710 }
1711 Ok((state, SendRowsEndedReason::Canceled)) => {
1712 (Ok(state), StatementEndedExecutionReason::Canceled)
1713 }
1714 };
1715 self.adapter_client
1716 .retire_execute(ctx_extra, statement_ended_execution_reason);
1717 return result;
1718 }
1719 ExecuteResponse::SendingRows {
1720 future: rows_rx,
1721 instance_id,
1722 strategy,
1723 } => {
1724 let span = tracing::debug_span!("sending_rows");
1725 let rows = self.row_future_to_stream(&span, rows_rx).await?;
1726 return self
1731 .copy_rows(
1732 format,
1733 row_desc,
1734 RecordFirstRowStream::new(
1735 Box::new(UnboundedReceiverStream::new(rows)),
1736 execute_started,
1737 &self.adapter_client,
1738 Some(instance_id),
1739 Some(strategy),
1740 ),
1741 )
1742 .await
1743 .map(|(state, _)| state);
1744 }
1745 ExecuteResponse::SendingRowsImmediate { rows } => {
1746 let span = tracing::debug_span!("sending_rows_immediate");
1747
1748 let rows = futures::stream::once(futures::future::ready(
1749 PeekResponseUnary::Rows(rows),
1750 ));
1751 return self
1756 .copy_rows(
1757 format,
1758 row_desc,
1759 RecordFirstRowStream::new(
1760 Box::new(rows),
1761 execute_started,
1762 &self.adapter_client,
1763 None,
1764 Some(StatementExecutionStrategy::Constant),
1765 ),
1766 )
1767 .instrument(span)
1768 .await
1769 .map(|(state, _)| state);
1770 }
1771 _ => {
1772 return self
1773 .error(ErrorResponse::error(
1774 SqlState::INTERNAL_ERROR,
1775 "unsupported COPY response type".to_string(),
1776 ))
1777 .await;
1778 }
1779 };
1780 }
1781 ExecuteResponse::CopyFrom {
1782 id,
1783 columns,
1784 params,
1785 ctx_extra,
1786 } => {
1787 let row_desc =
1788 row_desc.expect("missing row description for ExecuteResponse::CopyFrom");
1789 self.copy_from(id, columns, params, row_desc, ctx_extra)
1790 .await
1791 }
1792 ExecuteResponse::TransactionCommitted { params }
1793 | ExecuteResponse::TransactionRolledBack { params } => {
1794 let notify_set: mz_ore::collections::HashSet<String> = self
1795 .adapter_client
1796 .session()
1797 .vars()
1798 .notify_set()
1799 .map(|v| v.name().to_string())
1800 .collect();
1801
1802 for (name, value) in params
1804 .into_iter()
1805 .filter(|(name, _v)| notify_set.contains(*name))
1806 {
1807 let msg = BackendMessage::ParameterStatus(name, value);
1808 self.send(msg).await?;
1809 }
1810 command_complete!()
1811 }
1812
1813 ExecuteResponse::AlteredDefaultPrivileges
1814 | ExecuteResponse::AlteredObject(..)
1815 | ExecuteResponse::AlteredRole
1816 | ExecuteResponse::AlteredSystemConfiguration
1817 | ExecuteResponse::CreatedCluster { .. }
1818 | ExecuteResponse::CreatedClusterReplica { .. }
1819 | ExecuteResponse::CreatedConnection { .. }
1820 | ExecuteResponse::CreatedDatabase { .. }
1821 | ExecuteResponse::CreatedIndex { .. }
1822 | ExecuteResponse::CreatedIntrospectionSubscribe
1823 | ExecuteResponse::CreatedMaterializedView { .. }
1824 | ExecuteResponse::CreatedContinualTask { .. }
1825 | ExecuteResponse::CreatedRole
1826 | ExecuteResponse::CreatedSchema { .. }
1827 | ExecuteResponse::CreatedSecret { .. }
1828 | ExecuteResponse::CreatedSink { .. }
1829 | ExecuteResponse::CreatedSource { .. }
1830 | ExecuteResponse::CreatedTable { .. }
1831 | ExecuteResponse::CreatedType
1832 | ExecuteResponse::CreatedView { .. }
1833 | ExecuteResponse::CreatedViews { .. }
1834 | ExecuteResponse::CreatedNetworkPolicy
1835 | ExecuteResponse::Comment
1836 | ExecuteResponse::Deallocate { .. }
1837 | ExecuteResponse::Deleted(..)
1838 | ExecuteResponse::DiscardedAll
1839 | ExecuteResponse::DiscardedTemp
1840 | ExecuteResponse::DroppedObject(_)
1841 | ExecuteResponse::DroppedOwned
1842 | ExecuteResponse::GrantedPrivilege
1843 | ExecuteResponse::GrantedRole
1844 | ExecuteResponse::Inserted(..)
1845 | ExecuteResponse::Copied(..)
1846 | ExecuteResponse::Prepare
1847 | ExecuteResponse::Raised
1848 | ExecuteResponse::ReassignOwned
1849 | ExecuteResponse::RevokedPrivilege
1850 | ExecuteResponse::RevokedRole
1851 | ExecuteResponse::StartedTransaction { .. }
1852 | ExecuteResponse::Updated(..)
1853 | ExecuteResponse::ValidatedConnection => {
1854 command_complete!()
1855 }
1856 };
1857
1858 assert_none!(tag, "tag created but not consumed: {:?}", tag);
1859 r
1860 }
1861
1862 #[allow(clippy::too_many_arguments)]
1863 #[mz_ore::instrument(level = "debug")]
1865 async fn send_rows(
1866 &mut self,
1867 row_desc: RelationDesc,
1868 portal_name: String,
1869 mut rows: InProgressRows,
1870 max_rows: ExecuteCount,
1871 get_response: GetResponse,
1872 fetch_portal_name: Option<String>,
1873 timeout: ExecuteTimeout,
1874 ) -> Result<(State, SendRowsEndedReason), io::Error> {
1875 let result_format_portal_name: &str = if let Some(ref name) = fetch_portal_name {
1878 name
1879 } else {
1880 &portal_name
1881 };
1882 let result_formats = self
1883 .adapter_client
1884 .session()
1885 .get_portal_unverified(result_format_portal_name)
1886 .expect("valid fetch portal name for send rows")
1887 .result_formats
1888 .clone();
1889
1890 let (mut wait_once, mut deadline) = match timeout {
1891 ExecuteTimeout::None => (false, None),
1892 ExecuteTimeout::Seconds(t) => (
1893 false,
1894 Some(tokio::time::Instant::now() + tokio::time::Duration::from_secs_f64(t)),
1895 ),
1896 ExecuteTimeout::WaitOnce => (true, None),
1897 };
1898
1899 self.conn.set_encode_state(
1900 row_desc
1901 .typ()
1902 .column_types
1903 .iter()
1904 .map(|ty| mz_pgrepr::Type::from(&ty.scalar_type))
1905 .zip(result_formats)
1906 .collect(),
1907 );
1908
1909 let mut total_sent_rows = 0;
1910 let mut total_sent_bytes = 0;
1911 let mut want_rows = match max_rows {
1913 ExecuteCount::All => usize::MAX,
1914 ExecuteCount::Count(count) => count,
1915 };
1916
1917 loop {
1919 let batch = if rows.current.is_some() {
1922 FetchResult::Rows(rows.current.take())
1923 } else if want_rows == 0 {
1924 FetchResult::Rows(None)
1925 } else {
1926 let notice_fut = self.adapter_client.session().recv_notice();
1927 tokio::select! {
1928 err = self.conn.wait_closed() => return Err(err),
1929 _ = time::sleep_until(deadline.unwrap_or_else(tokio::time::Instant::now)), if deadline.is_some() => FetchResult::Rows(None),
1930 notice = notice_fut => {
1931 FetchResult::Notice(notice)
1932 }
1933 batch = rows.remaining.recv() => match batch {
1934 None => FetchResult::Rows(None),
1935 Some(PeekResponseUnary::Rows(rows)) => FetchResult::Rows(Some(rows)),
1936 Some(PeekResponseUnary::Error(err)) => FetchResult::Error(err),
1937 Some(PeekResponseUnary::Canceled) => FetchResult::Canceled,
1938 },
1939 }
1940 };
1941
1942 match batch {
1943 FetchResult::Rows(None) => break,
1944 FetchResult::Rows(Some(mut batch_rows)) => {
1945 if let Err(err) = verify_datum_desc(&row_desc, &mut batch_rows) {
1946 let msg = err.to_string();
1947 return self
1948 .error(err.into_response(Severity::Error))
1949 .await
1950 .map(|state| (state, SendRowsEndedReason::Errored { error: msg }));
1951 }
1952
1953 if wait_once && batch_rows.peek().is_some() {
1957 deadline = Some(tokio::time::Instant::now());
1958 wait_once = false;
1959 }
1960
1961 let mut sent_rows = 0;
1963 let mut sent_bytes = 0;
1964 let messages = (&mut batch_rows)
1965 .map(|row| {
1970 let row_len = row.byte_len();
1971 let values = mz_pgrepr::values_from_row(row, row_desc.typ());
1972 (row_len, BackendMessage::DataRow(values))
1973 })
1974 .inspect(|(row_len, _)| {
1975 sent_bytes += row_len;
1976 sent_rows += 1
1977 })
1978 .map(|(_row_len, row)| row)
1979 .take(want_rows);
1980 self.send_all(messages).await?;
1981
1982 total_sent_rows += sent_rows;
1983 total_sent_bytes += sent_bytes;
1984 want_rows -= sent_rows;
1985
1986 if want_rows == 0 {
1989 if batch_rows.peek().is_some() {
1990 rows.current = Some(batch_rows);
1991 }
1992 break;
1993 }
1994
1995 self.conn.flush().await?;
1996 }
1997 FetchResult::Notice(notice) => {
1998 self.send(notice.into_response()).await?;
1999 self.conn.flush().await?;
2000 }
2001 FetchResult::Error(text) => {
2002 return self
2003 .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2004 .await
2005 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2006 }
2007 FetchResult::Canceled => {
2008 return self
2009 .error(ErrorResponse::error(
2010 SqlState::QUERY_CANCELED,
2011 "canceling statement due to user request",
2012 ))
2013 .await
2014 .map(|state| (state, SendRowsEndedReason::Canceled));
2015 }
2016 }
2017 }
2018
2019 let portal = self
2020 .adapter_client
2021 .session()
2022 .get_portal_unverified_mut(&portal_name)
2023 .expect("valid portal name for send rows");
2024
2025 portal.state = PortalState::InProgress(Some(rows));
2028
2029 let fetch_portal = fetch_portal_name.map(|name| {
2030 self.adapter_client
2031 .session()
2032 .get_portal_unverified_mut(&name)
2033 .expect("valid fetch portal")
2034 });
2035 let response_message = get_response(max_rows, total_sent_rows, fetch_portal);
2036 self.send(response_message).await?;
2037 Ok((
2038 State::Ready,
2039 SendRowsEndedReason::Success {
2040 result_size: u64::cast_from(total_sent_bytes),
2041 rows_returned: u64::cast_from(total_sent_rows),
2042 },
2043 ))
2044 }
2045
2046 #[mz_ore::instrument(level = "debug")]
2047 async fn copy_rows(
2048 &mut self,
2049 format: CopyFormat,
2050 row_desc: RelationDesc,
2051 mut stream: RecordFirstRowStream,
2052 ) -> Result<(State, SendRowsEndedReason), io::Error> {
2053 let (row_format, encode_format) = match format {
2054 CopyFormat::Text => (
2055 CopyFormatParams::Text(CopyTextFormatParams::default()),
2056 Format::Text,
2057 ),
2058 CopyFormat::Binary => (CopyFormatParams::Binary, Format::Binary),
2059 CopyFormat::Csv => (
2060 CopyFormatParams::Csv(CopyCsvFormatParams::default()),
2061 Format::Text,
2062 ),
2063 CopyFormat::Parquet => {
2064 let text = "Parquet format is not supported".to_string();
2065 return self
2066 .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2067 .await
2068 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2069 }
2070 };
2071
2072 let encode_fn = |row: &RowRef, typ: &RelationType, out: &mut Vec<u8>| {
2073 mz_pgcopy::encode_copy_format(&row_format, row, typ, out)
2074 };
2075
2076 let typ = row_desc.typ();
2077 let column_formats = iter::repeat(encode_format)
2078 .take(typ.column_types.len())
2079 .collect();
2080 self.send(BackendMessage::CopyOutResponse {
2081 overall_format: encode_format,
2082 column_formats,
2083 })
2084 .await?;
2085
2086 let mut out = Vec::new();
2091
2092 if let CopyFormat::Binary = format {
2093 out.extend(b"PGCOPY\n\xFF\r\n\0");
2095 out.extend([0, 0, 0, 0]);
2097 out.extend([0, 0, 0, 0]);
2099 }
2100
2101 let mut count = 0;
2102 let mut total_sent_bytes = 0;
2103 loop {
2104 tokio::select! {
2105 e = self.conn.wait_closed() => return Err(e),
2106 batch = stream.recv() => match batch {
2107 None => break,
2108 Some(PeekResponseUnary::Error(text)) => {
2109 return self
2110 .error(ErrorResponse::error(SqlState::INTERNAL_ERROR, text.clone()))
2111 .await
2112 .map(|state| (state, SendRowsEndedReason::Errored { error: text }));
2113 }
2114 Some(PeekResponseUnary::Canceled) => {
2115 return self.error(ErrorResponse::error(
2116 SqlState::QUERY_CANCELED,
2117 "canceling statement due to user request",
2118 ))
2119 .await.map(|state| (state, SendRowsEndedReason::Canceled));
2120 }
2121 Some(PeekResponseUnary::Rows(mut rows)) => {
2122 count += rows.count();
2123 while let Some(row) = rows.next() {
2124 total_sent_bytes += row.byte_len();
2125 encode_fn(row, typ, &mut out)?;
2126 self.send(BackendMessage::CopyData(mem::take(&mut out)))
2127 .await?;
2128 }
2129 }
2130 },
2131 notice = self.adapter_client.session().recv_notice() => {
2132 self.send(notice.into_response())
2133 .await?;
2134 self.conn.flush().await?;
2135 }
2136 }
2137
2138 self.conn.flush().await?;
2139 }
2140 if let CopyFormat::Binary = format {
2142 let trailer: i16 = -1;
2143 out.extend(trailer.to_be_bytes());
2144 self.send(BackendMessage::CopyData(mem::take(&mut out)))
2145 .await?;
2146 }
2147
2148 let tag = format!("COPY {}", count);
2149 self.send(BackendMessage::CopyDone).await?;
2150 self.send(BackendMessage::CommandComplete { tag }).await?;
2151 Ok((
2152 State::Ready,
2153 SendRowsEndedReason::Success {
2154 result_size: u64::cast_from(total_sent_bytes),
2155 rows_returned: u64::cast_from(count),
2156 },
2157 ))
2158 }
2159
2160 #[instrument(level = "debug")]
2163 async fn copy_from(
2164 &mut self,
2165 id: CatalogItemId,
2166 columns: Vec<ColumnIndex>,
2167 params: CopyFormatParams<'_>,
2168 row_desc: RelationDesc,
2169 mut ctx_extra: ExecuteContextExtra,
2170 ) -> Result<State, io::Error> {
2171 let res = self
2172 .copy_from_inner(id, columns, params, row_desc, &mut ctx_extra)
2173 .await;
2174 match &res {
2175 Ok(State::Done) => {
2176 self.adapter_client
2180 .retire_execute(ctx_extra, StatementEndedExecutionReason::Canceled);
2181 }
2182 Err(e) => {
2183 self.adapter_client.retire_execute(
2184 ctx_extra,
2185 StatementEndedExecutionReason::Errored {
2186 error: format!("{e}"),
2187 },
2188 );
2189 }
2190 other => {
2191 tracing::warn!(?other, "aborting COPY FROM");
2192 self.adapter_client
2193 .retire_execute(ctx_extra, StatementEndedExecutionReason::Aborted);
2194 }
2195 }
2196 res
2197 }
2198
2199 async fn copy_from_inner(
2200 &mut self,
2201 id: CatalogItemId,
2202 columns: Vec<ColumnIndex>,
2203 params: CopyFormatParams<'_>,
2204 row_desc: RelationDesc,
2205 ctx_extra: &mut ExecuteContextExtra,
2206 ) -> Result<State, io::Error> {
2207 let typ = row_desc.typ();
2208 let column_formats = vec![Format::Text; typ.column_types.len()];
2209 self.send(BackendMessage::CopyInResponse {
2210 overall_format: Format::Text,
2211 column_formats,
2212 })
2213 .await?;
2214 self.conn.flush().await?;
2215
2216 let system_vars = self.adapter_client.get_system_vars().await;
2217 let max_size = system_vars
2218 .get(MAX_COPY_FROM_SIZE.name())
2219 .ok()
2220 .and_then(|max_size| max_size.value().parse().ok())
2221 .unwrap_or(usize::MAX);
2222 tracing::debug!("COPY FROM max buffer size: {max_size} bytes");
2223
2224 let mut data = Vec::new();
2225 loop {
2226 let message = self.conn.recv().await?;
2227 match message {
2228 Some(FrontendMessage::CopyData(buf)) => {
2229 if (data.len() + buf.len()) > max_size {
2231 return self
2232 .error(ErrorResponse::error(
2233 SqlState::INSUFFICIENT_RESOURCES,
2234 "COPY FROM STDIN too large",
2235 ))
2236 .await;
2237 }
2238 data.extend(buf)
2239 }
2240 Some(FrontendMessage::CopyDone) => break,
2241 Some(FrontendMessage::CopyFail(err)) => {
2242 self.adapter_client.retire_execute(
2243 std::mem::take(ctx_extra),
2244 StatementEndedExecutionReason::Canceled,
2245 );
2246 return self
2247 .error(ErrorResponse::error(
2248 SqlState::QUERY_CANCELED,
2249 format!("COPY from stdin failed: {}", err),
2250 ))
2251 .await;
2252 }
2253 Some(FrontendMessage::Flush) | Some(FrontendMessage::Sync) => {}
2254 Some(_) => {
2255 let msg = "unexpected message type during COPY from stdin";
2256 self.adapter_client.retire_execute(
2257 std::mem::take(ctx_extra),
2258 StatementEndedExecutionReason::Errored {
2259 error: msg.to_string(),
2260 },
2261 );
2262 return self
2263 .error(ErrorResponse::error(SqlState::PROTOCOL_VIOLATION, msg))
2264 .await;
2265 }
2266 None => {
2267 return Ok(State::Done);
2268 }
2269 }
2270 }
2271
2272 let column_types = typ
2273 .column_types
2274 .iter()
2275 .map(|x| &x.scalar_type)
2276 .map(mz_pgrepr::Type::from)
2277 .collect::<Vec<mz_pgrepr::Type>>();
2278
2279 let rows = match mz_pgcopy::decode_copy_format(&data, &column_types, params) {
2280 Ok(rows) => rows,
2281 Err(e) => {
2282 self.adapter_client.retire_execute(
2283 std::mem::take(ctx_extra),
2284 StatementEndedExecutionReason::Errored {
2285 error: e.to_string(),
2286 },
2287 );
2288 return self
2289 .error(ErrorResponse::error(
2290 SqlState::BAD_COPY_FILE_FORMAT,
2291 format!("{}", e),
2292 ))
2293 .await;
2294 }
2295 };
2296
2297 let count = rows.len();
2298
2299 if let Err(e) = self
2300 .adapter_client
2301 .insert_rows(id, columns, rows, std::mem::take(ctx_extra))
2302 .await
2303 {
2304 self.adapter_client.retire_execute(
2305 std::mem::take(ctx_extra),
2306 StatementEndedExecutionReason::Errored {
2307 error: e.to_string(),
2308 },
2309 );
2310 return self.error(e.into_response(Severity::Error)).await;
2311 }
2312
2313 let tag = format!("COPY {}", count);
2314 self.send(BackendMessage::CommandComplete { tag }).await?;
2315
2316 Ok(State::Ready)
2317 }
2318
2319 #[instrument(level = "debug")]
2320 async fn send_pending_notices(&mut self) -> Result<(), io::Error> {
2321 let notices = self
2322 .adapter_client
2323 .session()
2324 .drain_notices()
2325 .into_iter()
2326 .map(|notice| BackendMessage::ErrorResponse(notice.into_response()));
2327 self.send_all(notices).await?;
2328 Ok(())
2329 }
2330
2331 #[instrument(level = "debug")]
2332 async fn error(&mut self, err: ErrorResponse) -> Result<State, io::Error> {
2333 assert!(err.severity.is_error());
2334 debug!(
2335 "cid={} error code={}",
2336 self.adapter_client.session().conn_id(),
2337 err.code.code()
2338 );
2339 let is_fatal = err.severity.is_fatal();
2340 self.send(BackendMessage::ErrorResponse(err)).await?;
2341
2342 let txn = self.adapter_client.session().transaction();
2343 match txn {
2344 TransactionStatus::Default | TransactionStatus::Failed(_) => {}
2347 TransactionStatus::Started(_) => {
2349 self.rollback_transaction().await?;
2350 }
2351 TransactionStatus::InTransactionImplicit(_) => {
2353 self.rollback_transaction().await?;
2354 }
2355 TransactionStatus::InTransaction(_) => {
2357 self.adapter_client.fail_transaction();
2358 }
2359 };
2360 if is_fatal {
2361 Ok(State::Done)
2362 } else {
2363 Ok(State::Drain)
2364 }
2365 }
2366
2367 #[instrument(level = "debug")]
2368 async fn aborted_txn_error(&mut self) -> Result<State, io::Error> {
2369 self.send(BackendMessage::ErrorResponse(ErrorResponse::error(
2370 SqlState::IN_FAILED_SQL_TRANSACTION,
2371 ABORTED_TXN_MSG,
2372 )))
2373 .await?;
2374 Ok(State::Drain)
2375 }
2376
2377 fn is_aborted_txn(&mut self) -> bool {
2378 matches!(
2379 self.adapter_client.session().transaction(),
2380 TransactionStatus::Failed(_)
2381 )
2382 }
2383}
2384
2385fn pad_formats(formats: Vec<Format>, n: usize) -> Result<Vec<Format>, String> {
2386 match (formats.len(), n) {
2387 (0, e) => Ok(vec![Format::Text; e]),
2388 (1, e) => Ok(iter::repeat(formats[0]).take(e).collect()),
2389 (a, e) if a == e => Ok(formats),
2390 (a, e) => Err(format!(
2391 "expected {} field format specifiers, but got {}",
2392 e, a
2393 )),
2394 }
2395}
2396
2397fn describe_rows(stmt_desc: &StatementDesc, formats: &[Format]) -> BackendMessage {
2398 match &stmt_desc.relation_desc {
2399 Some(desc) if !stmt_desc.is_copy => {
2400 BackendMessage::RowDescription(message::encode_row_description(desc, formats))
2401 }
2402 _ => BackendMessage::NoData,
2403 }
2404}
2405
2406type GetResponse = fn(
2407 max_rows: ExecuteCount,
2408 total_sent_rows: usize,
2409 fetch_portal: Option<&mut Portal>,
2410) -> BackendMessage;
2411
2412fn portal_exec_message(
2415 max_rows: ExecuteCount,
2416 total_sent_rows: usize,
2417 _fetch_portal: Option<&mut Portal>,
2418) -> BackendMessage {
2419 match max_rows {
2426 ExecuteCount::Count(max_rows) if max_rows <= total_sent_rows => {
2427 BackendMessage::PortalSuspended
2428 }
2429 _ => BackendMessage::CommandComplete {
2430 tag: format!("SELECT {}", total_sent_rows),
2431 },
2432 }
2433}
2434
2435fn fetch_message(
2437 _max_rows: ExecuteCount,
2438 total_sent_rows: usize,
2439 fetch_portal: Option<&mut Portal>,
2440) -> BackendMessage {
2441 let tag = format!("FETCH {}", total_sent_rows);
2442 if let Some(portal) = fetch_portal {
2443 portal.state = PortalState::Completed(Some(tag.clone()));
2444 }
2445 BackendMessage::CommandComplete { tag }
2446}
2447
2448#[derive(Debug, Copy, Clone)]
2449enum ExecuteCount {
2450 All,
2451 Count(usize),
2452}
2453
2454fn is_txn_exit_stmt(stmt: Option<&Statement<Raw>>) -> bool {
2456 match stmt {
2457 Some(stmt) => matches!(stmt, Statement::Commit(_) | Statement::Rollback(_)),
2459 None => false,
2460 }
2461}
2462
2463#[derive(Debug)]
2464enum FetchResult {
2465 Rows(Option<Box<dyn RowIterator + Send + Sync>>),
2466 Canceled,
2467 Error(String),
2468 Notice(AdapterNotice),
2469}
2470
2471#[cfg(test)]
2472mod test {
2473 use super::*;
2474
2475 #[mz_ore::test]
2476 fn test_parse_options() {
2477 struct TestCase {
2478 input: &'static str,
2479 expect: Result<Vec<(&'static str, &'static str)>, ()>,
2480 }
2481 let tests = vec![
2482 TestCase {
2483 input: "",
2484 expect: Ok(vec![]),
2485 },
2486 TestCase {
2487 input: "--key",
2488 expect: Err(()),
2489 },
2490 TestCase {
2491 input: "--key=val",
2492 expect: Ok(vec![("key", "val")]),
2493 },
2494 TestCase {
2495 input: r#"--key=val -ckey2=val2 -c key3=val3 -c key4=val4 -ckey5=val5"#,
2496 expect: Ok(vec![
2497 ("key", "val"),
2498 ("key2", "val2"),
2499 ("key3", "val3"),
2500 ("key4", "val4"),
2501 ("key5", "val5"),
2502 ]),
2503 },
2504 TestCase {
2505 input: r#"-c\ key=val"#,
2506 expect: Ok(vec![(" key", "val")]),
2507 },
2508 TestCase {
2509 input: "--key=val -ckey2 val2",
2510 expect: Err(()),
2511 },
2512 TestCase {
2514 input: "--key=",
2515 expect: Ok(vec![("key", "")]),
2516 },
2517 ];
2518 for test in tests {
2519 let got = parse_options(test.input);
2520 let expect = test.expect.map(|r| {
2521 r.into_iter()
2522 .map(|(k, v)| (k.to_owned(), v.to_owned()))
2523 .collect()
2524 });
2525 assert_eq!(got, expect, "input: {}", test.input);
2526 }
2527 }
2528
2529 #[mz_ore::test]
2530 fn test_parse_option() {
2531 struct TestCase {
2532 input: &'static str,
2533 expect: Result<(&'static str, &'static str), ()>,
2534 }
2535 let tests = vec![
2536 TestCase {
2537 input: "",
2538 expect: Err(()),
2539 },
2540 TestCase {
2541 input: "--",
2542 expect: Err(()),
2543 },
2544 TestCase {
2545 input: "--c",
2546 expect: Err(()),
2547 },
2548 TestCase {
2549 input: "a=b",
2550 expect: Err(()),
2551 },
2552 TestCase {
2553 input: "--a=b",
2554 expect: Ok(("a", "b")),
2555 },
2556 TestCase {
2557 input: "--ca=b",
2558 expect: Ok(("ca", "b")),
2559 },
2560 TestCase {
2561 input: "-ca=b",
2562 expect: Ok(("a", "b")),
2563 },
2564 TestCase {
2566 input: "--=",
2567 expect: Ok(("", "")),
2568 },
2569 ];
2570 for test in tests {
2571 let got = parse_option(test.input);
2572 assert_eq!(got, test.expect, "input: {}", test.input);
2573 }
2574 }
2575
2576 #[mz_ore::test]
2577 fn test_split_options() {
2578 struct TestCase {
2579 input: &'static str,
2580 expect: Vec<&'static str>,
2581 }
2582 let tests = vec![
2583 TestCase {
2584 input: "",
2585 expect: vec![],
2586 },
2587 TestCase {
2588 input: " ",
2589 expect: vec![],
2590 },
2591 TestCase {
2592 input: " a ",
2593 expect: vec!["a"],
2594 },
2595 TestCase {
2596 input: " ab cd ",
2597 expect: vec!["ab", "cd"],
2598 },
2599 TestCase {
2600 input: r#" ab\ cd "#,
2601 expect: vec!["ab ", "cd"],
2602 },
2603 TestCase {
2604 input: r#" ab\\ cd "#,
2605 expect: vec![r#"ab\"#, "cd"],
2606 },
2607 TestCase {
2608 input: r#" ab\\\ cd "#,
2609 expect: vec![r#"ab\ "#, "cd"],
2610 },
2611 TestCase {
2612 input: r#" ab\\\ cd "#,
2613 expect: vec![r#"ab\ cd"#],
2614 },
2615 TestCase {
2616 input: r#" ab\\\cd "#,
2617 expect: vec![r#"ab\cd"#],
2618 },
2619 TestCase {
2620 input: r#"a\"#,
2621 expect: vec!["a"],
2622 },
2623 TestCase {
2624 input: r#"a\ "#,
2625 expect: vec!["a "],
2626 },
2627 TestCase {
2628 input: r#"\"#,
2629 expect: vec![],
2630 },
2631 TestCase {
2632 input: r#"\ "#,
2633 expect: vec![r#" "#],
2634 },
2635 TestCase {
2636 input: r#" \ "#,
2637 expect: vec![r#" "#],
2638 },
2639 TestCase {
2640 input: r#"\ "#,
2641 expect: vec![r#" "#],
2642 },
2643 ];
2644 for test in tests {
2645 let got = split_options(test.input);
2646 assert_eq!(got, test.expect, "input: {}", test.input);
2647 }
2648 }
2649}