1use anyhow::anyhow;
27use axum::Json;
28use axum::response::IntoResponse;
29use http::StatusCode;
30use mz_adapter_types::dyncfgs::{ENABLE_MCP_AGENTS, ENABLE_MCP_OBSERVATORY};
31use mz_sql::parse::parse;
32use mz_sql::session::metadata::SessionMetadata;
33use mz_sql_parser::ast::display::escaped_string_literal;
34use mz_sql_parser::ast::visit::{self, Visit};
35use mz_sql_parser::ast::{Raw, RawItemName};
36use serde::{Deserialize, Serialize};
37use serde_json::json;
38use thiserror::Error;
39use tracing::{debug, warn};
40
41use crate::http::AuthedClient;
42use crate::http::sql::{SqlRequest, SqlResponse, SqlResult, execute_request};
43
44const DISCOVERY_QUERY: &str = "SELECT * FROM mz_internal.mz_mcp_data_products";
46
47#[derive(Debug, Error)]
49enum McpRequestError {
50 #[error("Invalid JSON-RPC version: expected 2.0")]
51 InvalidJsonRpcVersion,
52 #[error("Method not found: {0}")]
53 #[allow(dead_code)] MethodNotFound(String),
55 #[error("Tool not found: {0}")]
56 ToolNotFound(String),
57 #[error("Data product not found: {0}")]
58 DataProductNotFound(String),
59 #[error("Query validation failed: {0}")]
60 QueryValidationFailed(String),
61 #[error("Query execution failed: {0}")]
62 QueryExecutionFailed(String),
63 #[error("Internal error: {0}")]
64 Internal(#[from] anyhow::Error),
65}
66
67impl McpRequestError {
68 fn error_code(&self) -> i32 {
69 match self {
70 Self::InvalidJsonRpcVersion => error_codes::INVALID_REQUEST,
71 Self::MethodNotFound(_) => error_codes::METHOD_NOT_FOUND,
72 Self::ToolNotFound(_) => error_codes::INVALID_PARAMS,
73 Self::DataProductNotFound(_) => error_codes::INVALID_PARAMS,
74 Self::QueryValidationFailed(_) => error_codes::INVALID_PARAMS,
75 Self::QueryExecutionFailed(_) | Self::Internal(_) => error_codes::INTERNAL_ERROR,
76 }
77 }
78
79 fn error_type(&self) -> &'static str {
80 match self {
81 Self::InvalidJsonRpcVersion => "InvalidRequest",
82 Self::MethodNotFound(_) => "MethodNotFound",
83 Self::ToolNotFound(_) => "ToolNotFound",
84 Self::DataProductNotFound(_) => "DataProductNotFound",
85 Self::QueryValidationFailed(_) => "ValidationError",
86 Self::QueryExecutionFailed(_) => "ExecutionError",
87 Self::Internal(_) => "InternalError",
88 }
89 }
90}
91
92#[derive(Debug, Deserialize)]
94pub(crate) struct McpRequest {
95 jsonrpc: String,
96 id: Option<serde_json::Value>,
97 #[serde(flatten)]
98 method: McpMethod,
99}
100
101#[derive(Debug, Deserialize)]
103#[serde(tag = "method", content = "params")]
104enum McpMethod {
105 #[serde(rename = "initialize")]
107 Initialize(#[allow(dead_code)] InitializeParams),
108 #[serde(rename = "tools/list")]
109 ToolsList,
110 #[serde(rename = "tools/call")]
111 ToolsCall(ToolsCallParams),
112 #[serde(other)]
114 Unknown,
115}
116
117impl std::fmt::Display for McpMethod {
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 match self {
120 McpMethod::Initialize(_) => write!(f, "initialize"),
121 McpMethod::ToolsList => write!(f, "tools/list"),
122 McpMethod::ToolsCall(_) => write!(f, "tools/call"),
123 McpMethod::Unknown => write!(f, "unknown"),
124 }
125 }
126}
127
128#[derive(Debug, Deserialize)]
129struct InitializeParams {
130 #[serde(rename = "protocolVersion")]
132 #[allow(dead_code)]
133 protocol_version: String,
134 #[serde(default)]
136 #[allow(dead_code)]
137 capabilities: serde_json::Value,
138 #[serde(rename = "clientInfo")]
140 #[allow(dead_code)]
141 client_info: Option<ClientInfo>,
142}
143
144#[derive(Debug, Deserialize)]
145struct ClientInfo {
146 #[allow(dead_code)]
147 name: String,
148 #[allow(dead_code)]
149 version: String,
150}
151
152#[derive(Debug, Deserialize)]
155#[serde(tag = "name", content = "arguments")]
156#[serde(rename_all = "snake_case")]
157enum ToolsCallParams {
158 GetDataProducts(#[serde(default)] ()),
161 GetDataProductDetails(GetDataProductDetailsParams),
162 Query(QueryParams),
163 QuerySystemCatalog(QuerySystemCatalogParams),
165}
166
167impl std::fmt::Display for ToolsCallParams {
168 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169 match self {
170 ToolsCallParams::GetDataProducts(_) => write!(f, "get_data_products"),
171 ToolsCallParams::GetDataProductDetails(_) => write!(f, "get_data_product_details"),
172 ToolsCallParams::Query(_) => write!(f, "query"),
173 ToolsCallParams::QuerySystemCatalog(_) => write!(f, "query_system_catalog"),
174 }
175 }
176}
177
178#[derive(Debug, Deserialize)]
179struct GetDataProductDetailsParams {
180 name: String,
181}
182
183#[derive(Debug, Deserialize)]
184struct QueryParams {
185 cluster: String,
186 sql_query: String,
187}
188
189#[derive(Debug, Deserialize)]
190struct QuerySystemCatalogParams {
191 sql_query: String,
192}
193
194#[derive(Debug, Serialize)]
195struct McpResponse {
196 jsonrpc: String,
197 id: serde_json::Value,
198 #[serde(skip_serializing_if = "Option::is_none")]
199 result: Option<McpResult>,
200 #[serde(skip_serializing_if = "Option::is_none")]
201 error: Option<McpError>,
202}
203
204#[derive(Debug, Serialize)]
206#[serde(untagged)]
207enum McpResult {
208 Initialize(InitializeResult),
209 ToolsList(ToolsListResult),
210 ToolContent(ToolContentResult),
211}
212
213#[derive(Debug, Serialize)]
214struct InitializeResult {
215 #[serde(rename = "protocolVersion")]
216 protocol_version: String,
217 capabilities: Capabilities,
218 #[serde(rename = "serverInfo")]
219 server_info: ServerInfo,
220}
221
222#[derive(Debug, Serialize)]
223struct Capabilities {
224 tools: serde_json::Value,
225}
226
227#[derive(Debug, Serialize)]
228struct ServerInfo {
229 name: String,
230 version: String,
231}
232
233#[derive(Debug, Serialize)]
234struct ToolsListResult {
235 tools: Vec<ToolDefinition>,
236}
237
238#[derive(Debug, Serialize)]
239struct ToolDefinition {
240 name: String,
241 description: String,
242 #[serde(rename = "inputSchema")]
243 input_schema: serde_json::Value,
244}
245
246#[derive(Debug, Serialize)]
247struct ToolContentResult {
248 content: Vec<ContentBlock>,
249}
250
251#[derive(Debug, Serialize)]
252struct ContentBlock {
253 #[serde(rename = "type")]
254 content_type: String,
255 text: String,
256}
257
258mod error_codes {
260 pub const INVALID_REQUEST: i32 = -32600;
261 pub const METHOD_NOT_FOUND: i32 = -32601;
262 pub const INVALID_PARAMS: i32 = -32602;
263 pub const INTERNAL_ERROR: i32 = -32603;
264}
265
266#[derive(Debug, Serialize)]
267struct McpError {
268 code: i32,
269 message: String,
270 #[serde(skip_serializing_if = "Option::is_none")]
271 data: Option<serde_json::Value>,
272}
273
274impl From<McpRequestError> for McpError {
275 fn from(err: McpRequestError) -> Self {
276 McpError {
277 code: err.error_code(),
278 message: err.to_string(),
279 data: Some(json!({
280 "error_type": err.error_type(),
281 })),
282 }
283 }
284}
285
286#[derive(Debug, Clone, Copy)]
287enum McpEndpointType {
288 Agents,
289 Observatory,
290}
291
292impl std::fmt::Display for McpEndpointType {
293 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
294 match self {
295 McpEndpointType::Agents => write!(f, "agents"),
296 McpEndpointType::Observatory => write!(f, "observatory"),
297 }
298 }
299}
300
301pub async fn handle_mcp_agents(
303 client: AuthedClient,
304 Json(request): Json<McpRequest>,
305) -> impl IntoResponse {
306 handle_mcp_request(client, request, McpEndpointType::Agents).await
307}
308
309pub async fn handle_mcp_observatory(
311 client: AuthedClient,
312 Json(request): Json<McpRequest>,
313) -> impl IntoResponse {
314 handle_mcp_request(client, request, McpEndpointType::Observatory).await
315}
316
317async fn handle_mcp_request(
318 mut client: AuthedClient,
319 request: McpRequest,
320 endpoint_type: McpEndpointType,
321) -> impl IntoResponse {
322 let catalog = client.client.catalog_snapshot("mcp").await;
324 let dyncfgs = catalog.system_config().dyncfgs();
325 let enabled = match endpoint_type {
326 McpEndpointType::Agents => ENABLE_MCP_AGENTS.get(dyncfgs),
327 McpEndpointType::Observatory => ENABLE_MCP_OBSERVATORY.get(dyncfgs),
328 };
329 if !enabled {
330 debug!(endpoint = %endpoint_type, "MCP endpoint disabled by feature flag");
331 return StatusCode::SERVICE_UNAVAILABLE.into_response();
332 }
333
334 let user = client.client.session().user().name.clone();
335 let is_notification = request.id.is_none();
336
337 debug!(
338 method = %request.method,
339 endpoint = %endpoint_type,
340 user = %user,
341 is_notification = is_notification,
342 "MCP request received"
343 );
344
345 if is_notification {
347 debug!(method = %request.method, "Received notification (no response will be sent)");
348 return StatusCode::OK.into_response();
349 }
350
351 let response = mz_ore::task::spawn(|| "mcp_request", async move {
353 handle_mcp_request_inner(&mut client, request, endpoint_type).await
354 })
355 .await;
356
357 (StatusCode::OK, Json(response)).into_response()
358}
359
360async fn handle_mcp_request_inner(
361 client: &mut AuthedClient,
362 request: McpRequest,
363 endpoint_type: McpEndpointType,
364) -> McpResponse {
365 let request_id = request.id.clone().unwrap_or(serde_json::Value::Null);
367
368 let result = handle_mcp_method(client, &request, endpoint_type).await;
369
370 match result {
371 Ok(result_value) => McpResponse {
372 jsonrpc: "2.0".to_string(),
373 id: request_id,
374 result: Some(result_value),
375 error: None,
376 },
377 Err(e) => {
378 if !matches!(
380 e,
381 McpRequestError::MethodNotFound(_) | McpRequestError::InvalidJsonRpcVersion
382 ) {
383 warn!(error = %e, method = %request.method, "MCP method execution failed");
384 }
385 McpResponse {
386 jsonrpc: "2.0".to_string(),
387 id: request_id,
388 result: None,
389 error: Some(e.into()),
390 }
391 }
392 }
393}
394
395async fn handle_mcp_method(
396 client: &mut AuthedClient,
397 request: &McpRequest,
398 endpoint_type: McpEndpointType,
399) -> Result<McpResult, McpRequestError> {
400 if request.jsonrpc != "2.0" {
402 return Err(McpRequestError::InvalidJsonRpcVersion);
403 }
404
405 match &request.method {
407 McpMethod::Initialize(_) => {
408 debug!(endpoint = %endpoint_type, "Processing initialize");
409 handle_initialize(endpoint_type).await
410 }
411 McpMethod::ToolsList => {
412 debug!(endpoint = %endpoint_type, "Processing tools/list");
413 handle_tools_list(endpoint_type).await
414 }
415 McpMethod::ToolsCall(params) => {
416 debug!(tool = %params, endpoint = %endpoint_type, "Processing tools/call");
417 handle_tools_call(client, params, endpoint_type).await
418 }
419 McpMethod::Unknown => Err(McpRequestError::MethodNotFound(
420 "unknown method".to_string(),
421 )),
422 }
423}
424
425async fn handle_initialize(endpoint_type: McpEndpointType) -> Result<McpResult, McpRequestError> {
426 Ok(McpResult::Initialize(InitializeResult {
427 protocol_version: "2024-11-05".to_string(),
428 capabilities: Capabilities { tools: json!({}) },
429 server_info: ServerInfo {
430 name: format!("materialize-mcp-{}", endpoint_type),
431 version: env!("CARGO_PKG_VERSION").to_string(),
432 },
433 }))
434}
435
436async fn handle_tools_list(endpoint_type: McpEndpointType) -> Result<McpResult, McpRequestError> {
437 let tools = match endpoint_type {
438 McpEndpointType::Agents => {
439 vec![
440 ToolDefinition {
441 name: "get_data_products".to_string(),
442 description: "Discover all available real-time data views (data products) that represent business entities like customers, orders, products, etc. Each data product provides fresh, queryable data with defined schemas. Use this first to see what data is available before querying specific information.".to_string(),
443 input_schema: json!({
444 "type": "object",
445 "properties": {},
446 "required": []
447 }),
448 },
449 ToolDefinition {
450 name: "get_data_product_details".to_string(),
451 description: "Get the complete schema and structure of a specific data product. This shows you exactly what fields are available, their types, and what data you can query. Use this after finding a data product from get_data_products() to understand how to query it.".to_string(),
452 input_schema: json!({
453 "type": "object",
454 "properties": {
455 "name": {
456 "type": "string",
457 "description": "Exact name of the data product from get_data_products() list"
458 }
459 },
460 "required": ["name"]
461 }),
462 },
463 ToolDefinition {
464 name: "query".to_string(),
465 description: "Execute SQL queries against real-time data products to retrieve current business information. Use standard PostgreSQL syntax. You can JOIN multiple data products together, but ONLY if they are all hosted on the same cluster. Always specify the cluster parameter from the data product details. This provides fresh, up-to-date results from materialized views.".to_string(),
466 input_schema: json!({
467 "type": "object",
468 "properties": {
469 "cluster": {
470 "type": "string",
471 "description": "Exact cluster name from the data product details - required for query execution"
472 },
473 "sql_query": {
474 "type": "string",
475 "description": "PostgreSQL-compatible SELECT statement to retrieve data. Use the fully qualified data product name exactly as provided (with double quotes). You can JOIN multiple data products, but only those on the same cluster."
476 }
477 },
478 "required": ["cluster", "sql_query"]
479 }),
480 },
481 ]
482 }
483 McpEndpointType::Observatory => {
484 vec![
485 ToolDefinition {
486 name: "query_system_catalog".to_string(),
487 description: "Query Materialize system catalog tables (mz_*) for troubleshooting and observability. Only mz_* tables are accessible.".to_string(),
488 input_schema: json!({
489 "type": "object",
490 "properties": {
491 "sql_query": {
492 "type": "string",
493 "description": "SQL query restricted to mz_* system tables"
494 }
495 },
496 "required": ["sql_query"]
497 }),
498 },
499 ]
500 }
501 };
502
503 Ok(McpResult::ToolsList(ToolsListResult { tools }))
504}
505
506async fn handle_tools_call(
507 client: &mut AuthedClient,
508 params: &ToolsCallParams,
509 endpoint_type: McpEndpointType,
510) -> Result<McpResult, McpRequestError> {
511 match (endpoint_type, params) {
512 (McpEndpointType::Agents, ToolsCallParams::GetDataProducts(_)) => {
513 get_data_products(client).await
514 }
515 (McpEndpointType::Agents, ToolsCallParams::GetDataProductDetails(p)) => {
516 get_data_product_details(client, &p.name).await
517 }
518 (McpEndpointType::Agents, ToolsCallParams::Query(p)) => {
519 execute_query(client, &p.cluster, &p.sql_query).await
520 }
521 (McpEndpointType::Observatory, ToolsCallParams::QuerySystemCatalog(p)) => {
522 query_system_catalog(client, &p.sql_query).await
523 }
524 (endpoint, tool) => Err(McpRequestError::ToolNotFound(format!(
526 "{} is not available on {} endpoint",
527 tool, endpoint
528 ))),
529 }
530}
531
532async fn execute_sql(
534 client: &mut AuthedClient,
535 query: &str,
536) -> Result<Vec<Vec<serde_json::Value>>, McpRequestError> {
537 let mut response = SqlResponse::new();
538
539 execute_request(
540 client,
541 SqlRequest::Simple {
542 query: query.to_string(),
543 },
544 &mut response,
545 )
546 .await
547 .map_err(|e| McpRequestError::QueryExecutionFailed(e.to_string()))?;
548
549 for result in response.results {
552 match result {
553 SqlResult::Rows { rows, .. } => return Ok(rows),
554 SqlResult::Err { error, .. } => {
555 return Err(McpRequestError::QueryExecutionFailed(error.message));
556 }
557 SqlResult::Ok { .. } => continue,
558 }
559 }
560
561 Err(McpRequestError::QueryExecutionFailed(
562 "Query did not return any results".to_string(),
563 ))
564}
565
566async fn get_data_products(client: &mut AuthedClient) -> Result<McpResult, McpRequestError> {
567 debug!("Executing get_data_products");
568 let rows = execute_sql(client, DISCOVERY_QUERY).await?;
569 debug!("get_data_products returned {} rows", rows.len());
570 if rows.is_empty() {
571 warn!("No data products found - indexes must have comments");
572 }
573
574 let text =
575 serde_json::to_string_pretty(&rows).map_err(|e| McpRequestError::Internal(anyhow!(e)))?;
576
577 Ok(McpResult::ToolContent(ToolContentResult {
578 content: vec![ContentBlock {
579 content_type: "text".to_string(),
580 text,
581 }],
582 }))
583}
584
585async fn get_data_product_details(
586 client: &mut AuthedClient,
587 name: &str,
588) -> Result<McpResult, McpRequestError> {
589 debug!(name = %name, "Executing get_data_product_details");
590
591 let query = format!(
592 "SELECT * FROM mz_internal.mz_mcp_data_products WHERE object_name = {}",
593 escaped_string_literal(name)
594 );
595
596 let rows = execute_sql(client, &query).await?;
597
598 if rows.is_empty() {
599 return Err(McpRequestError::DataProductNotFound(name.to_string()));
600 }
601
602 let text =
603 serde_json::to_string_pretty(&rows).map_err(|e| McpRequestError::Internal(anyhow!(e)))?;
604
605 Ok(McpResult::ToolContent(ToolContentResult {
606 content: vec![ContentBlock {
607 content_type: "text".to_string(),
608 text,
609 }],
610 }))
611}
612
613fn validate_readonly_query(sql: &str) -> Result<(), McpRequestError> {
615 let sql = sql.trim();
616 if sql.is_empty() {
617 return Err(McpRequestError::QueryValidationFailed(
618 "Empty query".to_string(),
619 ));
620 }
621
622 let stmts = parse(sql).map_err(|e| {
624 McpRequestError::QueryValidationFailed(format!("Failed to parse SQL: {}", e))
625 })?;
626
627 if stmts.len() != 1 {
629 return Err(McpRequestError::QueryValidationFailed(format!(
630 "Only one query allowed at a time. Found {} statements.",
631 stmts.len()
632 )));
633 }
634
635 let stmt = &stmts[0];
637 use mz_sql_parser::ast::Statement;
638
639 match &stmt.ast {
640 Statement::Select(_) | Statement::Show(_) | Statement::ExplainPlan(_) => {
641 Ok(())
643 }
644 _ => Err(McpRequestError::QueryValidationFailed(
645 "Only SELECT, SHOW, and EXPLAIN statements are allowed".to_string(),
646 )),
647 }
648}
649
650async fn execute_query(
651 client: &mut AuthedClient,
652 cluster: &str,
653 sql_query: &str,
654) -> Result<McpResult, McpRequestError> {
655 debug!(cluster = %cluster, "Executing user query");
656
657 validate_readonly_query(sql_query)?;
658
659 let combined_query = format!(
662 "BEGIN READ ONLY; SET CLUSTER = {}; {}; COMMIT;",
663 escaped_string_literal(cluster),
664 sql_query
665 );
666
667 let rows = execute_sql(client, &combined_query).await?;
668
669 let text =
670 serde_json::to_string_pretty(&rows).map_err(|e| McpRequestError::Internal(anyhow!(e)))?;
671
672 Ok(McpResult::ToolContent(ToolContentResult {
673 content: vec![ContentBlock {
674 content_type: "text".to_string(),
675 text,
676 }],
677 }))
678}
679
680async fn query_system_catalog(
681 client: &mut AuthedClient,
682 sql_query: &str,
683) -> Result<McpResult, McpRequestError> {
684 debug!("Executing query_system_catalog");
685
686 validate_readonly_query(sql_query)?;
688
689 validate_system_catalog_query(sql_query)?;
691
692 let wrapped_query = format!("BEGIN READ ONLY; {}; COMMIT;", sql_query);
694 let rows = execute_sql(client, &wrapped_query).await?;
695
696 let text =
697 serde_json::to_string_pretty(&rows).map_err(|e| McpRequestError::Internal(anyhow!(e)))?;
698
699 Ok(McpResult::ToolContent(ToolContentResult {
700 content: vec![ContentBlock {
701 content_type: "text".to_string(),
702 text,
703 }],
704 }))
705}
706
707struct TableReferenceCollector {
709 tables: Vec<(Option<String>, String)>,
711 cte_names: std::collections::BTreeSet<String>,
713}
714
715impl TableReferenceCollector {
716 fn new() -> Self {
717 Self {
718 tables: Vec::new(),
719 cte_names: std::collections::BTreeSet::new(),
720 }
721 }
722}
723
724impl<'ast> Visit<'ast, Raw> for TableReferenceCollector {
725 fn visit_cte(&mut self, cte: &'ast mz_sql_parser::ast::Cte<Raw>) {
726 self.cte_names
728 .insert(cte.alias.name.as_str().to_lowercase());
729 visit::visit_cte(self, cte);
730 }
731
732 fn visit_table_factor(&mut self, table_factor: &'ast mz_sql_parser::ast::TableFactor<Raw>) {
733 if let mz_sql_parser::ast::TableFactor::Table { name, .. } = table_factor {
735 match name {
736 RawItemName::Name(n) | RawItemName::Id(_, n, _) => {
737 let parts = &n.0;
738 if !parts.is_empty() {
739 let table_name = parts.last().unwrap().as_str().to_lowercase();
740
741 if self.cte_names.contains(&table_name) {
743 visit::visit_table_factor(self, table_factor);
744 return;
745 }
746
747 let schema = if parts.len() >= 2 {
749 Some(parts[parts.len() - 2].as_str().to_lowercase())
750 } else {
751 None
752 };
753 self.tables.push((schema, table_name));
754 }
755 }
756 }
757 }
758 visit::visit_table_factor(self, table_factor);
759 }
760}
761
762fn validate_system_catalog_query(sql: &str) -> Result<(), McpRequestError> {
764 let stmts = parse(sql).map_err(|e| {
766 McpRequestError::QueryValidationFailed(format!("Failed to parse SQL: {}", e))
767 })?;
768
769 if stmts.is_empty() {
770 return Err(McpRequestError::QueryValidationFailed(
771 "Empty query".to_string(),
772 ));
773 }
774
775 let mut collector = TableReferenceCollector::new();
777 for stmt in &stmts {
778 collector.visit_statement(&stmt.ast);
779 }
780
781 const ALLOWED_SCHEMAS: &[&str] = &[
783 "mz_catalog",
784 "mz_internal",
785 "pg_catalog",
786 "information_schema",
787 ];
788
789 let is_system_table = |(schema, table_name): &(Option<String>, String)| {
791 match schema {
792 Some(s) => ALLOWED_SCHEMAS.contains(&s.as_str()),
794 None => table_name.starts_with("mz_"),
796 }
797 };
798
799 let non_system_tables: Vec<String> = collector
801 .tables
802 .iter()
803 .filter(|t| !is_system_table(t))
804 .map(|(schema, table)| match schema {
805 Some(s) => format!("{}.{}", s, table),
806 None => table.clone(),
807 })
808 .collect();
809
810 if !non_system_tables.is_empty() {
811 return Err(McpRequestError::QueryValidationFailed(format!(
812 "Query references non-system tables: {}. Only system catalog tables (mz_*, pg_catalog, information_schema) are allowed.",
813 non_system_tables.join(", ")
814 )));
815 }
816
817 if collector.tables.is_empty() || !collector.tables.iter().any(is_system_table) {
819 return Err(McpRequestError::QueryValidationFailed(
820 "Query must reference at least one system catalog table".to_string(),
821 ));
822 }
823
824 Ok(())
825}
826
827#[cfg(test)]
828mod tests {
829 use super::*;
830
831 #[mz_ore::test]
832 fn test_validate_readonly_query_select() {
833 assert!(validate_readonly_query("SELECT * FROM mz_tables").is_ok());
834 assert!(validate_readonly_query("SELECT 1 + 2").is_ok());
835 assert!(validate_readonly_query(" SELECT 1 ").is_ok());
836 }
837
838 #[mz_ore::test]
839 fn test_validate_readonly_query_subqueries() {
840 assert!(
842 validate_readonly_query(
843 "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM mz_columns)"
844 )
845 .is_ok()
846 );
847
848 assert!(
850 validate_readonly_query(
851 "SELECT * FROM (SELECT * FROM mz_tables WHERE name LIKE 'test%') AS t"
852 )
853 .is_ok()
854 );
855
856 assert!(validate_readonly_query(
858 "SELECT * FROM mz_tables t WHERE EXISTS (SELECT 1 FROM mz_columns c WHERE c.id = t.id)"
859 )
860 .is_ok());
861
862 assert!(validate_readonly_query(
864 "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM mz_columns WHERE type IN (SELECT name FROM mz_types))"
865 )
866 .is_ok());
867
868 assert!(
870 validate_readonly_query(
871 "SELECT * FROM mz_tables WHERE id = (SELECT MAX(id) FROM mz_columns)"
872 )
873 .is_ok()
874 );
875 }
876
877 #[mz_ore::test]
878 fn test_validate_readonly_query_show() {
879 assert!(validate_readonly_query("SHOW CLUSTERS").is_ok());
880 assert!(validate_readonly_query("SHOW TABLES").is_ok());
881 }
882
883 #[mz_ore::test]
884 fn test_validate_readonly_query_explain() {
885 assert!(validate_readonly_query("EXPLAIN SELECT 1").is_ok());
886 }
887
888 #[mz_ore::test]
889 fn test_validate_readonly_query_rejects_writes() {
890 assert!(validate_readonly_query("INSERT INTO t VALUES (1)").is_err());
891 assert!(validate_readonly_query("UPDATE t SET a = 1").is_err());
892 assert!(validate_readonly_query("DELETE FROM t").is_err());
893 assert!(validate_readonly_query("CREATE TABLE t (a INT)").is_err());
894 assert!(validate_readonly_query("DROP TABLE t").is_err());
895 }
896
897 #[mz_ore::test]
898 fn test_validate_readonly_query_rejects_multiple() {
899 assert!(validate_readonly_query("SELECT 1; SELECT 2").is_err());
900 }
901
902 #[mz_ore::test]
903 fn test_validate_readonly_query_rejects_empty() {
904 assert!(validate_readonly_query("").is_err());
905 assert!(validate_readonly_query(" ").is_err());
906 }
907
908 #[mz_ore::test]
909 fn test_validate_system_catalog_query_accepts_mz_tables() {
910 assert!(validate_system_catalog_query("SELECT * FROM mz_tables").is_ok());
911 assert!(validate_system_catalog_query("SELECT * FROM mz_internal.mz_comments").is_ok());
912 assert!(
913 validate_system_catalog_query(
914 "SELECT * FROM mz_tables t JOIN mz_columns c ON t.id = c.id"
915 )
916 .is_ok()
917 );
918 }
919
920 #[mz_ore::test]
921 fn test_validate_system_catalog_query_subqueries() {
922 assert!(
924 validate_system_catalog_query(
925 "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM mz_columns)"
926 )
927 .is_ok()
928 );
929
930 assert!(validate_system_catalog_query(
932 "SELECT * FROM mz_tables WHERE id IN (SELECT table_id FROM mz_columns WHERE type IN (SELECT id FROM mz_types))"
933 )
934 .is_ok());
935
936 assert!(
938 validate_system_catalog_query(
939 "SELECT * FROM (SELECT * FROM mz_tables WHERE name LIKE 'test%') AS t"
940 )
941 .is_ok()
942 );
943
944 assert!(
946 validate_system_catalog_query(
947 "SELECT * FROM mz_tables WHERE id IN (SELECT table_id FROM user_data)"
948 )
949 .is_err()
950 );
951
952 assert!(validate_system_catalog_query(
954 "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM (SELECT id FROM user_table) AS t)"
955 )
956 .is_err());
957 }
958
959 #[mz_ore::test]
960 fn test_validate_system_catalog_query_rejects_user_tables() {
961 assert!(validate_system_catalog_query("SELECT * FROM user_data").is_err());
962 assert!(validate_system_catalog_query("SELECT * FROM my_table").is_err());
963 assert!(
965 validate_system_catalog_query("SELECT * FROM user_data WHERE 'mz_' IS NOT NULL")
966 .is_err()
967 );
968 }
969
970 #[mz_ore::test]
971 fn test_validate_system_catalog_query_allows_functions() {
972 assert!(
974 validate_system_catalog_query(
975 "SELECT date_part('year', now())::int4 AS y FROM mz_tables LIMIT 1"
976 )
977 .is_ok()
978 );
979 assert!(validate_system_catalog_query("SELECT length(name) FROM mz_tables").is_ok());
980 assert!(
981 validate_system_catalog_query(
982 "SELECT count(*) FROM mz_sources WHERE now() > created_at"
983 )
984 .is_ok()
985 );
986 }
987
988 #[mz_ore::test]
989 fn test_validate_system_catalog_query_schema_qualified() {
990 assert!(validate_system_catalog_query("SELECT * FROM mz_catalog.mz_tables").is_ok());
992 assert!(validate_system_catalog_query("SELECT * FROM mz_internal.mz_sessions").is_ok());
993 assert!(validate_system_catalog_query("SELECT * FROM pg_catalog.pg_type").is_ok());
994 assert!(validate_system_catalog_query("SELECT * FROM information_schema.tables").is_ok());
995
996 assert!(validate_system_catalog_query("SELECT * FROM public.user_table").is_err());
998 assert!(validate_system_catalog_query("SELECT * FROM myschema.mytable").is_err());
999
1000 assert!(
1002 validate_system_catalog_query(
1003 "SELECT * FROM mz_catalog.mz_tables JOIN public.user_data ON true"
1004 )
1005 .is_err()
1006 );
1007 }
1008
1009 #[mz_ore::test]
1010 fn test_validate_system_catalog_query_adversarial_cases() {
1011 assert!(
1013 validate_system_catalog_query(
1014 "WITH user_cte AS (SELECT * FROM user_data) \
1015 SELECT * FROM mz_tables, user_cte"
1016 )
1017 .is_err(),
1018 "Should reject CTE referencing user table"
1019 );
1020
1021 assert!(
1023 validate_system_catalog_query(
1024 "WITH \
1025 cte1 AS (SELECT * FROM mz_tables), \
1026 cte2 AS (SELECT * FROM cte1), \
1027 cte3 AS (SELECT * FROM user_data) \
1028 SELECT * FROM cte2"
1029 )
1030 .is_err(),
1031 "Should reject CTE chain with user table"
1032 );
1033
1034 assert!(
1036 validate_system_catalog_query(
1037 "SELECT * FROM mz_tables t1 \
1038 JOIN user_data u ON t1.id = u.id \
1039 JOIN mz_sources s ON t1.id = s.id"
1040 )
1041 .is_err(),
1042 "Should reject multi-join with user table"
1043 );
1044
1045 assert!(
1047 validate_system_catalog_query(
1048 "SELECT * FROM mz_tables t \
1049 LEFT JOIN user_data u ON t.id = u.table_id \
1050 WHERE u.id IS NULL"
1051 )
1052 .is_err(),
1053 "Should reject LEFT JOIN with user table"
1054 );
1055
1056 assert!(
1058 validate_system_catalog_query(
1059 "SELECT * FROM mz_tables WHERE id IN \
1060 (SELECT table_id FROM (SELECT * FROM user_data) AS u)"
1061 )
1062 .is_err(),
1063 "Should reject nested subquery with user table"
1064 );
1065
1066 assert!(
1068 validate_system_catalog_query(
1069 "SELECT name FROM mz_tables \
1070 UNION \
1071 SELECT name FROM user_data"
1072 )
1073 .is_err(),
1074 "Should reject UNION with user table"
1075 );
1076
1077 assert!(
1079 validate_system_catalog_query(
1080 "SELECT id FROM mz_sources \
1081 UNION ALL \
1082 SELECT id FROM products"
1083 )
1084 .is_err(),
1085 "Should reject UNION ALL with user table"
1086 );
1087
1088 assert!(
1090 validate_system_catalog_query("SELECT * FROM mz_tables CROSS JOIN user_data").is_err(),
1091 "Should reject CROSS JOIN with user table"
1092 );
1093
1094 assert!(
1096 validate_system_catalog_query(
1097 "SELECT t.*, (SELECT COUNT(*) FROM user_data) AS cnt FROM mz_tables t"
1098 )
1099 .is_err(),
1100 "Should reject subquery in SELECT with user table"
1101 );
1102
1103 assert!(
1105 validate_system_catalog_query("SELECT * FROM mz_catalogg.fake_table").is_err(),
1106 "Should reject typo-squatting schema name"
1107 );
1108 assert!(
1109 validate_system_catalog_query("SELECT * FROM mz_catalog_hack.fake_table").is_err(),
1110 "Should reject fake schema with mz_catalog prefix"
1111 );
1112
1113 assert!(
1115 validate_system_catalog_query(
1116 "SELECT * FROM mz_tables t, LATERAL (SELECT * FROM user_data WHERE id = t.id) u"
1117 )
1118 .is_err(),
1119 "Should reject LATERAL join with user table"
1120 );
1121
1122 assert!(
1124 validate_system_catalog_query(
1125 "WITH \
1126 tables AS (SELECT * FROM mz_tables), \
1127 sources AS (SELECT * FROM mz_sources) \
1128 SELECT t.name, s.name \
1129 FROM tables t \
1130 JOIN sources s ON t.id = s.id \
1131 WHERE t.id IN (SELECT id FROM mz_columns)"
1132 )
1133 .is_ok(),
1134 "Should allow complex query with only system tables"
1135 );
1136
1137 assert!(
1139 validate_system_catalog_query(
1140 "SELECT name FROM mz_tables \
1141 UNION \
1142 SELECT name FROM mz_sources"
1143 )
1144 .is_ok(),
1145 "Should allow UNION of system tables"
1146 );
1147 }
1148
1149 #[mz_ore::test]
1150 fn test_validate_system_catalog_query_rejects_mixed_tables() {
1151 assert!(
1152 validate_system_catalog_query(
1153 "SELECT * FROM mz_tables t JOIN user_data u ON t.id = u.table_id"
1154 )
1155 .is_err()
1156 );
1157 }
1158
1159 #[mz_ore::test]
1160 fn test_mcp_error_codes() {
1161 assert_eq!(
1162 McpRequestError::InvalidJsonRpcVersion.error_code(),
1163 error_codes::INVALID_REQUEST
1164 );
1165 assert_eq!(
1166 McpRequestError::MethodNotFound("test".to_string()).error_code(),
1167 error_codes::METHOD_NOT_FOUND
1168 );
1169 assert_eq!(
1170 McpRequestError::QueryExecutionFailed("test".to_string()).error_code(),
1171 error_codes::INTERNAL_ERROR
1172 );
1173 }
1174}