1use std::sync::Arc;
27use std::time::Duration;
28
29use anyhow::anyhow;
30use axum::Extension;
31use axum::Json;
32use axum::response::IntoResponse;
33use http::{HeaderMap, HeaderValue, StatusCode};
34use mz_adapter_types::dyncfgs::{
35 ENABLE_MCP_AGENT, ENABLE_MCP_AGENT_QUERY_TOOL, ENABLE_MCP_DEVELOPER, MCP_MAX_RESPONSE_SIZE,
36};
37use mz_repr::namespaces::{self, SYSTEM_SCHEMAS};
38use mz_sql::parse::parse;
39use mz_sql::session::metadata::SessionMetadata;
40use mz_sql_parser::ast::display::{AstDisplay, escaped_string_literal};
41use mz_sql_parser::ast::visit::{self, Visit};
42use mz_sql_parser::ast::{Raw, RawItemName};
43use mz_sql_parser::parser::parse_item_name;
44use serde::{Deserialize, Serialize};
45use serde_json::json;
46use thiserror::Error;
47use tracing::{debug, warn};
48
49use crate::http::AuthedClient;
50use crate::http::sql::{SqlRequest, SqlResponse, SqlResult, execute_request};
51
52const JSONRPC_VERSION: &str = "2.0";
56
57const MCP_PROTOCOL_VERSION: &str = "2025-11-25";
60
61const MCP_REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
66
67const DISCOVERY_QUERY: &str = "SELECT * FROM mz_internal.mz_mcp_data_products";
69const DETAILS_QUERY_PREFIX: &str =
71 "SELECT * FROM mz_internal.mz_mcp_data_product_details WHERE object_name = ";
72
73#[derive(Debug, Error)]
75enum McpRequestError {
76 #[error("Invalid JSON-RPC version: expected 2.0")]
77 InvalidJsonRpcVersion,
78 #[error("Method not found: {0}")]
79 #[allow(dead_code)] MethodNotFound(String),
81 #[error("Tool not found: {0}")]
82 ToolNotFound(String),
83 #[error("Data product not found: {0}")]
84 DataProductNotFound(String),
85 #[error("Query validation failed: {0}")]
86 QueryValidationFailed(String),
87 #[error("Query execution failed: {0}")]
88 QueryExecutionFailed(String),
89 #[error("Internal error: {0}")]
90 Internal(#[from] anyhow::Error),
91}
92
93impl McpRequestError {
94 fn error_code(&self) -> i32 {
95 match self {
96 Self::InvalidJsonRpcVersion => error_codes::INVALID_REQUEST,
97 Self::MethodNotFound(_) => error_codes::METHOD_NOT_FOUND,
98 Self::ToolNotFound(_) => error_codes::INVALID_PARAMS,
99 Self::DataProductNotFound(_) => error_codes::INVALID_PARAMS,
100 Self::QueryValidationFailed(_) => error_codes::INVALID_PARAMS,
101 Self::QueryExecutionFailed(_) | Self::Internal(_) => error_codes::INTERNAL_ERROR,
102 }
103 }
104
105 fn error_type(&self) -> &'static str {
106 match self {
107 Self::InvalidJsonRpcVersion => "InvalidRequest",
108 Self::MethodNotFound(_) => "MethodNotFound",
109 Self::ToolNotFound(_) => "ToolNotFound",
110 Self::DataProductNotFound(_) => "DataProductNotFound",
111 Self::QueryValidationFailed(_) => "ValidationError",
112 Self::QueryExecutionFailed(_) => "ExecutionError",
113 Self::Internal(_) => "InternalError",
114 }
115 }
116}
117
118#[derive(Debug, Deserialize)]
120pub(crate) struct McpRequest {
121 jsonrpc: String,
122 id: Option<serde_json::Value>,
123 #[serde(flatten)]
124 method: McpMethod,
125}
126
127#[derive(Debug, Deserialize)]
129#[serde(tag = "method", content = "params")]
130enum McpMethod {
131 #[serde(rename = "initialize")]
133 Initialize(#[allow(dead_code)] InitializeParams),
134 #[serde(rename = "tools/list")]
135 ToolsList,
136 #[serde(rename = "tools/call")]
137 ToolsCall(ToolsCallParams),
138 #[serde(other)]
140 Unknown,
141}
142
143impl std::fmt::Display for McpMethod {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 match self {
146 McpMethod::Initialize(_) => write!(f, "initialize"),
147 McpMethod::ToolsList => write!(f, "tools/list"),
148 McpMethod::ToolsCall(_) => write!(f, "tools/call"),
149 McpMethod::Unknown => write!(f, "unknown"),
150 }
151 }
152}
153
154#[derive(Debug, Deserialize)]
155struct InitializeParams {
156 #[serde(rename = "protocolVersion")]
158 #[allow(dead_code)]
159 protocol_version: String,
160 #[serde(default)]
162 #[allow(dead_code)]
163 capabilities: serde_json::Value,
164 #[serde(rename = "clientInfo")]
166 #[allow(dead_code)]
167 client_info: Option<ClientInfo>,
168}
169
170#[derive(Debug, Deserialize)]
171struct ClientInfo {
172 #[allow(dead_code)]
173 name: String,
174 #[allow(dead_code)]
175 version: String,
176}
177
178#[derive(Debug, Deserialize)]
181#[serde(tag = "name", content = "arguments")]
182#[serde(rename_all = "snake_case")]
183enum ToolsCallParams {
184 GetDataProducts(#[serde(default)] ()),
187 GetDataProductDetails(GetDataProductDetailsParams),
188 ReadDataProduct(ReadDataProductParams),
189 Query(QueryParams),
190 QuerySystemCatalog(QuerySystemCatalogParams),
192}
193
194impl std::fmt::Display for ToolsCallParams {
195 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196 match self {
197 ToolsCallParams::GetDataProducts(_) => write!(f, "get_data_products"),
198 ToolsCallParams::GetDataProductDetails(_) => write!(f, "get_data_product_details"),
199 ToolsCallParams::ReadDataProduct(_) => write!(f, "read_data_product"),
200 ToolsCallParams::Query(_) => write!(f, "query"),
201 ToolsCallParams::QuerySystemCatalog(_) => write!(f, "query_system_catalog"),
202 }
203 }
204}
205
206#[derive(Debug, Deserialize)]
207struct GetDataProductDetailsParams {
208 name: String,
209}
210
211#[derive(Debug, Deserialize)]
212struct ReadDataProductParams {
213 name: String,
214 #[serde(default = "default_read_limit")]
215 limit: u32,
216 cluster: Option<String>,
217}
218
219fn default_read_limit() -> u32 {
220 500
221}
222
223const MAX_READ_LIMIT: u32 = 1000;
225
226#[derive(Debug, Deserialize)]
227struct QueryParams {
228 cluster: String,
229 sql_query: String,
230}
231
232#[derive(Debug, Deserialize)]
233struct QuerySystemCatalogParams {
234 sql_query: String,
235}
236
237#[derive(Debug, Serialize)]
238struct McpResponse {
239 jsonrpc: String,
240 id: serde_json::Value,
241 #[serde(skip_serializing_if = "Option::is_none")]
242 result: Option<McpResult>,
243 #[serde(skip_serializing_if = "Option::is_none")]
244 error: Option<McpError>,
245}
246
247#[derive(Debug, Serialize)]
249#[serde(untagged)]
250enum McpResult {
251 Initialize(InitializeResult),
252 ToolsList(ToolsListResult),
253 ToolContent(ToolContentResult),
254}
255
256#[derive(Debug, Serialize)]
257struct InitializeResult {
258 #[serde(rename = "protocolVersion")]
259 protocol_version: String,
260 capabilities: Capabilities,
261 #[serde(rename = "serverInfo")]
262 server_info: ServerInfo,
263 #[serde(skip_serializing_if = "Option::is_none")]
264 instructions: Option<String>,
265}
266
267#[derive(Debug, Serialize)]
268struct Capabilities {
269 tools: serde_json::Value,
270}
271
272#[derive(Debug, Serialize)]
273struct ServerInfo {
274 name: String,
275 version: String,
276}
277
278#[derive(Debug, Serialize)]
279struct ToolsListResult {
280 tools: Vec<ToolDefinition>,
281}
282
283#[derive(Debug, Serialize)]
284struct ToolDefinition {
285 name: String,
286 #[serde(skip_serializing_if = "Option::is_none")]
287 title: Option<String>,
288 description: String,
289 #[serde(rename = "inputSchema")]
290 input_schema: serde_json::Value,
291 #[serde(skip_serializing_if = "Option::is_none")]
292 annotations: Option<ToolAnnotations>,
293}
294
295#[derive(Debug, Serialize)]
298struct ToolAnnotations {
299 #[serde(rename = "readOnlyHint", skip_serializing_if = "Option::is_none")]
300 read_only_hint: Option<bool>,
301 #[serde(rename = "destructiveHint", skip_serializing_if = "Option::is_none")]
302 destructive_hint: Option<bool>,
303 #[serde(rename = "idempotentHint", skip_serializing_if = "Option::is_none")]
304 idempotent_hint: Option<bool>,
305 #[serde(rename = "openWorldHint", skip_serializing_if = "Option::is_none")]
306 open_world_hint: Option<bool>,
307}
308
309const READ_ONLY_ANNOTATIONS: ToolAnnotations = ToolAnnotations {
311 read_only_hint: Some(true),
312 destructive_hint: Some(false),
313 idempotent_hint: Some(true),
314 open_world_hint: Some(false),
315};
316
317#[derive(Debug, Serialize)]
318struct ToolContentResult {
319 content: Vec<ContentBlock>,
320 #[serde(rename = "isError")]
321 is_error: bool,
322}
323
324#[derive(Debug, Serialize)]
325struct ContentBlock {
326 #[serde(rename = "type")]
327 content_type: String,
328 text: String,
329}
330
331mod error_codes {
333 pub const INVALID_REQUEST: i32 = -32600;
334 pub const METHOD_NOT_FOUND: i32 = -32601;
335 pub const INVALID_PARAMS: i32 = -32602;
336 pub const INTERNAL_ERROR: i32 = -32603;
337}
338
339#[derive(Debug, Serialize)]
340struct McpError {
341 code: i32,
342 message: String,
343 #[serde(skip_serializing_if = "Option::is_none")]
344 data: Option<serde_json::Value>,
345}
346
347impl From<McpRequestError> for McpError {
348 fn from(err: McpRequestError) -> Self {
349 McpError {
350 code: err.error_code(),
351 message: err.to_string(),
352 data: Some(json!({
353 "error_type": err.error_type(),
354 })),
355 }
356 }
357}
358
359#[derive(Debug, Clone, Copy)]
360enum McpEndpointType {
361 Agent,
362 Developer,
363}
364
365impl std::fmt::Display for McpEndpointType {
366 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
367 match self {
368 McpEndpointType::Agent => write!(f, "agent"),
369 McpEndpointType::Developer => write!(f, "developer"),
370 }
371 }
372}
373
374pub async fn handle_mcp_method_not_allowed() -> impl IntoResponse {
377 StatusCode::METHOD_NOT_ALLOWED
378}
379
380pub async fn handle_mcp_agent(
382 headers: HeaderMap,
383 Extension(allowed_origins): Extension<Arc<Vec<HeaderValue>>>,
384 client: AuthedClient,
385 Json(body): Json<McpRequest>,
386) -> axum::response::Response {
387 if let Some(resp) = validate_origin(&headers, &allowed_origins) {
388 return resp;
389 }
390 handle_mcp_request(client, body, McpEndpointType::Agent)
391 .await
392 .into_response()
393}
394
395pub async fn handle_mcp_developer(
397 headers: HeaderMap,
398 Extension(allowed_origins): Extension<Arc<Vec<HeaderValue>>>,
399 client: AuthedClient,
400 Json(body): Json<McpRequest>,
401) -> axum::response::Response {
402 if let Some(resp) = validate_origin(&headers, &allowed_origins) {
403 return resp;
404 }
405 handle_mcp_request(client, body, McpEndpointType::Developer)
406 .await
407 .into_response()
408}
409
410fn validate_origin(
419 headers: &HeaderMap,
420 allowed: &[HeaderValue],
421) -> Option<axum::response::Response> {
422 let origin = headers.get(http::header::ORIGIN)?;
423 if mz_http_util::origin_is_allowed(origin, allowed) {
424 return None;
425 }
426 warn!(
427 origin = ?origin,
428 "MCP request rejected: origin not in allowlist",
429 );
430 Some(StatusCode::FORBIDDEN.into_response())
431}
432
433async fn handle_mcp_request(
434 mut client: AuthedClient,
435 request: McpRequest,
436 endpoint_type: McpEndpointType,
437) -> impl IntoResponse {
438 let catalog = client.client.catalog_snapshot("mcp").await;
440 let dyncfgs = catalog.system_config().dyncfgs();
441 let enabled = match endpoint_type {
442 McpEndpointType::Agent => ENABLE_MCP_AGENT.get(dyncfgs),
443 McpEndpointType::Developer => ENABLE_MCP_DEVELOPER.get(dyncfgs),
444 };
445 if !enabled {
446 debug!(endpoint = %endpoint_type, "MCP endpoint disabled by feature flag");
447 return StatusCode::SERVICE_UNAVAILABLE.into_response();
448 }
449
450 let query_tool_enabled = ENABLE_MCP_AGENT_QUERY_TOOL.get(dyncfgs);
451 let max_response_size = MCP_MAX_RESPONSE_SIZE.get(dyncfgs);
452
453 let user = client.client.session().user().name.clone();
454 let is_notification = request.id.is_none();
455
456 debug!(
457 method = %request.method,
458 endpoint = %endpoint_type,
459 user = %user,
460 is_notification = is_notification,
461 "MCP request received"
462 );
463
464 if is_notification {
466 debug!(method = %request.method, "Received notification (no response will be sent)");
467 return StatusCode::OK.into_response();
468 }
469
470 let request_id = request.id.clone().unwrap_or(serde_json::Value::Null);
471
472 let result = tokio::time::timeout(
474 MCP_REQUEST_TIMEOUT,
475 mz_ore::task::spawn(|| "mcp_request", async move {
476 handle_mcp_request_inner(
477 &mut client,
478 request,
479 endpoint_type,
480 query_tool_enabled,
481 max_response_size,
482 )
483 .await
484 }),
485 )
486 .await;
487
488 let response = match result {
489 Ok(inner) => inner,
490 Err(_elapsed) => {
491 warn!(
492 endpoint = %endpoint_type,
493 timeout = ?MCP_REQUEST_TIMEOUT,
494 "MCP request timed out",
495 );
496 McpResponse {
497 jsonrpc: JSONRPC_VERSION.to_string(),
498 id: request_id,
499 result: None,
500 error: Some(
501 McpRequestError::QueryExecutionFailed(format!(
502 "Request timed out after {} seconds.",
503 MCP_REQUEST_TIMEOUT.as_secs(),
504 ))
505 .into(),
506 ),
507 }
508 }
509 };
510
511 (StatusCode::OK, Json(response)).into_response()
512}
513
514async fn handle_mcp_request_inner(
515 client: &mut AuthedClient,
516 request: McpRequest,
517 endpoint_type: McpEndpointType,
518 query_tool_enabled: bool,
519 max_response_size: usize,
520) -> McpResponse {
521 let request_id = request.id.clone().unwrap_or(serde_json::Value::Null);
523
524 let result = handle_mcp_method(
525 client,
526 &request,
527 endpoint_type,
528 query_tool_enabled,
529 max_response_size,
530 )
531 .await;
532
533 match result {
534 Ok(result_value) => McpResponse {
535 jsonrpc: JSONRPC_VERSION.to_string(),
536 id: request_id,
537 result: Some(result_value),
538 error: None,
539 },
540 Err(e) => {
541 if !matches!(
543 e,
544 McpRequestError::MethodNotFound(_) | McpRequestError::InvalidJsonRpcVersion
545 ) {
546 warn!(error = %e, method = %request.method, "MCP method execution failed");
547 }
548 McpResponse {
549 jsonrpc: JSONRPC_VERSION.to_string(),
550 id: request_id,
551 result: None,
552 error: Some(e.into()),
553 }
554 }
555 }
556}
557
558async fn handle_mcp_method(
559 client: &mut AuthedClient,
560 request: &McpRequest,
561 endpoint_type: McpEndpointType,
562 query_tool_enabled: bool,
563 max_response_size: usize,
564) -> Result<McpResult, McpRequestError> {
565 if request.jsonrpc != JSONRPC_VERSION {
567 return Err(McpRequestError::InvalidJsonRpcVersion);
568 }
569
570 match &request.method {
572 McpMethod::Initialize(_) => {
573 debug!(endpoint = %endpoint_type, "Processing initialize");
574 handle_initialize(endpoint_type).await
575 }
576 McpMethod::ToolsList => {
577 debug!(endpoint = %endpoint_type, "Processing tools/list");
578 handle_tools_list(endpoint_type, query_tool_enabled, max_response_size).await
579 }
580 McpMethod::ToolsCall(params) => {
581 debug!(tool = %params, endpoint = %endpoint_type, "Processing tools/call");
582 handle_tools_call(
583 client,
584 params,
585 endpoint_type,
586 query_tool_enabled,
587 max_response_size,
588 )
589 .await
590 }
591 McpMethod::Unknown => Err(McpRequestError::MethodNotFound(
592 "unknown method".to_string(),
593 )),
594 }
595}
596
597fn endpoint_instructions(endpoint_type: McpEndpointType) -> Option<String> {
600 match endpoint_type {
601 McpEndpointType::Agent => None,
602 McpEndpointType::Developer => Some(concat!(
603 "You are connected to the Materialize developer MCP server. ",
604 "You have read-only access to system catalog tables (mz_*, pg_catalog, information_schema) ",
605 "for troubleshooting and observability.\n\n",
606 "IMPORTANT: Before writing queries, discover table schemas using the mz_ontology tables:\n",
607 "- mz_internal.mz_ontology_entity_types: what catalog entities exist and which tables they map to\n",
608 "- mz_internal.mz_ontology_link_types: relationships between entities (foreign keys, metrics, etc.)\n",
609 "- mz_internal.mz_ontology_properties: column names, types, and descriptions for each entity\n",
610 "- mz_internal.mz_ontology_semantic_types: typed ID domains (CatalogItemId, ReplicaId, etc.)\n\n",
611 "Use these to find the correct tables, join paths, and column names instead of guessing.\n\n",
612 "Key rules:\n",
613 "- mz_source_statuses and mz_sink_statuses use `last_status_change_at` (NOT `updated_at`)\n",
614 "- mz_cluster_replica_utilization only has `replica_id` — JOIN with mz_cluster_replicas and mz_clusters to get names\n",
615 "- Do NOT query mz_introspection.mz_dataflow_arrangement_sizes — it is cluster-scoped and has uint8/text type mismatches\n",
616 "- Use SHOW COLUMNS FROM <table> to verify column names if unsure",
617 ).to_string()),
618 }
619}
620
621async fn handle_initialize(endpoint_type: McpEndpointType) -> Result<McpResult, McpRequestError> {
622 Ok(McpResult::Initialize(InitializeResult {
623 protocol_version: MCP_PROTOCOL_VERSION.to_string(),
624 capabilities: Capabilities { tools: json!({}) },
625 server_info: ServerInfo {
626 name: format!("materialize-mcp-{}", endpoint_type),
627 version: env!("CARGO_PKG_VERSION").to_string(),
628 },
629 instructions: endpoint_instructions(endpoint_type),
630 }))
631}
632
633async fn handle_tools_list(
634 endpoint_type: McpEndpointType,
635 query_tool_enabled: bool,
636 max_response_size: usize,
637) -> Result<McpResult, McpRequestError> {
638 let size_hint = format!("Response limit: {} MB.", max_response_size / 1_000_000);
639
640 let tools = match endpoint_type {
641 McpEndpointType::Agent => {
642 let mut tools = vec![
643 ToolDefinition {
644 name: "get_data_products".to_string(),
645 title: Some("List Data Products".to_string()),
646 description: "Discover all available real-time data views (data products) that represent business entities like customers, orders, products, etc. Each data product provides fresh, queryable data with defined schemas. Use this first to see what data is available before querying specific information.".to_string(),
647 input_schema: json!({
648 "type": "object",
649 "properties": {},
650 "required": []
651 }),
652 annotations: Some(READ_ONLY_ANNOTATIONS),
653 },
654 ToolDefinition {
655 name: "get_data_product_details".to_string(),
656 title: Some("Get Data Product Details".to_string()),
657 description: "Get the complete schema and structure of a specific data product. This shows you exactly what fields are available, their types, and what data you can query. Use this after finding a data product from get_data_products() to understand how to query it.".to_string(),
658 input_schema: json!({
659 "type": "object",
660 "properties": {
661 "name": {
662 "type": "string",
663 "description": "Exact name of the data product from get_data_products() list"
664 }
665 },
666 "required": ["name"]
667 }),
668 annotations: Some(READ_ONLY_ANNOTATIONS),
669 },
670 ToolDefinition {
671 name: "read_data_product".to_string(),
672 title: Some("Read Data Product".to_string()),
673 description: format!("Read rows from a specific data product. Returns up to `limit` rows (default 500, max 1000). The data product must exist in the catalog (use get_data_products() to discover available products). Use this to retrieve actual data from a known data product. {size_hint}"),
674 input_schema: json!({
675 "type": "object",
676 "properties": {
677 "name": {
678 "type": "string",
679 "description": "Exact fully-qualified name of the data product (e.g. '\"materialize\".\"schema\".\"view_name\"')"
680 },
681 "limit": {
682 "type": "integer",
683 "description": "Maximum number of rows to return (default 500, max 1000)",
684 "default": 500
685 },
686 "cluster": {
687 "type": "string",
688 "description": "Optional cluster override. If omitted, uses the cluster from the data product catalog."
689 }
690 },
691 "required": ["name"]
692 }),
693 annotations: Some(READ_ONLY_ANNOTATIONS),
694 },
695 ];
696 if query_tool_enabled {
697 tools.push(ToolDefinition {
698 name: "query".to_string(),
699 title: Some("Query Data Products".to_string()),
700 description: format!("Execute SQL queries against real-time data products to retrieve current business information. Use standard PostgreSQL syntax. You can JOIN multiple data products together, but ONLY if they are all hosted on the same cluster. Always specify the cluster parameter from the data product details. This provides fresh, up-to-date results from materialized views. {size_hint}"),
701 input_schema: json!({
702 "type": "object",
703 "properties": {
704 "cluster": {
705 "type": "string",
706 "description": "Exact cluster name from the data product details - required for query execution"
707 },
708 "sql_query": {
709 "type": "string",
710 "description": "PostgreSQL-compatible SELECT statement to retrieve data. Use the fully qualified data product name exactly as provided (with double quotes). You can JOIN multiple data products, but only those on the same cluster."
711 }
712 },
713 "required": ["cluster", "sql_query"]
714 }),
715 annotations: Some(READ_ONLY_ANNOTATIONS),
716 });
717 }
718 tools
719 }
720 McpEndpointType::Developer => {
721 vec![ToolDefinition {
722 name: "query_system_catalog".to_string(),
723 title: Some("Query System Catalog".to_string()),
724 description: concat!(
725 "Query Materialize system catalog tables for troubleshooting and observability. ",
726 "Only mz_*, pg_catalog, and information_schema tables are accessible. ",
727 "Use the mz_internal.mz_ontology_* tables to discover tables, columns, and join paths before writing queries.",
728 ).to_owned() + &format!(" {size_hint}"),
729 input_schema: json!({
730 "type": "object",
731 "properties": {
732 "sql_query": {
733 "type": "string",
734 "description": "PostgreSQL-compatible SELECT, SHOW, or EXPLAIN query referencing mz_* system catalog tables"
735 }
736 },
737 "required": ["sql_query"]
738 }),
739 annotations: Some(READ_ONLY_ANNOTATIONS),
740 }]
741 }
742 };
743
744 Ok(McpResult::ToolsList(ToolsListResult { tools }))
745}
746
747async fn handle_tools_call(
748 client: &mut AuthedClient,
749 params: &ToolsCallParams,
750 endpoint_type: McpEndpointType,
751 query_tool_enabled: bool,
752 max_response_size: usize,
753) -> Result<McpResult, McpRequestError> {
754 match (endpoint_type, params) {
755 (McpEndpointType::Agent, ToolsCallParams::GetDataProducts(_)) => {
756 get_data_products(client, max_response_size).await
757 }
758 (McpEndpointType::Agent, ToolsCallParams::GetDataProductDetails(p)) => {
759 get_data_product_details(client, &p.name, max_response_size).await
760 }
761 (McpEndpointType::Agent, ToolsCallParams::ReadDataProduct(p)) => {
762 read_data_product(
763 client,
764 &p.name,
765 p.limit,
766 p.cluster.as_deref(),
767 max_response_size,
768 )
769 .await
770 }
771 (McpEndpointType::Agent, ToolsCallParams::Query(_)) if !query_tool_enabled => {
772 Err(McpRequestError::ToolNotFound(
773 "query tool is not available. Use get_data_products, get_data_product_details, and read_data_product instead.".to_string(),
774 ))
775 }
776 (McpEndpointType::Agent, ToolsCallParams::Query(p)) => {
777 execute_query(client, &p.cluster, &p.sql_query, max_response_size).await
778 }
779 (McpEndpointType::Developer, ToolsCallParams::QuerySystemCatalog(p)) => {
780 query_system_catalog(client, &p.sql_query, max_response_size).await
781 }
782 (endpoint, tool) => Err(McpRequestError::ToolNotFound(format!(
784 "{} is not available on {} endpoint",
785 tool, endpoint
786 ))),
787 }
788}
789
790async fn execute_sql(
792 client: &mut AuthedClient,
793 query: &str,
794) -> Result<Vec<Vec<serde_json::Value>>, McpRequestError> {
795 let mut response = SqlResponse::new();
796
797 execute_request(
798 client,
799 SqlRequest::Simple {
800 query: query.to_string(),
801 },
802 &mut response,
803 )
804 .await
805 .map_err(|e| McpRequestError::QueryExecutionFailed(e.to_string()))?;
806
807 for result in response.results {
810 match result {
811 SqlResult::Rows { rows, .. } => return Ok(rows),
812 SqlResult::Err { error, .. } => {
813 return Err(McpRequestError::QueryExecutionFailed(error.message));
814 }
815 SqlResult::Ok { .. } => continue,
816 }
817 }
818
819 Err(McpRequestError::QueryExecutionFailed(
820 "Query did not return any results".to_string(),
821 ))
822}
823
824fn format_rows_response(
831 rows: Vec<Vec<serde_json::Value>>,
832 max_size: usize,
833) -> Result<McpResult, McpRequestError> {
834 let text =
835 serde_json::to_string_pretty(&rows).map_err(|e| McpRequestError::Internal(anyhow!(e)))?;
836
837 if text.len() > max_size {
838 return Err(McpRequestError::QueryExecutionFailed(format!(
839 "Response size ({} bytes) exceeds the {} byte limit. \
840 Use LIMIT or WHERE to narrow your query.",
841 text.len(),
842 max_size,
843 )));
844 }
845
846 Ok(McpResult::ToolContent(ToolContentResult {
847 content: vec![ContentBlock {
848 content_type: "text".to_string(),
849 text,
850 }],
851 is_error: false,
852 }))
853}
854
855async fn get_data_products(
856 client: &mut AuthedClient,
857 max_response_size: usize,
858) -> Result<McpResult, McpRequestError> {
859 debug!("Executing get_data_products");
860 let rows = execute_sql(client, DISCOVERY_QUERY).await?;
861 debug!("get_data_products returned {} rows", rows.len());
862
863 format_rows_response(rows, max_response_size)
864}
865
866async fn get_data_product_details(
867 client: &mut AuthedClient,
868 name: &str,
869 max_response_size: usize,
870) -> Result<McpResult, McpRequestError> {
871 debug!(name = %name, "Executing get_data_product_details");
872
873 let query = format!("{}{}", DETAILS_QUERY_PREFIX, escaped_string_literal(name));
874
875 let rows = execute_sql(client, &query).await?;
876
877 if rows.is_empty() {
878 return Err(McpRequestError::DataProductNotFound(name.to_string()));
879 }
880
881 format_rows_response(rows, max_response_size)
882}
883
884fn safe_data_product_name(name: &str) -> Result<String, McpRequestError> {
891 let name = name.trim();
892 if name.is_empty() {
893 return Err(McpRequestError::QueryValidationFailed(
894 "Data product name cannot be empty".to_string(),
895 ));
896 }
897
898 let parsed = parse_item_name(name).map_err(|_| {
899 McpRequestError::QueryValidationFailed(format!(
900 "Invalid data product name: {}. Expected a valid object name, \
901 e.g. '\"database\".\"schema\".\"name\"' or 'my_view'",
902 name
903 ))
904 })?;
905
906 Ok(parsed.to_ast_string_stable())
909}
910
911async fn read_data_product(
919 client: &mut AuthedClient,
920 name: &str,
921 limit: u32,
922 cluster_override: Option<&str>,
923 max_response_size: usize,
924) -> Result<McpResult, McpRequestError> {
925 debug!(name = %name, limit = limit, cluster_override = ?cluster_override, "Executing read_data_product");
926
927 let safe_name = safe_data_product_name(name)?;
929
930 let lookup_query = format!(
938 "SELECT 1 FROM mz_internal.mz_mcp_data_products WHERE object_name = {}",
939 escaped_string_literal(name)
940 );
941 let lookup_rows = execute_sql(client, &lookup_query).await?;
942 if lookup_rows.is_empty() {
943 return Err(McpRequestError::DataProductNotFound(name.to_string()));
944 }
945
946 let clamped_limit = limit.min(MAX_READ_LIMIT);
947
948 let read_query = match cluster_override {
949 Some(cluster) => format!(
950 "BEGIN READ ONLY; SET CLUSTER = {}; SELECT * FROM {} LIMIT {}\n; COMMIT;",
951 escaped_string_literal(cluster),
952 safe_name,
953 clamped_limit,
954 ),
955 None => format!("SELECT * FROM {} LIMIT {}", safe_name, clamped_limit),
957 };
958
959 let rows = execute_sql(client, &read_query).await?;
960
961 format_rows_response(rows, max_response_size)
962}
963
964fn validate_readonly_query(sql: &str) -> Result<(), McpRequestError> {
966 let sql = sql.trim();
967 if sql.is_empty() {
968 return Err(McpRequestError::QueryValidationFailed(
969 "Empty query".to_string(),
970 ));
971 }
972
973 let stmts = parse(sql).map_err(|e| {
975 McpRequestError::QueryValidationFailed(format!("Failed to parse SQL: {}", e))
976 })?;
977
978 if stmts.len() != 1 {
980 return Err(McpRequestError::QueryValidationFailed(format!(
981 "Only one query allowed at a time. Found {} statements.",
982 stmts.len()
983 )));
984 }
985
986 let stmt = &stmts[0];
988 use mz_sql_parser::ast::Statement;
989
990 match &stmt.ast {
991 Statement::Select(_) | Statement::Show(_) | Statement::ExplainPlan(_) => {
992 Ok(())
994 }
995 _ => Err(McpRequestError::QueryValidationFailed(
996 "Only SELECT, SHOW, and EXPLAIN statements are allowed".to_string(),
997 )),
998 }
999}
1000
1001async fn execute_query(
1002 client: &mut AuthedClient,
1003 cluster: &str,
1004 sql_query: &str,
1005 max_response_size: usize,
1006) -> Result<McpResult, McpRequestError> {
1007 debug!(cluster = %cluster, "Executing user query");
1008
1009 validate_readonly_query(sql_query)?;
1010
1011 let combined_query = format!(
1014 "BEGIN READ ONLY; SET CLUSTER = {}; {}\n; COMMIT;",
1015 escaped_string_literal(cluster),
1016 sql_query
1017 );
1018
1019 let rows = execute_sql(client, &combined_query).await?;
1020
1021 format_rows_response(rows, max_response_size)
1022}
1023
1024async fn query_system_catalog(
1025 client: &mut AuthedClient,
1026 sql_query: &str,
1027 max_response_size: usize,
1028) -> Result<McpResult, McpRequestError> {
1029 debug!("Executing query_system_catalog");
1030
1031 validate_readonly_query(sql_query)?;
1033
1034 validate_system_catalog_query(sql_query)?;
1036
1037 let combined_query = format!(
1043 "BEGIN READ ONLY; SET search_path = mz_catalog, mz_internal, pg_catalog, information_schema; {}; COMMIT;",
1044 sql_query
1045 );
1046
1047 let rows = execute_sql(client, &combined_query).await?;
1048
1049 format_rows_response(rows, max_response_size)
1050}
1051
1052struct TableReferenceCollector {
1054 tables: Vec<(Option<String>, String)>,
1056 cte_names: std::collections::BTreeSet<String>,
1058}
1059
1060impl TableReferenceCollector {
1061 fn new() -> Self {
1062 Self {
1063 tables: Vec::new(),
1064 cte_names: std::collections::BTreeSet::new(),
1065 }
1066 }
1067}
1068
1069impl<'ast> Visit<'ast, Raw> for TableReferenceCollector {
1070 fn visit_cte(&mut self, cte: &'ast mz_sql_parser::ast::Cte<Raw>) {
1071 self.cte_names
1073 .insert(cte.alias.name.as_str().to_lowercase());
1074 visit::visit_cte(self, cte);
1075 }
1076
1077 fn visit_table_factor(&mut self, table_factor: &'ast mz_sql_parser::ast::TableFactor<Raw>) {
1078 if let mz_sql_parser::ast::TableFactor::Table { name, .. } = table_factor {
1080 match name {
1081 RawItemName::Name(n) | RawItemName::Id(_, n, _) => {
1082 let parts = &n.0;
1083 if !parts.is_empty() {
1084 let table_name = parts.last().unwrap().as_str().to_lowercase();
1085
1086 if self.cte_names.contains(&table_name) {
1088 visit::visit_table_factor(self, table_factor);
1089 return;
1090 }
1091
1092 let schema = if parts.len() >= 2 {
1094 Some(parts[parts.len() - 2].as_str().to_lowercase())
1095 } else {
1096 None
1097 };
1098 self.tables.push((schema, table_name));
1099 }
1100 }
1101 }
1102 }
1103 visit::visit_table_factor(self, table_factor);
1104 }
1105}
1106
1107fn validate_system_catalog_query(sql: &str) -> Result<(), McpRequestError> {
1115 let stmts = parse(sql).map_err(|e| {
1117 McpRequestError::QueryValidationFailed(format!("Failed to parse SQL: {}", e))
1118 })?;
1119
1120 if stmts.is_empty() {
1121 return Err(McpRequestError::QueryValidationFailed(
1122 "Empty query".to_string(),
1123 ));
1124 }
1125
1126 let mut collector = TableReferenceCollector::new();
1128 for stmt in &stmts {
1129 collector.visit_statement(&stmt.ast);
1130 }
1131
1132 let is_allowed_schema =
1135 |s: &str| SYSTEM_SCHEMAS.contains(&s) && s != namespaces::MZ_UNSAFE_SCHEMA;
1136
1137 let is_system_table = |(schema, table_name): &(Option<String>, String)| match schema {
1142 Some(s) => is_allowed_schema(s.as_str()),
1143 None => table_name.starts_with("mz_"),
1144 };
1145
1146 let non_system_tables: Vec<String> = collector
1148 .tables
1149 .iter()
1150 .filter(|t| !is_system_table(t))
1151 .map(|(schema, table)| match schema {
1152 Some(s) => format!("{}.{}", s, table),
1153 None => table.clone(),
1154 })
1155 .collect();
1156
1157 if !non_system_tables.is_empty() {
1158 return Err(McpRequestError::QueryValidationFailed(format!(
1159 "Query references non-system tables: {}. Only system catalog tables (mz_*, pg_catalog, information_schema) are allowed.",
1160 non_system_tables.join(", ")
1161 )));
1162 }
1163
1164 use mz_sql_parser::ast::Statement;
1167 let is_select = stmts.iter().any(|s| matches!(&s.ast, Statement::Select(_)));
1168
1169 if is_select && (collector.tables.is_empty() || !collector.tables.iter().any(is_system_table)) {
1170 return Err(McpRequestError::QueryValidationFailed(
1171 "Query must reference at least one system catalog table".to_string(),
1172 ));
1173 }
1174
1175 Ok(())
1176}
1177
1178#[cfg(test)]
1179mod tests {
1180 use super::*;
1181
1182 #[mz_ore::test]
1183 fn test_validate_readonly_query_select() {
1184 assert!(validate_readonly_query("SELECT * FROM mz_tables").is_ok());
1185 assert!(validate_readonly_query("SELECT 1 + 2").is_ok());
1186 assert!(validate_readonly_query(" SELECT 1 ").is_ok());
1187 }
1188
1189 #[mz_ore::test]
1190 fn test_validate_readonly_query_subqueries() {
1191 assert!(
1193 validate_readonly_query(
1194 "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM mz_columns)"
1195 )
1196 .is_ok()
1197 );
1198
1199 assert!(
1201 validate_readonly_query(
1202 "SELECT * FROM (SELECT * FROM mz_tables WHERE name LIKE 'test%') AS t"
1203 )
1204 .is_ok()
1205 );
1206
1207 assert!(validate_readonly_query(
1209 "SELECT * FROM mz_tables t WHERE EXISTS (SELECT 1 FROM mz_columns c WHERE c.id = t.id)"
1210 )
1211 .is_ok());
1212
1213 assert!(validate_readonly_query(
1215 "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM mz_columns WHERE type IN (SELECT name FROM mz_types))"
1216 )
1217 .is_ok());
1218
1219 assert!(
1221 validate_readonly_query(
1222 "SELECT * FROM mz_tables WHERE id = (SELECT MAX(id) FROM mz_columns)"
1223 )
1224 .is_ok()
1225 );
1226 }
1227
1228 #[mz_ore::test]
1229 fn test_validate_readonly_query_show() {
1230 assert!(validate_readonly_query("SHOW CLUSTERS").is_ok());
1231 assert!(validate_readonly_query("SHOW TABLES").is_ok());
1232 }
1233
1234 #[mz_ore::test]
1235 fn test_validate_readonly_query_explain() {
1236 assert!(validate_readonly_query("EXPLAIN SELECT 1").is_ok());
1237 }
1238
1239 #[mz_ore::test]
1240 fn test_validate_readonly_query_rejects_writes() {
1241 assert!(validate_readonly_query("INSERT INTO t VALUES (1)").is_err());
1242 assert!(validate_readonly_query("UPDATE t SET a = 1").is_err());
1243 assert!(validate_readonly_query("DELETE FROM t").is_err());
1244 assert!(validate_readonly_query("CREATE TABLE t (a INT)").is_err());
1245 assert!(validate_readonly_query("DROP TABLE t").is_err());
1246 }
1247
1248 #[mz_ore::test]
1249 fn test_validate_readonly_query_rejects_multiple() {
1250 assert!(validate_readonly_query("SELECT 1; SELECT 2").is_err());
1251 }
1252
1253 #[mz_ore::test]
1254 fn test_validate_readonly_query_rejects_empty() {
1255 assert!(validate_readonly_query("").is_err());
1256 assert!(validate_readonly_query(" ").is_err());
1257 }
1258
1259 #[mz_ore::test]
1260 fn test_validate_system_catalog_query_accepts_mz_tables() {
1261 assert!(validate_system_catalog_query("SELECT * FROM mz_tables").is_ok());
1262 assert!(validate_system_catalog_query("SELECT * FROM mz_internal.mz_comments").is_ok());
1263 assert!(
1264 validate_system_catalog_query(
1265 "SELECT * FROM mz_tables t JOIN mz_columns c ON t.id = c.id"
1266 )
1267 .is_ok()
1268 );
1269 }
1270
1271 #[mz_ore::test]
1272 fn test_validate_system_catalog_query_subqueries() {
1273 assert!(
1275 validate_system_catalog_query(
1276 "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM mz_columns)"
1277 )
1278 .is_ok()
1279 );
1280
1281 assert!(validate_system_catalog_query(
1283 "SELECT * FROM mz_tables WHERE id IN (SELECT table_id FROM mz_columns WHERE type IN (SELECT id FROM mz_types))"
1284 )
1285 .is_ok());
1286
1287 assert!(
1289 validate_system_catalog_query(
1290 "SELECT * FROM (SELECT * FROM mz_tables WHERE name LIKE 'test%') AS t"
1291 )
1292 .is_ok()
1293 );
1294
1295 assert!(
1297 validate_system_catalog_query(
1298 "SELECT * FROM mz_tables WHERE id IN (SELECT table_id FROM user_data)"
1299 )
1300 .is_err()
1301 );
1302
1303 assert!(validate_system_catalog_query(
1305 "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM (SELECT id FROM user_table) AS t)"
1306 )
1307 .is_err());
1308 }
1309
1310 #[mz_ore::test]
1311 fn test_validate_system_catalog_query_rejects_user_tables() {
1312 assert!(validate_system_catalog_query("SELECT * FROM user_data").is_err());
1313 assert!(validate_system_catalog_query("SELECT * FROM my_table").is_err());
1314 assert!(
1316 validate_system_catalog_query("SELECT * FROM user_data WHERE 'mz_' IS NOT NULL")
1317 .is_err()
1318 );
1319 }
1320
1321 #[mz_ore::test]
1322 fn test_validate_system_catalog_query_allows_functions() {
1323 assert!(
1325 validate_system_catalog_query(
1326 "SELECT date_part('year', now())::int4 AS y FROM mz_tables LIMIT 1"
1327 )
1328 .is_ok()
1329 );
1330 assert!(validate_system_catalog_query("SELECT length(name) FROM mz_tables").is_ok());
1331 assert!(
1332 validate_system_catalog_query(
1333 "SELECT count(*) FROM mz_sources WHERE now() > created_at"
1334 )
1335 .is_ok()
1336 );
1337 }
1338
1339 #[mz_ore::test]
1340 fn test_validate_system_catalog_query_schema_qualified() {
1341 assert!(validate_system_catalog_query("SELECT * FROM mz_catalog.mz_tables").is_ok());
1343 assert!(validate_system_catalog_query("SELECT * FROM mz_internal.mz_sessions").is_ok());
1344 assert!(validate_system_catalog_query("SELECT * FROM pg_catalog.pg_type").is_ok());
1345 assert!(validate_system_catalog_query("SELECT * FROM information_schema.tables").is_ok());
1346
1347 assert!(validate_system_catalog_query("SELECT * FROM public.user_table").is_err());
1349 assert!(validate_system_catalog_query("SELECT * FROM myschema.mytable").is_err());
1350
1351 assert!(
1353 validate_system_catalog_query("SELECT * FROM mz_unsafe.mz_some_table").is_err(),
1354 "mz_unsafe schema should be blocked even though it is a system schema"
1355 );
1356
1357 assert!(
1359 validate_system_catalog_query(
1360 "SELECT * FROM mz_catalog.mz_tables JOIN public.user_data ON true"
1361 )
1362 .is_err()
1363 );
1364 }
1365
1366 #[mz_ore::test]
1367 fn test_validate_system_catalog_query_adversarial_cases() {
1368 assert!(
1370 validate_system_catalog_query(
1371 "WITH user_cte AS (SELECT * FROM user_data) \
1372 SELECT * FROM mz_tables, user_cte"
1373 )
1374 .is_err(),
1375 "Should reject CTE referencing user table"
1376 );
1377
1378 assert!(
1380 validate_system_catalog_query(
1381 "WITH \
1382 cte1 AS (SELECT * FROM mz_tables), \
1383 cte2 AS (SELECT * FROM cte1), \
1384 cte3 AS (SELECT * FROM user_data) \
1385 SELECT * FROM cte2"
1386 )
1387 .is_err(),
1388 "Should reject CTE chain with user table"
1389 );
1390
1391 assert!(
1393 validate_system_catalog_query(
1394 "SELECT * FROM mz_tables t1 \
1395 JOIN user_data u ON t1.id = u.id \
1396 JOIN mz_sources s ON t1.id = s.id"
1397 )
1398 .is_err(),
1399 "Should reject multi-join with user table"
1400 );
1401
1402 assert!(
1404 validate_system_catalog_query(
1405 "SELECT * FROM mz_tables t \
1406 LEFT JOIN user_data u ON t.id = u.table_id \
1407 WHERE u.id IS NULL"
1408 )
1409 .is_err(),
1410 "Should reject LEFT JOIN with user table"
1411 );
1412
1413 assert!(
1415 validate_system_catalog_query(
1416 "SELECT * FROM mz_tables WHERE id IN \
1417 (SELECT table_id FROM (SELECT * FROM user_data) AS u)"
1418 )
1419 .is_err(),
1420 "Should reject nested subquery with user table"
1421 );
1422
1423 assert!(
1425 validate_system_catalog_query(
1426 "SELECT name FROM mz_tables \
1427 UNION \
1428 SELECT name FROM user_data"
1429 )
1430 .is_err(),
1431 "Should reject UNION with user table"
1432 );
1433
1434 assert!(
1436 validate_system_catalog_query(
1437 "SELECT id FROM mz_sources \
1438 UNION ALL \
1439 SELECT id FROM products"
1440 )
1441 .is_err(),
1442 "Should reject UNION ALL with user table"
1443 );
1444
1445 assert!(
1447 validate_system_catalog_query("SELECT * FROM mz_tables CROSS JOIN user_data").is_err(),
1448 "Should reject CROSS JOIN with user table"
1449 );
1450
1451 assert!(
1453 validate_system_catalog_query(
1454 "SELECT t.*, (SELECT COUNT(*) FROM user_data) AS cnt FROM mz_tables t"
1455 )
1456 .is_err(),
1457 "Should reject subquery in SELECT with user table"
1458 );
1459
1460 assert!(
1462 validate_system_catalog_query("SELECT * FROM mz_catalogg.fake_table").is_err(),
1463 "Should reject typo-squatting schema name"
1464 );
1465 assert!(
1466 validate_system_catalog_query("SELECT * FROM mz_catalog_hack.fake_table").is_err(),
1467 "Should reject fake schema with mz_catalog prefix"
1468 );
1469
1470 assert!(
1472 validate_system_catalog_query(
1473 "SELECT * FROM mz_tables t, LATERAL (SELECT * FROM user_data WHERE id = t.id) u"
1474 )
1475 .is_err(),
1476 "Should reject LATERAL join with user table"
1477 );
1478
1479 assert!(
1481 validate_system_catalog_query(
1482 "WITH \
1483 tables AS (SELECT * FROM mz_tables), \
1484 sources AS (SELECT * FROM mz_sources) \
1485 SELECT t.name, s.name \
1486 FROM tables t \
1487 JOIN sources s ON t.id = s.id \
1488 WHERE t.id IN (SELECT id FROM mz_columns)"
1489 )
1490 .is_ok(),
1491 "Should allow complex query with only system tables"
1492 );
1493
1494 assert!(
1496 validate_system_catalog_query(
1497 "SELECT name FROM mz_tables \
1498 UNION \
1499 SELECT name FROM mz_sources"
1500 )
1501 .is_ok(),
1502 "Should allow UNION of system tables"
1503 );
1504 }
1505
1506 #[mz_ore::test]
1507 fn test_validate_system_catalog_query_rejects_constant_queries() {
1508 assert!(
1511 validate_system_catalog_query("SELECT 1").is_err(),
1512 "Should reject constant SELECT with no table references"
1513 );
1514 assert!(
1515 validate_system_catalog_query("SELECT 1 + 2, 'hello'").is_err(),
1516 "Should reject constant expression SELECT"
1517 );
1518 assert!(
1519 validate_system_catalog_query("SELECT now()").is_err(),
1520 "Should reject function-only SELECT with no table references"
1521 );
1522 }
1523
1524 #[mz_ore::test]
1525 fn test_validate_system_catalog_query_rejects_mixed_tables() {
1526 assert!(
1527 validate_system_catalog_query(
1528 "SELECT * FROM mz_tables t JOIN user_data u ON t.id = u.table_id"
1529 )
1530 .is_err()
1531 );
1532 }
1533
1534 #[mz_ore::test]
1535 fn test_validate_system_catalog_query_allows_show() {
1536 assert!(
1538 validate_system_catalog_query("SHOW TABLES FROM mz_internal").is_ok(),
1539 "SHOW TABLES FROM mz_internal should be allowed"
1540 );
1541 assert!(
1542 validate_system_catalog_query("SHOW TABLES FROM mz_catalog").is_ok(),
1543 "SHOW TABLES FROM mz_catalog should be allowed"
1544 );
1545 assert!(
1546 validate_system_catalog_query("SHOW CLUSTERS").is_ok(),
1547 "SHOW CLUSTERS should be allowed"
1548 );
1549 assert!(
1550 validate_system_catalog_query("SHOW SOURCES").is_ok(),
1551 "SHOW SOURCES should be allowed"
1552 );
1553 assert!(
1554 validate_system_catalog_query("SHOW TABLES").is_ok(),
1555 "SHOW TABLES should be allowed"
1556 );
1557 }
1558
1559 #[mz_ore::test]
1560 fn test_validate_system_catalog_query_allows_explain() {
1561 assert!(
1562 validate_system_catalog_query("EXPLAIN SELECT * FROM mz_tables").is_ok(),
1563 "EXPLAIN of system table query should be allowed"
1564 );
1565 assert!(
1566 validate_system_catalog_query("EXPLAIN SELECT 1").is_ok(),
1567 "EXPLAIN SELECT 1 should be allowed"
1568 );
1569 }
1570
1571 #[mz_ore::test(tokio::test)]
1574 async fn test_tools_list_agent_query_tool_disabled() {
1575 let result = handle_tools_list(McpEndpointType::Agent, false, 1_000_000)
1576 .await
1577 .unwrap();
1578 let McpResult::ToolsList(list) = result else {
1579 panic!("Expected ToolsList result");
1580 };
1581 let tool_names: Vec<&str> = list.tools.iter().map(|t| t.name.as_str()).collect();
1582 assert!(
1583 tool_names.contains(&"get_data_products"),
1584 "get_data_products should always be present"
1585 );
1586 assert!(
1587 tool_names.contains(&"get_data_product_details"),
1588 "get_data_product_details should always be present"
1589 );
1590 assert!(
1591 tool_names.contains(&"read_data_product"),
1592 "read_data_product should always be present"
1593 );
1594 assert!(
1595 !tool_names.contains(&"query"),
1596 "query tool should be hidden when disabled"
1597 );
1598 }
1599
1600 #[mz_ore::test(tokio::test)]
1601 async fn test_tools_list_agent_query_tool_enabled() {
1602 let result = handle_tools_list(McpEndpointType::Agent, true, 1_000_000)
1603 .await
1604 .unwrap();
1605 let McpResult::ToolsList(list) = result else {
1606 panic!("Expected ToolsList result");
1607 };
1608 let tool_names: Vec<&str> = list.tools.iter().map(|t| t.name.as_str()).collect();
1609 assert!(
1610 tool_names.contains(&"get_data_products"),
1611 "get_data_products should always be present"
1612 );
1613 assert!(
1614 tool_names.contains(&"get_data_product_details"),
1615 "get_data_product_details should always be present"
1616 );
1617 assert!(
1618 tool_names.contains(&"read_data_product"),
1619 "read_data_product should always be present"
1620 );
1621 assert!(
1622 tool_names.contains(&"query"),
1623 "query tool should be present when enabled"
1624 );
1625 }
1626
1627 #[mz_ore::test(tokio::test)]
1628 async fn test_tools_list_developer_unaffected_by_query_flag() {
1629 for flag in [true, false] {
1631 let result = handle_tools_list(McpEndpointType::Developer, flag, 1_000_000)
1632 .await
1633 .unwrap();
1634 let McpResult::ToolsList(list) = result else {
1635 panic!("Expected ToolsList result");
1636 };
1637 let tool_names: Vec<&str> = list.tools.iter().map(|t| t.name.as_str()).collect();
1638 assert!(
1639 tool_names.contains(&"query_system_catalog"),
1640 "query_system_catalog should always be present on developer"
1641 );
1642 assert!(
1643 !tool_names.contains(&"query"),
1644 "query tool should never appear on developer"
1645 );
1646 }
1647 }
1648
1649 #[mz_ore::test]
1652 fn test_format_rows_response_within_limit() {
1653 let rows = vec![vec![json!("a"), json!(1)], vec![json!("b"), json!(2)]];
1654 let result = format_rows_response(rows, 1_000_000).unwrap();
1655 let McpResult::ToolContent(content) = result else {
1656 panic!("Expected ToolContent");
1657 };
1658 assert_eq!(content.content.len(), 1);
1659 assert!(content.content[0].text.contains("\"a\""));
1660 assert!(content.content[0].text.contains("\"b\""));
1661 }
1662
1663 #[mz_ore::test]
1664 fn test_format_rows_response_errors_when_over_limit() {
1665 let rows: Vec<Vec<serde_json::Value>> = (0..100)
1666 .map(|i| vec![json!(format!("row_{}", i)), json!(i)])
1667 .collect();
1668 let err = format_rows_response(rows, 500).unwrap_err();
1669 let msg = err.to_string();
1670 assert!(
1671 msg.contains("exceeds the 500 byte limit"),
1672 "Error should mention the size limit, got: {msg}"
1673 );
1674 assert!(
1675 msg.contains("Use LIMIT or WHERE"),
1676 "Error should suggest narrowing the query, got: {msg}"
1677 );
1678 }
1679
1680 #[mz_ore::test]
1681 fn test_format_rows_response_empty_rows() {
1682 let rows: Vec<Vec<serde_json::Value>> = vec![];
1683 let result = format_rows_response(rows, 1000).unwrap();
1684 let McpResult::ToolContent(content) = result else {
1685 panic!("Expected ToolContent");
1686 };
1687 assert_eq!(content.content.len(), 1);
1688 assert_eq!(content.content[0].text, "[]");
1689 }
1690
1691 #[mz_ore::test]
1694 fn test_safe_data_product_name_valid() {
1695 assert_eq!(
1697 safe_data_product_name(r#""materialize"."public"."my_view""#).unwrap(),
1698 r#""materialize"."public"."my_view""#
1699 );
1700 assert_eq!(
1702 safe_data_product_name(r#""public"."my_view""#).unwrap(),
1703 r#""public"."my_view""#
1704 );
1705 assert_eq!(safe_data_product_name("my_view").unwrap(), r#""my_view""#);
1707 }
1708
1709 #[mz_ore::test]
1710 fn test_safe_data_product_name_rejects_empty() {
1711 assert!(safe_data_product_name("").is_err());
1712 assert!(safe_data_product_name(" ").is_err());
1713 }
1714
1715 #[mz_ore::test]
1716 fn test_safe_data_product_name_rejects_sql_injection() {
1717 assert!(safe_data_product_name("my_view; DROP TABLE users").is_err());
1719 assert!(safe_data_product_name("my_view UNION SELECT * FROM secrets").is_err());
1721 assert!(safe_data_product_name("my_view, secrets").is_err());
1723 assert!(safe_data_product_name("my_view WHERE 1=1 --").is_err());
1725 }
1726
1727 #[mz_ore::test]
1728 fn test_mcp_error_codes() {
1729 assert_eq!(
1730 McpRequestError::InvalidJsonRpcVersion.error_code(),
1731 error_codes::INVALID_REQUEST
1732 );
1733 assert_eq!(
1734 McpRequestError::MethodNotFound("test".to_string()).error_code(),
1735 error_codes::METHOD_NOT_FOUND
1736 );
1737 assert_eq!(
1738 McpRequestError::QueryExecutionFailed("test".to_string()).error_code(),
1739 error_codes::INTERNAL_ERROR
1740 );
1741 }
1742}