mz_pgwire_common/message.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeMap;
11
12use tokio_postgres::error::SqlState;
13
14use crate::{Format, Severity};
15
16// Pgwire protocol versions are represented as 32-bit integers, where the
17// high 16 bits represent the major version and the low 16 bits represent the
18// minor version.
19//
20// There have only been three released protocol versions, v1.0, v2.0, and v3.0.
21// The protocol changes very infrequently: the most recent protocol version,
22// v3.0, was released with Postgres v7.4 in 2003.
23//
24// Somewhat unfortunately, the protocol overloads the version field to indicate
25// special types of connections, namely, SSL connections and cancellation
26// connections. These pseudo-versions were constructed to avoid ever matching
27// a true protocol version.
28
29pub const VERSION_1: i32 = 0x10000;
30pub const VERSION_2: i32 = 0x20000;
31pub const VERSION_3: i32 = 0x30000;
32pub const VERSION_CANCEL: i32 = (1234 << 16) + 5678;
33pub const VERSION_SSL: i32 = (1234 << 16) + 5679;
34pub const VERSION_GSSENC: i32 = (1234 << 16) + 5680;
35
36pub const VERSIONS: &[i32] = &[
37 VERSION_1,
38 VERSION_2,
39 VERSION_3,
40 VERSION_CANCEL,
41 VERSION_SSL,
42 VERSION_GSSENC,
43];
44
45/// Like [`FrontendMessage`], but only the messages that can occur during
46/// startup protocol negotiation.
47#[derive(Debug)]
48pub enum FrontendStartupMessage {
49 /// Begin a connection.
50 Startup {
51 version: i32,
52 params: BTreeMap<String, String>,
53 },
54
55 /// Request SSL encryption for the connection.
56 SslRequest,
57
58 /// Request GSSAPI encryption for the connection.
59 GssEncRequest,
60
61 /// Cancel a query that is running on another connection.
62 CancelRequest {
63 /// The target connection ID.
64 conn_id: u32,
65 /// The secret key for the target connection.
66 secret_key: u32,
67 },
68}
69
70/// A decoded frontend pgwire [message], representing instructions for the
71/// backend.
72///
73/// [message]: https://www.postgresql.org/docs/11/protocol-message-formats.html
74#[derive(Debug)]
75pub enum FrontendMessage {
76 /// Execute the specified SQL.
77 ///
78 /// This is issued as part of the simple query flow.
79 Query {
80 /// The SQL to execute.
81 sql: String,
82 },
83
84 /// Parse the specified SQL into a prepared statement.
85 ///
86 /// This starts the extended query flow.
87 Parse {
88 /// The name of the prepared statement to create. An empty string
89 /// specifies the unnamed prepared statement.
90 name: String,
91 /// The SQL to parse.
92 sql: String,
93 /// The OID of each parameter data type for which the client wants to
94 /// prespecify types. A zero OID is equivalent to leaving the type
95 /// unspecified.
96 ///
97 /// The number of specified parameter data types can be less than the
98 /// number of parameters specified in the query.
99 param_types: Vec<u32>,
100 },
101
102 /// Describe an existing prepared statement.
103 ///
104 /// This command is part of the extended query flow.
105 DescribeStatement {
106 /// The name of the prepared statement to describe.
107 name: String,
108 },
109
110 /// Describe an existing portal.
111 ///
112 /// This command is part of the extended query flow.
113 DescribePortal {
114 /// The name of the portal to describe.
115 name: String,
116 },
117
118 /// Bind an existing prepared statement to a portal.
119 ///
120 /// This command is part of the extended query flow.
121 Bind {
122 /// The destination portal. An empty string selects the unnamed
123 /// portal. The portal can later be executed with the `Execute` command.
124 portal_name: String,
125 /// The source prepared statement. An empty string selects the unnamed
126 /// prepared statement.
127 statement_name: String,
128 /// The formats used to encode the parameters in `raw_parameters`.
129 param_formats: Vec<Format>,
130 /// The value of each parameter, encoded using the formats described
131 /// by `parameter_formats`.
132 raw_params: Vec<Option<Vec<u8>>>,
133 /// The desired formats for the columns in the result set.
134 result_formats: Vec<Format>,
135 },
136
137 /// Execute a bound portal.
138 ///
139 /// This command is part of the extended query flow.
140 Execute {
141 /// The name of the portal to execute.
142 portal_name: String,
143 /// The maximum number number of rows to return before suspending.
144 ///
145 /// 0 or negative means infinite.
146 max_rows: i32,
147 },
148
149 /// Flush any pending output.
150 ///
151 /// This command is part of the extended query flow.
152 Flush,
153
154 /// Finish an extended query.
155 ///
156 /// This command is part of the extended query flow.
157 Sync,
158
159 /// Close the named statement.
160 ///
161 /// This command is part of the extended query flow.
162 CloseStatement {
163 name: String,
164 },
165
166 /// Close the named portal.
167 ///
168 // This command is part of the extended query flow.
169 ClosePortal {
170 name: String,
171 },
172
173 /// Terminate a connection.
174 Terminate,
175
176 CopyData(Vec<u8>),
177
178 CopyDone,
179
180 CopyFail(String),
181
182 RawAuthentication(Vec<u8>),
183
184 Password {
185 password: String,
186 },
187
188 SASLInitialResponse {
189 gs2_header: GS2Header,
190 mechanism: String,
191 initial_response: SASLInitialResponse,
192 },
193
194 SASLResponse(SASLClientFinalResponse),
195}
196
197#[derive(Debug, Clone)]
198pub enum ChannelBinding {
199 /// Client doesn't support channel binding.
200 None,
201 /// Client supports channel binding but thinks server does not.
202 ClientSupported,
203 /// Client requires channel binding.
204 Required(String),
205}
206
207#[derive(Debug, Clone)]
208pub struct GS2Header {
209 pub cbind_flag: ChannelBinding,
210 pub authzid: Option<String>,
211}
212
213impl GS2Header {
214 pub fn channel_binding_enabled(&self) -> bool {
215 matches!(self.cbind_flag, ChannelBinding::Required(_))
216 }
217}
218
219#[derive(Debug)]
220pub struct SASLInitialResponse {
221 pub gs2_header: GS2Header,
222 pub nonce: String,
223 pub extensions: Vec<String>,
224 pub reserved_mext: Option<String>,
225 pub client_first_message_bare_raw: String,
226}
227
228#[derive(Debug)]
229pub struct SASLClientFinalResponse {
230 pub channel_binding: String,
231 pub nonce: String,
232 pub extensions: Vec<String>,
233 pub proof: String,
234 pub client_final_message_bare_raw: String,
235}
236
237impl FrontendMessage {
238 pub fn name(&self) -> &'static str {
239 match self {
240 FrontendMessage::Query { .. } => "query",
241 FrontendMessage::Parse { .. } => "parse",
242 FrontendMessage::DescribeStatement { .. } => "describe_statement",
243 FrontendMessage::DescribePortal { .. } => "describe_portal",
244 FrontendMessage::Bind { .. } => "bind",
245 FrontendMessage::Execute { .. } => "execute",
246 FrontendMessage::Flush => "flush",
247 FrontendMessage::Sync => "sync",
248 FrontendMessage::CloseStatement { .. } => "close_statement",
249 FrontendMessage::ClosePortal { .. } => "close_portal",
250 FrontendMessage::Terminate => "terminate",
251 FrontendMessage::CopyData(_) => "copy_data",
252 FrontendMessage::CopyDone => "copy_done",
253 FrontendMessage::CopyFail(_) => "copy_fail",
254 FrontendMessage::RawAuthentication(_) => "raw_authentication",
255 FrontendMessage::Password { .. } => "password",
256 FrontendMessage::SASLInitialResponse { .. } => "sasl_initial_response",
257 FrontendMessage::SASLResponse(_) => "sasl_response",
258 }
259 }
260}
261
262#[derive(Debug)]
263pub struct ErrorResponse {
264 pub severity: Severity,
265 pub code: SqlState,
266 pub message: String,
267 pub detail: Option<String>,
268 pub hint: Option<String>,
269 pub position: Option<usize>,
270}
271
272impl ErrorResponse {
273 pub fn fatal<S>(code: SqlState, message: S) -> ErrorResponse
274 where
275 S: Into<String>,
276 {
277 ErrorResponse::new(Severity::Fatal, code, message)
278 }
279
280 pub fn error<S>(code: SqlState, message: S) -> ErrorResponse
281 where
282 S: Into<String>,
283 {
284 ErrorResponse::new(Severity::Error, code, message)
285 }
286
287 pub fn notice<S>(code: SqlState, message: S) -> ErrorResponse
288 where
289 S: Into<String>,
290 {
291 ErrorResponse::new(Severity::Notice, code, message)
292 }
293
294 fn new<S>(severity: Severity, code: SqlState, message: S) -> ErrorResponse
295 where
296 S: Into<String>,
297 {
298 ErrorResponse {
299 severity,
300 code,
301 message: message.into(),
302 detail: None,
303 hint: None,
304 position: None,
305 }
306 }
307
308 pub fn with_position(mut self, position: usize) -> ErrorResponse {
309 self.position = Some(position);
310 self
311 }
312}