1use std::sync::Arc;
27use std::time::Duration;
28
29use anyhow::anyhow;
30use axum::Extension;
31use axum::Json;
32use axum::response::IntoResponse;
33use http::{HeaderMap, HeaderValue, StatusCode};
34use mz_adapter_types::dyncfgs::{
35 ENABLE_MCP_AGENT, ENABLE_MCP_AGENT_QUERY_TOOL, ENABLE_MCP_DEVELOPER, MCP_MAX_RESPONSE_SIZE,
36};
37use mz_repr::namespaces::{self, SYSTEM_SCHEMAS};
38use mz_sql::parse::parse;
39use mz_sql::session::metadata::SessionMetadata;
40use mz_sql::session::vars::{APPLICATION_NAME, Var, VarInput};
41use mz_sql_parser::ast::display::{AstDisplay, escaped_string_literal};
42use mz_sql_parser::ast::visit::{self, Visit};
43use mz_sql_parser::ast::{Raw, RawItemName};
44use mz_sql_parser::parser::parse_item_name;
45use serde::{Deserialize, Serialize};
46use serde_json::json;
47use thiserror::Error;
48use tracing::{debug, warn};
49
50use crate::http::AuthedClient;
51use crate::http::sql::{SqlRequest, SqlResponse, SqlResult, execute_request};
52
53const JSONRPC_VERSION: &str = "2.0";
57
58const MCP_PROTOCOL_VERSION: &str = "2025-11-25";
61
62const MCP_REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
67
68const DISCOVERY_QUERY: &str = "SELECT * FROM mz_internal.mz_mcp_data_products";
70const DETAILS_QUERY_PREFIX: &str =
72 "SELECT * FROM mz_internal.mz_mcp_data_product_details WHERE object_name = ";
73
74#[derive(Debug, Error)]
76enum McpRequestError {
77 #[error("Invalid JSON-RPC version: expected 2.0")]
78 InvalidJsonRpcVersion,
79 #[error("Method not found: {0}")]
80 #[allow(dead_code)] MethodNotFound(String),
82 #[error("Tool not found: {0}")]
83 ToolNotFound(String),
84 #[error("Data product not found: {0}")]
85 DataProductNotFound(String),
86 #[error("Query validation failed: {0}")]
87 QueryValidationFailed(String),
88 #[error("Query execution failed: {0}")]
89 QueryExecutionFailed(String),
90 #[error("Internal error: {0}")]
91 Internal(#[from] anyhow::Error),
92}
93
94impl McpRequestError {
95 fn error_code(&self) -> i32 {
96 match self {
97 Self::InvalidJsonRpcVersion => error_codes::INVALID_REQUEST,
98 Self::MethodNotFound(_) => error_codes::METHOD_NOT_FOUND,
99 Self::ToolNotFound(_) => error_codes::INVALID_PARAMS,
100 Self::DataProductNotFound(_) => error_codes::INVALID_PARAMS,
101 Self::QueryValidationFailed(_) => error_codes::INVALID_PARAMS,
102 Self::QueryExecutionFailed(_) | Self::Internal(_) => error_codes::INTERNAL_ERROR,
103 }
104 }
105
106 fn error_type(&self) -> &'static str {
107 match self {
108 Self::InvalidJsonRpcVersion => "InvalidRequest",
109 Self::MethodNotFound(_) => "MethodNotFound",
110 Self::ToolNotFound(_) => "ToolNotFound",
111 Self::DataProductNotFound(_) => "DataProductNotFound",
112 Self::QueryValidationFailed(_) => "ValidationError",
113 Self::QueryExecutionFailed(_) => "ExecutionError",
114 Self::Internal(_) => "InternalError",
115 }
116 }
117}
118
119#[derive(Debug, Deserialize)]
121pub(crate) struct McpRequest {
122 jsonrpc: String,
123 id: Option<serde_json::Value>,
124 #[serde(flatten)]
125 method: McpMethod,
126}
127
128#[derive(Debug, Deserialize)]
130#[serde(tag = "method", content = "params")]
131enum McpMethod {
132 #[serde(rename = "initialize")]
134 Initialize(#[allow(dead_code)] InitializeParams),
135 #[serde(rename = "tools/list")]
136 ToolsList,
137 #[serde(rename = "tools/call")]
138 ToolsCall(ToolsCallParams),
139 #[serde(other)]
141 Unknown,
142}
143
144impl std::fmt::Display for McpMethod {
145 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146 match self {
147 McpMethod::Initialize(_) => write!(f, "initialize"),
148 McpMethod::ToolsList => write!(f, "tools/list"),
149 McpMethod::ToolsCall(_) => write!(f, "tools/call"),
150 McpMethod::Unknown => write!(f, "unknown"),
151 }
152 }
153}
154
155#[derive(Debug, Deserialize)]
156struct InitializeParams {
157 #[serde(rename = "protocolVersion")]
159 #[allow(dead_code)]
160 protocol_version: String,
161 #[serde(default)]
163 #[allow(dead_code)]
164 capabilities: serde_json::Value,
165 #[serde(rename = "clientInfo")]
167 #[allow(dead_code)]
168 client_info: Option<ClientInfo>,
169}
170
171#[derive(Debug, Deserialize)]
172struct ClientInfo {
173 #[allow(dead_code)]
174 name: String,
175 #[allow(dead_code)]
176 version: String,
177}
178
179#[derive(Debug, Deserialize)]
182#[serde(tag = "name", content = "arguments")]
183#[serde(rename_all = "snake_case")]
184enum ToolsCallParams {
185 GetDataProducts(#[serde(default)] ()),
188 GetDataProductDetails(GetDataProductDetailsParams),
189 ReadDataProduct(ReadDataProductParams),
190 Query(QueryParams),
191 QuerySystemCatalog(QuerySystemCatalogParams),
193}
194
195impl std::fmt::Display for ToolsCallParams {
196 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197 match self {
198 ToolsCallParams::GetDataProducts(_) => write!(f, "get_data_products"),
199 ToolsCallParams::GetDataProductDetails(_) => write!(f, "get_data_product_details"),
200 ToolsCallParams::ReadDataProduct(_) => write!(f, "read_data_product"),
201 ToolsCallParams::Query(_) => write!(f, "query"),
202 ToolsCallParams::QuerySystemCatalog(_) => write!(f, "query_system_catalog"),
203 }
204 }
205}
206
207#[derive(Debug, Deserialize)]
208struct GetDataProductDetailsParams {
209 name: String,
210}
211
212#[derive(Debug, Deserialize)]
213struct ReadDataProductParams {
214 name: String,
215 #[serde(default = "default_read_limit")]
216 limit: u32,
217 cluster: Option<String>,
218}
219
220fn default_read_limit() -> u32 {
221 500
222}
223
224const MAX_READ_LIMIT: u32 = 1000;
226
227#[derive(Debug, Deserialize)]
228struct QueryParams {
229 cluster: String,
230 sql_query: String,
231}
232
233#[derive(Debug, Deserialize)]
234struct QuerySystemCatalogParams {
235 sql_query: String,
236}
237
238#[derive(Debug, Serialize)]
239struct McpResponse {
240 jsonrpc: String,
241 id: serde_json::Value,
242 #[serde(skip_serializing_if = "Option::is_none")]
243 result: Option<McpResult>,
244 #[serde(skip_serializing_if = "Option::is_none")]
245 error: Option<McpError>,
246}
247
248#[derive(Debug, Serialize)]
250#[serde(untagged)]
251enum McpResult {
252 Initialize(InitializeResult),
253 ToolsList(ToolsListResult),
254 ToolContent(ToolContentResult),
255}
256
257#[derive(Debug, Serialize)]
258struct InitializeResult {
259 #[serde(rename = "protocolVersion")]
260 protocol_version: String,
261 capabilities: Capabilities,
262 #[serde(rename = "serverInfo")]
263 server_info: ServerInfo,
264 #[serde(skip_serializing_if = "Option::is_none")]
265 instructions: Option<String>,
266}
267
268#[derive(Debug, Serialize)]
269struct Capabilities {
270 tools: serde_json::Value,
271}
272
273#[derive(Debug, Serialize)]
274struct ServerInfo {
275 name: String,
276 version: String,
277}
278
279#[derive(Debug, Serialize)]
280struct ToolsListResult {
281 tools: Vec<ToolDefinition>,
282}
283
284#[derive(Debug, Serialize)]
285struct ToolDefinition {
286 name: String,
287 #[serde(skip_serializing_if = "Option::is_none")]
288 title: Option<String>,
289 description: String,
290 #[serde(rename = "inputSchema")]
291 input_schema: serde_json::Value,
292 #[serde(skip_serializing_if = "Option::is_none")]
293 annotations: Option<ToolAnnotations>,
294}
295
296#[derive(Debug, Serialize)]
299struct ToolAnnotations {
300 #[serde(rename = "readOnlyHint", skip_serializing_if = "Option::is_none")]
301 read_only_hint: Option<bool>,
302 #[serde(rename = "destructiveHint", skip_serializing_if = "Option::is_none")]
303 destructive_hint: Option<bool>,
304 #[serde(rename = "idempotentHint", skip_serializing_if = "Option::is_none")]
305 idempotent_hint: Option<bool>,
306 #[serde(rename = "openWorldHint", skip_serializing_if = "Option::is_none")]
307 open_world_hint: Option<bool>,
308}
309
310const READ_ONLY_ANNOTATIONS: ToolAnnotations = ToolAnnotations {
312 read_only_hint: Some(true),
313 destructive_hint: Some(false),
314 idempotent_hint: Some(true),
315 open_world_hint: Some(false),
316};
317
318#[derive(Debug, Serialize)]
319struct ToolContentResult {
320 content: Vec<ContentBlock>,
321 #[serde(rename = "isError")]
322 is_error: bool,
323}
324
325#[derive(Debug, Serialize)]
326struct ContentBlock {
327 #[serde(rename = "type")]
328 content_type: String,
329 text: String,
330}
331
332mod error_codes {
334 pub const INVALID_REQUEST: i32 = -32600;
335 pub const METHOD_NOT_FOUND: i32 = -32601;
336 pub const INVALID_PARAMS: i32 = -32602;
337 pub const INTERNAL_ERROR: i32 = -32603;
338}
339
340#[derive(Debug, Serialize)]
341struct McpError {
342 code: i32,
343 message: String,
344 #[serde(skip_serializing_if = "Option::is_none")]
345 data: Option<serde_json::Value>,
346}
347
348impl From<McpRequestError> for McpError {
349 fn from(err: McpRequestError) -> Self {
350 McpError {
351 code: err.error_code(),
352 message: err.to_string(),
353 data: Some(json!({
354 "error_type": err.error_type(),
355 })),
356 }
357 }
358}
359
360#[derive(Debug, Clone, Copy)]
361enum McpEndpointType {
362 Agent,
363 Developer,
364}
365
366impl std::fmt::Display for McpEndpointType {
367 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
368 match self {
369 McpEndpointType::Agent => write!(f, "agent"),
370 McpEndpointType::Developer => write!(f, "developer"),
371 }
372 }
373}
374
375pub async fn handle_mcp_method_not_allowed() -> impl IntoResponse {
378 StatusCode::METHOD_NOT_ALLOWED
379}
380
381pub async fn handle_mcp_agent(
383 headers: HeaderMap,
384 Extension(allowed_origins): Extension<Arc<Vec<HeaderValue>>>,
385 client: AuthedClient,
386 Json(body): Json<McpRequest>,
387) -> axum::response::Response {
388 if let Some(resp) = validate_origin(&headers, &allowed_origins) {
389 return resp;
390 }
391 handle_mcp_request(client, body, McpEndpointType::Agent)
392 .await
393 .into_response()
394}
395
396pub async fn handle_mcp_developer(
398 headers: HeaderMap,
399 Extension(allowed_origins): Extension<Arc<Vec<HeaderValue>>>,
400 client: AuthedClient,
401 Json(body): Json<McpRequest>,
402) -> axum::response::Response {
403 if let Some(resp) = validate_origin(&headers, &allowed_origins) {
404 return resp;
405 }
406 handle_mcp_request(client, body, McpEndpointType::Developer)
407 .await
408 .into_response()
409}
410
411fn validate_origin(
420 headers: &HeaderMap,
421 allowed: &[HeaderValue],
422) -> Option<axum::response::Response> {
423 let origin = headers.get(http::header::ORIGIN)?;
424 if mz_http_util::origin_is_allowed(origin, allowed) {
425 return None;
426 }
427 warn!(
428 origin = ?origin,
429 "MCP request rejected: origin not in allowlist",
430 );
431 Some(StatusCode::FORBIDDEN.into_response())
432}
433
434async fn handle_mcp_request(
435 mut client: AuthedClient,
436 request: McpRequest,
437 endpoint_type: McpEndpointType,
438) -> impl IntoResponse {
439 let catalog = client.client.catalog_snapshot("mcp").await;
441 let dyncfgs = catalog.system_config().dyncfgs();
442 let enabled = match endpoint_type {
443 McpEndpointType::Agent => ENABLE_MCP_AGENT.get(dyncfgs),
444 McpEndpointType::Developer => ENABLE_MCP_DEVELOPER.get(dyncfgs),
445 };
446 if !enabled {
447 debug!(endpoint = %endpoint_type, "MCP endpoint disabled by feature flag");
448 return StatusCode::SERVICE_UNAVAILABLE.into_response();
449 }
450
451 let query_tool_enabled = ENABLE_MCP_AGENT_QUERY_TOOL.get(dyncfgs);
452 let max_response_size = MCP_MAX_RESPONSE_SIZE.get(dyncfgs);
453
454 let app_name = match endpoint_type {
458 McpEndpointType::Agent => "mz_mcp_agents",
459 McpEndpointType::Developer => "mz_mcp_developer",
460 };
461 client
462 .client
463 .session()
464 .vars_mut()
465 .set_default(APPLICATION_NAME.name(), VarInput::Flat(app_name))
466 .expect("application_name is a known session var");
467
468 let user = client.client.session().user().name.clone();
469 let is_notification = request.id.is_none();
470
471 debug!(
472 method = %request.method,
473 endpoint = %endpoint_type,
474 user = %user,
475 is_notification = is_notification,
476 "MCP request received"
477 );
478
479 if is_notification {
481 debug!(method = %request.method, "Received notification (no response will be sent)");
482 return StatusCode::OK.into_response();
483 }
484
485 let request_id = request.id.clone().unwrap_or(serde_json::Value::Null);
486
487 let result = tokio::time::timeout(
489 MCP_REQUEST_TIMEOUT,
490 mz_ore::task::spawn(|| "mcp_request", async move {
491 handle_mcp_request_inner(
492 &mut client,
493 request,
494 endpoint_type,
495 query_tool_enabled,
496 max_response_size,
497 )
498 .await
499 }),
500 )
501 .await;
502
503 let response = match result {
504 Ok(inner) => inner,
505 Err(_elapsed) => {
506 warn!(
507 endpoint = %endpoint_type,
508 timeout = ?MCP_REQUEST_TIMEOUT,
509 "MCP request timed out",
510 );
511 McpResponse {
512 jsonrpc: JSONRPC_VERSION.to_string(),
513 id: request_id,
514 result: None,
515 error: Some(
516 McpRequestError::QueryExecutionFailed(format!(
517 "Request timed out after {} seconds.",
518 MCP_REQUEST_TIMEOUT.as_secs(),
519 ))
520 .into(),
521 ),
522 }
523 }
524 };
525
526 (StatusCode::OK, Json(response)).into_response()
527}
528
529async fn handle_mcp_request_inner(
530 client: &mut AuthedClient,
531 request: McpRequest,
532 endpoint_type: McpEndpointType,
533 query_tool_enabled: bool,
534 max_response_size: usize,
535) -> McpResponse {
536 let request_id = request.id.clone().unwrap_or(serde_json::Value::Null);
538
539 let result = handle_mcp_method(
540 client,
541 &request,
542 endpoint_type,
543 query_tool_enabled,
544 max_response_size,
545 )
546 .await;
547
548 match result {
549 Ok(result_value) => McpResponse {
550 jsonrpc: JSONRPC_VERSION.to_string(),
551 id: request_id,
552 result: Some(result_value),
553 error: None,
554 },
555 Err(e) => {
556 if !matches!(
558 e,
559 McpRequestError::MethodNotFound(_) | McpRequestError::InvalidJsonRpcVersion
560 ) {
561 warn!(error = %e, method = %request.method, "MCP method execution failed");
562 }
563 McpResponse {
564 jsonrpc: JSONRPC_VERSION.to_string(),
565 id: request_id,
566 result: None,
567 error: Some(e.into()),
568 }
569 }
570 }
571}
572
573async fn handle_mcp_method(
574 client: &mut AuthedClient,
575 request: &McpRequest,
576 endpoint_type: McpEndpointType,
577 query_tool_enabled: bool,
578 max_response_size: usize,
579) -> Result<McpResult, McpRequestError> {
580 if request.jsonrpc != JSONRPC_VERSION {
582 return Err(McpRequestError::InvalidJsonRpcVersion);
583 }
584
585 match &request.method {
587 McpMethod::Initialize(_) => {
588 debug!(endpoint = %endpoint_type, "Processing initialize");
589 handle_initialize(endpoint_type).await
590 }
591 McpMethod::ToolsList => {
592 debug!(endpoint = %endpoint_type, "Processing tools/list");
593 handle_tools_list(endpoint_type, query_tool_enabled, max_response_size).await
594 }
595 McpMethod::ToolsCall(params) => {
596 debug!(tool = %params, endpoint = %endpoint_type, "Processing tools/call");
597 handle_tools_call(
598 client,
599 params,
600 endpoint_type,
601 query_tool_enabled,
602 max_response_size,
603 )
604 .await
605 }
606 McpMethod::Unknown => Err(McpRequestError::MethodNotFound(
607 "unknown method".to_string(),
608 )),
609 }
610}
611
612fn endpoint_instructions(endpoint_type: McpEndpointType) -> Option<String> {
615 match endpoint_type {
616 McpEndpointType::Agent => None,
617 McpEndpointType::Developer => Some(concat!(
618 "You are connected to the Materialize developer MCP server. ",
619 "You have read-only access to system catalog tables (mz_*, pg_catalog, information_schema) ",
620 "for troubleshooting and observability.\n\n",
621 "IMPORTANT: Before writing queries, discover table schemas using the mz_ontology tables:\n",
622 "- mz_internal.mz_ontology_entity_types: what catalog entities exist and which tables they map to\n",
623 "- mz_internal.mz_ontology_link_types: relationships between entities (foreign keys, metrics, etc.)\n",
624 "- mz_internal.mz_ontology_properties: column names, types, and descriptions for each entity\n",
625 "- mz_internal.mz_ontology_semantic_types: typed ID domains (CatalogItemId, ReplicaId, etc.)\n\n",
626 "Use these to find the correct tables, join paths, and column names instead of guessing.\n\n",
627 "Key rules:\n",
628 "- mz_source_statuses and mz_sink_statuses use `last_status_change_at` (NOT `updated_at`)\n",
629 "- mz_cluster_replica_utilization only has `replica_id` — JOIN with mz_cluster_replicas and mz_clusters to get names\n",
630 "- Do NOT query mz_introspection.mz_dataflow_arrangement_sizes — it is cluster-scoped and has uint8/text type mismatches\n",
631 "- Use SHOW COLUMNS FROM <table> to verify column names if unsure",
632 ).to_string()),
633 }
634}
635
636async fn handle_initialize(endpoint_type: McpEndpointType) -> Result<McpResult, McpRequestError> {
637 Ok(McpResult::Initialize(InitializeResult {
638 protocol_version: MCP_PROTOCOL_VERSION.to_string(),
639 capabilities: Capabilities { tools: json!({}) },
640 server_info: ServerInfo {
641 name: format!("materialize-mcp-{}", endpoint_type),
642 version: env!("CARGO_PKG_VERSION").to_string(),
643 },
644 instructions: endpoint_instructions(endpoint_type),
645 }))
646}
647
648async fn handle_tools_list(
649 endpoint_type: McpEndpointType,
650 query_tool_enabled: bool,
651 max_response_size: usize,
652) -> Result<McpResult, McpRequestError> {
653 let size_hint = format!("Response limit: {} MB.", max_response_size / 1_000_000);
654
655 let tools = match endpoint_type {
656 McpEndpointType::Agent => {
657 let mut tools = vec![
658 ToolDefinition {
659 name: "get_data_products".to_string(),
660 title: Some("List Data Products".to_string()),
661 description: "Discover all available real-time data views (data products) that represent business entities like customers, orders, products, etc. Each data product provides fresh, queryable data with defined schemas. Use this first to see what data is available before querying specific information.".to_string(),
662 input_schema: json!({
663 "type": "object",
664 "properties": {},
665 "required": []
666 }),
667 annotations: Some(READ_ONLY_ANNOTATIONS),
668 },
669 ToolDefinition {
670 name: "get_data_product_details".to_string(),
671 title: Some("Get Data Product Details".to_string()),
672 description: "Get the complete schema and structure of a specific data product. This shows you exactly what fields are available, their types, and what data you can query. Use this after finding a data product from get_data_products() to understand how to query it.".to_string(),
673 input_schema: json!({
674 "type": "object",
675 "properties": {
676 "name": {
677 "type": "string",
678 "description": "Exact name of the data product from get_data_products() list"
679 }
680 },
681 "required": ["name"]
682 }),
683 annotations: Some(READ_ONLY_ANNOTATIONS),
684 },
685 ToolDefinition {
686 name: "read_data_product".to_string(),
687 title: Some("Read Data Product".to_string()),
688 description: format!("Read rows from a specific data product. Returns up to `limit` rows (default 500, max 1000). The data product must exist in the catalog (use get_data_products() to discover available products). Use this to retrieve actual data from a known data product. {size_hint}"),
689 input_schema: json!({
690 "type": "object",
691 "properties": {
692 "name": {
693 "type": "string",
694 "description": "Exact fully-qualified name of the data product (e.g. '\"materialize\".\"schema\".\"view_name\"')"
695 },
696 "limit": {
697 "type": "integer",
698 "description": "Maximum number of rows to return (default 500, max 1000)",
699 "default": 500
700 },
701 "cluster": {
702 "type": "string",
703 "description": "Optional cluster override. If omitted, uses the cluster from the data product catalog."
704 }
705 },
706 "required": ["name"]
707 }),
708 annotations: Some(READ_ONLY_ANNOTATIONS),
709 },
710 ];
711 if query_tool_enabled {
712 tools.push(ToolDefinition {
713 name: "query".to_string(),
714 title: Some("Query Data Products".to_string()),
715 description: format!("Execute SQL queries against real-time data products to retrieve current business information. Use standard PostgreSQL syntax. You can JOIN multiple data products together, but ONLY if they are all hosted on the same cluster. Always specify the cluster parameter from the data product details. This provides fresh, up-to-date results from materialized views. {size_hint}"),
716 input_schema: json!({
717 "type": "object",
718 "properties": {
719 "cluster": {
720 "type": "string",
721 "description": "Exact cluster name from the data product details - required for query execution"
722 },
723 "sql_query": {
724 "type": "string",
725 "description": "PostgreSQL-compatible SELECT statement to retrieve data. Use the fully qualified data product name exactly as provided (with double quotes). You can JOIN multiple data products, but only those on the same cluster."
726 }
727 },
728 "required": ["cluster", "sql_query"]
729 }),
730 annotations: Some(READ_ONLY_ANNOTATIONS),
731 });
732 }
733 tools
734 }
735 McpEndpointType::Developer => {
736 vec![ToolDefinition {
737 name: "query_system_catalog".to_string(),
738 title: Some("Query System Catalog".to_string()),
739 description: concat!(
740 "Query Materialize system catalog tables for troubleshooting and observability. ",
741 "Only mz_*, pg_catalog, and information_schema tables are accessible. ",
742 "Use the mz_internal.mz_ontology_* tables to discover tables, columns, and join paths before writing queries.",
743 ).to_owned() + &format!(" {size_hint}"),
744 input_schema: json!({
745 "type": "object",
746 "properties": {
747 "sql_query": {
748 "type": "string",
749 "description": "PostgreSQL-compatible SELECT, SHOW, or EXPLAIN query referencing mz_* system catalog tables"
750 }
751 },
752 "required": ["sql_query"]
753 }),
754 annotations: Some(READ_ONLY_ANNOTATIONS),
755 }]
756 }
757 };
758
759 Ok(McpResult::ToolsList(ToolsListResult { tools }))
760}
761
762async fn handle_tools_call(
763 client: &mut AuthedClient,
764 params: &ToolsCallParams,
765 endpoint_type: McpEndpointType,
766 query_tool_enabled: bool,
767 max_response_size: usize,
768) -> Result<McpResult, McpRequestError> {
769 match (endpoint_type, params) {
770 (McpEndpointType::Agent, ToolsCallParams::GetDataProducts(_)) => {
771 get_data_products(client, max_response_size).await
772 }
773 (McpEndpointType::Agent, ToolsCallParams::GetDataProductDetails(p)) => {
774 get_data_product_details(client, &p.name, max_response_size).await
775 }
776 (McpEndpointType::Agent, ToolsCallParams::ReadDataProduct(p)) => {
777 read_data_product(
778 client,
779 &p.name,
780 p.limit,
781 p.cluster.as_deref(),
782 max_response_size,
783 )
784 .await
785 }
786 (McpEndpointType::Agent, ToolsCallParams::Query(_)) if !query_tool_enabled => {
787 Err(McpRequestError::ToolNotFound(
788 "query tool is not available. Use get_data_products, get_data_product_details, and read_data_product instead.".to_string(),
789 ))
790 }
791 (McpEndpointType::Agent, ToolsCallParams::Query(p)) => {
792 execute_query(client, &p.cluster, &p.sql_query, max_response_size).await
793 }
794 (McpEndpointType::Developer, ToolsCallParams::QuerySystemCatalog(p)) => {
795 query_system_catalog(client, &p.sql_query, max_response_size).await
796 }
797 (endpoint, tool) => Err(McpRequestError::ToolNotFound(format!(
799 "{} is not available on {} endpoint",
800 tool, endpoint
801 ))),
802 }
803}
804
805async fn execute_sql(
807 client: &mut AuthedClient,
808 query: &str,
809) -> Result<Vec<Vec<serde_json::Value>>, McpRequestError> {
810 let mut response = SqlResponse::new();
811
812 execute_request(
813 client,
814 SqlRequest::Simple {
815 query: query.to_string(),
816 },
817 &mut response,
818 )
819 .await
820 .map_err(|e| McpRequestError::QueryExecutionFailed(e.to_string()))?;
821
822 for result in response.results {
825 match result {
826 SqlResult::Rows { rows, .. } => return Ok(rows),
827 SqlResult::Err { error, .. } => {
828 return Err(McpRequestError::QueryExecutionFailed(error.message));
829 }
830 SqlResult::Ok { .. } => continue,
831 }
832 }
833
834 Err(McpRequestError::QueryExecutionFailed(
835 "Query did not return any results".to_string(),
836 ))
837}
838
839fn format_rows_response(
846 rows: Vec<Vec<serde_json::Value>>,
847 max_size: usize,
848) -> Result<McpResult, McpRequestError> {
849 let text =
850 serde_json::to_string_pretty(&rows).map_err(|e| McpRequestError::Internal(anyhow!(e)))?;
851
852 if text.len() > max_size {
853 return Err(McpRequestError::QueryExecutionFailed(format!(
854 "Response size ({} bytes) exceeds the {} byte limit. \
855 Use LIMIT or WHERE to narrow your query.",
856 text.len(),
857 max_size,
858 )));
859 }
860
861 Ok(McpResult::ToolContent(ToolContentResult {
862 content: vec![ContentBlock {
863 content_type: "text".to_string(),
864 text,
865 }],
866 is_error: false,
867 }))
868}
869
870async fn get_data_products(
871 client: &mut AuthedClient,
872 max_response_size: usize,
873) -> Result<McpResult, McpRequestError> {
874 debug!("Executing get_data_products");
875 let rows = execute_sql(client, DISCOVERY_QUERY).await?;
876 debug!("get_data_products returned {} rows", rows.len());
877
878 format_rows_response(rows, max_response_size)
879}
880
881async fn get_data_product_details(
882 client: &mut AuthedClient,
883 name: &str,
884 max_response_size: usize,
885) -> Result<McpResult, McpRequestError> {
886 debug!(name = %name, "Executing get_data_product_details");
887
888 let query = format!("{}{}", DETAILS_QUERY_PREFIX, escaped_string_literal(name));
889
890 let rows = execute_sql(client, &query).await?;
891
892 if rows.is_empty() {
893 return Err(McpRequestError::DataProductNotFound(name.to_string()));
894 }
895
896 format_rows_response(rows, max_response_size)
897}
898
899fn safe_data_product_name(name: &str) -> Result<String, McpRequestError> {
906 let name = name.trim();
907 if name.is_empty() {
908 return Err(McpRequestError::QueryValidationFailed(
909 "Data product name cannot be empty".to_string(),
910 ));
911 }
912
913 let parsed = parse_item_name(name).map_err(|_| {
914 McpRequestError::QueryValidationFailed(format!(
915 "Invalid data product name: {}. Expected a valid object name, \
916 e.g. '\"database\".\"schema\".\"name\"' or 'my_view'",
917 name
918 ))
919 })?;
920
921 Ok(parsed.to_ast_string_stable())
924}
925
926async fn read_data_product(
934 client: &mut AuthedClient,
935 name: &str,
936 limit: u32,
937 cluster_override: Option<&str>,
938 max_response_size: usize,
939) -> Result<McpResult, McpRequestError> {
940 debug!(name = %name, limit = limit, cluster_override = ?cluster_override, "Executing read_data_product");
941
942 let safe_name = safe_data_product_name(name)?;
944
945 let lookup_query = format!(
953 "SELECT 1 FROM mz_internal.mz_mcp_data_products WHERE object_name = {}",
954 escaped_string_literal(name)
955 );
956 let lookup_rows = execute_sql(client, &lookup_query).await?;
957 if lookup_rows.is_empty() {
958 return Err(McpRequestError::DataProductNotFound(name.to_string()));
959 }
960
961 let clamped_limit = limit.min(MAX_READ_LIMIT);
962
963 let read_query = match cluster_override {
964 Some(cluster) => format!(
965 "BEGIN READ ONLY; SET CLUSTER = {}; SELECT * FROM {} LIMIT {}\n; COMMIT;",
966 escaped_string_literal(cluster),
967 safe_name,
968 clamped_limit,
969 ),
970 None => format!("SELECT * FROM {} LIMIT {}", safe_name, clamped_limit),
972 };
973
974 let rows = execute_sql(client, &read_query).await?;
975
976 format_rows_response(rows, max_response_size)
977}
978
979fn validate_readonly_query(sql: &str) -> Result<(), McpRequestError> {
981 let sql = sql.trim();
982 if sql.is_empty() {
983 return Err(McpRequestError::QueryValidationFailed(
984 "Empty query".to_string(),
985 ));
986 }
987
988 let stmts = parse(sql).map_err(|e| {
990 McpRequestError::QueryValidationFailed(format!("Failed to parse SQL: {}", e))
991 })?;
992
993 if stmts.len() != 1 {
995 return Err(McpRequestError::QueryValidationFailed(format!(
996 "Only one query allowed at a time. Found {} statements.",
997 stmts.len()
998 )));
999 }
1000
1001 let stmt = &stmts[0];
1003 use mz_sql_parser::ast::Statement;
1004
1005 match &stmt.ast {
1006 Statement::Select(_) | Statement::Show(_) | Statement::ExplainPlan(_) => {
1007 Ok(())
1009 }
1010 _ => Err(McpRequestError::QueryValidationFailed(
1011 "Only SELECT, SHOW, and EXPLAIN statements are allowed".to_string(),
1012 )),
1013 }
1014}
1015
1016async fn execute_query(
1017 client: &mut AuthedClient,
1018 cluster: &str,
1019 sql_query: &str,
1020 max_response_size: usize,
1021) -> Result<McpResult, McpRequestError> {
1022 debug!(cluster = %cluster, "Executing user query");
1023
1024 validate_readonly_query(sql_query)?;
1025
1026 let combined_query = format!(
1029 "BEGIN READ ONLY; SET CLUSTER = {}; {}\n; COMMIT;",
1030 escaped_string_literal(cluster),
1031 sql_query
1032 );
1033
1034 let rows = execute_sql(client, &combined_query).await?;
1035
1036 format_rows_response(rows, max_response_size)
1037}
1038
1039async fn query_system_catalog(
1040 client: &mut AuthedClient,
1041 sql_query: &str,
1042 max_response_size: usize,
1043) -> Result<McpResult, McpRequestError> {
1044 debug!("Executing query_system_catalog");
1045
1046 validate_readonly_query(sql_query)?;
1048
1049 validate_system_catalog_query(sql_query)?;
1051
1052 let combined_query = format!(
1058 "BEGIN READ ONLY; SET search_path = mz_catalog, mz_internal, pg_catalog, information_schema; {}; COMMIT;",
1059 sql_query
1060 );
1061
1062 let rows = execute_sql(client, &combined_query).await?;
1063
1064 format_rows_response(rows, max_response_size)
1065}
1066
1067struct TableReferenceCollector {
1069 tables: Vec<(Option<String>, String)>,
1071 cte_names: std::collections::BTreeSet<String>,
1073}
1074
1075impl TableReferenceCollector {
1076 fn new() -> Self {
1077 Self {
1078 tables: Vec::new(),
1079 cte_names: std::collections::BTreeSet::new(),
1080 }
1081 }
1082}
1083
1084impl<'ast> Visit<'ast, Raw> for TableReferenceCollector {
1085 fn visit_cte(&mut self, cte: &'ast mz_sql_parser::ast::Cte<Raw>) {
1086 self.cte_names
1088 .insert(cte.alias.name.as_str().to_lowercase());
1089 visit::visit_cte(self, cte);
1090 }
1091
1092 fn visit_table_factor(&mut self, table_factor: &'ast mz_sql_parser::ast::TableFactor<Raw>) {
1093 if let mz_sql_parser::ast::TableFactor::Table { name, .. } = table_factor {
1095 match name {
1096 RawItemName::Name(n) | RawItemName::Id(_, n, _) => {
1097 let parts = &n.0;
1098 if !parts.is_empty() {
1099 let table_name = parts.last().unwrap().as_str().to_lowercase();
1100
1101 if self.cte_names.contains(&table_name) {
1103 visit::visit_table_factor(self, table_factor);
1104 return;
1105 }
1106
1107 let schema = if parts.len() >= 2 {
1109 Some(parts[parts.len() - 2].as_str().to_lowercase())
1110 } else {
1111 None
1112 };
1113 self.tables.push((schema, table_name));
1114 }
1115 }
1116 }
1117 }
1118 visit::visit_table_factor(self, table_factor);
1119 }
1120}
1121
1122fn validate_system_catalog_query(sql: &str) -> Result<(), McpRequestError> {
1130 let stmts = parse(sql).map_err(|e| {
1132 McpRequestError::QueryValidationFailed(format!("Failed to parse SQL: {}", e))
1133 })?;
1134
1135 if stmts.is_empty() {
1136 return Err(McpRequestError::QueryValidationFailed(
1137 "Empty query".to_string(),
1138 ));
1139 }
1140
1141 let mut collector = TableReferenceCollector::new();
1143 for stmt in &stmts {
1144 collector.visit_statement(&stmt.ast);
1145 }
1146
1147 let is_allowed_schema =
1150 |s: &str| SYSTEM_SCHEMAS.contains(&s) && s != namespaces::MZ_UNSAFE_SCHEMA;
1151
1152 let is_system_table = |(schema, table_name): &(Option<String>, String)| match schema {
1157 Some(s) => is_allowed_schema(s.as_str()),
1158 None => table_name.starts_with("mz_"),
1159 };
1160
1161 let non_system_tables: Vec<String> = collector
1163 .tables
1164 .iter()
1165 .filter(|t| !is_system_table(t))
1166 .map(|(schema, table)| match schema {
1167 Some(s) => format!("{}.{}", s, table),
1168 None => table.clone(),
1169 })
1170 .collect();
1171
1172 if !non_system_tables.is_empty() {
1173 return Err(McpRequestError::QueryValidationFailed(format!(
1174 "Query references non-system tables: {}. Only system catalog tables (mz_*, pg_catalog, information_schema) are allowed.",
1175 non_system_tables.join(", ")
1176 )));
1177 }
1178
1179 use mz_sql_parser::ast::Statement;
1182 let is_select = stmts.iter().any(|s| matches!(&s.ast, Statement::Select(_)));
1183
1184 if is_select && (collector.tables.is_empty() || !collector.tables.iter().any(is_system_table)) {
1185 return Err(McpRequestError::QueryValidationFailed(
1186 "Query must reference at least one system catalog table".to_string(),
1187 ));
1188 }
1189
1190 Ok(())
1191}
1192
1193#[cfg(test)]
1194mod tests {
1195 use super::*;
1196
1197 #[mz_ore::test]
1198 fn test_validate_readonly_query_select() {
1199 assert!(validate_readonly_query("SELECT * FROM mz_tables").is_ok());
1200 assert!(validate_readonly_query("SELECT 1 + 2").is_ok());
1201 assert!(validate_readonly_query(" SELECT 1 ").is_ok());
1202 }
1203
1204 #[mz_ore::test]
1205 fn test_validate_readonly_query_subqueries() {
1206 assert!(
1208 validate_readonly_query(
1209 "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM mz_columns)"
1210 )
1211 .is_ok()
1212 );
1213
1214 assert!(
1216 validate_readonly_query(
1217 "SELECT * FROM (SELECT * FROM mz_tables WHERE name LIKE 'test%') AS t"
1218 )
1219 .is_ok()
1220 );
1221
1222 assert!(validate_readonly_query(
1224 "SELECT * FROM mz_tables t WHERE EXISTS (SELECT 1 FROM mz_columns c WHERE c.id = t.id)"
1225 )
1226 .is_ok());
1227
1228 assert!(validate_readonly_query(
1230 "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM mz_columns WHERE type IN (SELECT name FROM mz_types))"
1231 )
1232 .is_ok());
1233
1234 assert!(
1236 validate_readonly_query(
1237 "SELECT * FROM mz_tables WHERE id = (SELECT MAX(id) FROM mz_columns)"
1238 )
1239 .is_ok()
1240 );
1241 }
1242
1243 #[mz_ore::test]
1244 fn test_validate_readonly_query_show() {
1245 assert!(validate_readonly_query("SHOW CLUSTERS").is_ok());
1246 assert!(validate_readonly_query("SHOW TABLES").is_ok());
1247 }
1248
1249 #[mz_ore::test]
1250 fn test_validate_readonly_query_explain() {
1251 assert!(validate_readonly_query("EXPLAIN SELECT 1").is_ok());
1252 }
1253
1254 #[mz_ore::test]
1255 fn test_validate_readonly_query_rejects_writes() {
1256 assert!(validate_readonly_query("INSERT INTO t VALUES (1)").is_err());
1257 assert!(validate_readonly_query("UPDATE t SET a = 1").is_err());
1258 assert!(validate_readonly_query("DELETE FROM t").is_err());
1259 assert!(validate_readonly_query("CREATE TABLE t (a INT)").is_err());
1260 assert!(validate_readonly_query("DROP TABLE t").is_err());
1261 }
1262
1263 #[mz_ore::test]
1264 fn test_validate_readonly_query_rejects_multiple() {
1265 assert!(validate_readonly_query("SELECT 1; SELECT 2").is_err());
1266 }
1267
1268 #[mz_ore::test]
1269 fn test_validate_readonly_query_rejects_empty() {
1270 assert!(validate_readonly_query("").is_err());
1271 assert!(validate_readonly_query(" ").is_err());
1272 }
1273
1274 #[mz_ore::test]
1275 fn test_validate_system_catalog_query_accepts_mz_tables() {
1276 assert!(validate_system_catalog_query("SELECT * FROM mz_tables").is_ok());
1277 assert!(validate_system_catalog_query("SELECT * FROM mz_internal.mz_comments").is_ok());
1278 assert!(
1279 validate_system_catalog_query(
1280 "SELECT * FROM mz_tables t JOIN mz_columns c ON t.id = c.id"
1281 )
1282 .is_ok()
1283 );
1284 }
1285
1286 #[mz_ore::test]
1287 fn test_validate_system_catalog_query_subqueries() {
1288 assert!(
1290 validate_system_catalog_query(
1291 "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM mz_columns)"
1292 )
1293 .is_ok()
1294 );
1295
1296 assert!(validate_system_catalog_query(
1298 "SELECT * FROM mz_tables WHERE id IN (SELECT table_id FROM mz_columns WHERE type IN (SELECT id FROM mz_types))"
1299 )
1300 .is_ok());
1301
1302 assert!(
1304 validate_system_catalog_query(
1305 "SELECT * FROM (SELECT * FROM mz_tables WHERE name LIKE 'test%') AS t"
1306 )
1307 .is_ok()
1308 );
1309
1310 assert!(
1312 validate_system_catalog_query(
1313 "SELECT * FROM mz_tables WHERE id IN (SELECT table_id FROM user_data)"
1314 )
1315 .is_err()
1316 );
1317
1318 assert!(validate_system_catalog_query(
1320 "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM (SELECT id FROM user_table) AS t)"
1321 )
1322 .is_err());
1323 }
1324
1325 #[mz_ore::test]
1326 fn test_validate_system_catalog_query_rejects_user_tables() {
1327 assert!(validate_system_catalog_query("SELECT * FROM user_data").is_err());
1328 assert!(validate_system_catalog_query("SELECT * FROM my_table").is_err());
1329 assert!(
1331 validate_system_catalog_query("SELECT * FROM user_data WHERE 'mz_' IS NOT NULL")
1332 .is_err()
1333 );
1334 }
1335
1336 #[mz_ore::test]
1337 fn test_validate_system_catalog_query_allows_functions() {
1338 assert!(
1340 validate_system_catalog_query(
1341 "SELECT date_part('year', now())::int4 AS y FROM mz_tables LIMIT 1"
1342 )
1343 .is_ok()
1344 );
1345 assert!(validate_system_catalog_query("SELECT length(name) FROM mz_tables").is_ok());
1346 assert!(
1347 validate_system_catalog_query(
1348 "SELECT count(*) FROM mz_sources WHERE now() > created_at"
1349 )
1350 .is_ok()
1351 );
1352 }
1353
1354 #[mz_ore::test]
1355 fn test_validate_system_catalog_query_schema_qualified() {
1356 assert!(validate_system_catalog_query("SELECT * FROM mz_catalog.mz_tables").is_ok());
1358 assert!(validate_system_catalog_query("SELECT * FROM mz_internal.mz_sessions").is_ok());
1359 assert!(validate_system_catalog_query("SELECT * FROM pg_catalog.pg_type").is_ok());
1360 assert!(validate_system_catalog_query("SELECT * FROM information_schema.tables").is_ok());
1361
1362 assert!(validate_system_catalog_query("SELECT * FROM public.user_table").is_err());
1364 assert!(validate_system_catalog_query("SELECT * FROM myschema.mytable").is_err());
1365
1366 assert!(
1368 validate_system_catalog_query("SELECT * FROM mz_unsafe.mz_some_table").is_err(),
1369 "mz_unsafe schema should be blocked even though it is a system schema"
1370 );
1371
1372 assert!(
1374 validate_system_catalog_query(
1375 "SELECT * FROM mz_catalog.mz_tables JOIN public.user_data ON true"
1376 )
1377 .is_err()
1378 );
1379 }
1380
1381 #[mz_ore::test]
1382 fn test_validate_system_catalog_query_adversarial_cases() {
1383 assert!(
1385 validate_system_catalog_query(
1386 "WITH user_cte AS (SELECT * FROM user_data) \
1387 SELECT * FROM mz_tables, user_cte"
1388 )
1389 .is_err(),
1390 "Should reject CTE referencing user table"
1391 );
1392
1393 assert!(
1395 validate_system_catalog_query(
1396 "WITH \
1397 cte1 AS (SELECT * FROM mz_tables), \
1398 cte2 AS (SELECT * FROM cte1), \
1399 cte3 AS (SELECT * FROM user_data) \
1400 SELECT * FROM cte2"
1401 )
1402 .is_err(),
1403 "Should reject CTE chain with user table"
1404 );
1405
1406 assert!(
1408 validate_system_catalog_query(
1409 "SELECT * FROM mz_tables t1 \
1410 JOIN user_data u ON t1.id = u.id \
1411 JOIN mz_sources s ON t1.id = s.id"
1412 )
1413 .is_err(),
1414 "Should reject multi-join with user table"
1415 );
1416
1417 assert!(
1419 validate_system_catalog_query(
1420 "SELECT * FROM mz_tables t \
1421 LEFT JOIN user_data u ON t.id = u.table_id \
1422 WHERE u.id IS NULL"
1423 )
1424 .is_err(),
1425 "Should reject LEFT JOIN with user table"
1426 );
1427
1428 assert!(
1430 validate_system_catalog_query(
1431 "SELECT * FROM mz_tables WHERE id IN \
1432 (SELECT table_id FROM (SELECT * FROM user_data) AS u)"
1433 )
1434 .is_err(),
1435 "Should reject nested subquery with user table"
1436 );
1437
1438 assert!(
1440 validate_system_catalog_query(
1441 "SELECT name FROM mz_tables \
1442 UNION \
1443 SELECT name FROM user_data"
1444 )
1445 .is_err(),
1446 "Should reject UNION with user table"
1447 );
1448
1449 assert!(
1451 validate_system_catalog_query(
1452 "SELECT id FROM mz_sources \
1453 UNION ALL \
1454 SELECT id FROM products"
1455 )
1456 .is_err(),
1457 "Should reject UNION ALL with user table"
1458 );
1459
1460 assert!(
1462 validate_system_catalog_query("SELECT * FROM mz_tables CROSS JOIN user_data").is_err(),
1463 "Should reject CROSS JOIN with user table"
1464 );
1465
1466 assert!(
1468 validate_system_catalog_query(
1469 "SELECT t.*, (SELECT COUNT(*) FROM user_data) AS cnt FROM mz_tables t"
1470 )
1471 .is_err(),
1472 "Should reject subquery in SELECT with user table"
1473 );
1474
1475 assert!(
1477 validate_system_catalog_query("SELECT * FROM mz_catalogg.fake_table").is_err(),
1478 "Should reject typo-squatting schema name"
1479 );
1480 assert!(
1481 validate_system_catalog_query("SELECT * FROM mz_catalog_hack.fake_table").is_err(),
1482 "Should reject fake schema with mz_catalog prefix"
1483 );
1484
1485 assert!(
1487 validate_system_catalog_query(
1488 "SELECT * FROM mz_tables t, LATERAL (SELECT * FROM user_data WHERE id = t.id) u"
1489 )
1490 .is_err(),
1491 "Should reject LATERAL join with user table"
1492 );
1493
1494 assert!(
1496 validate_system_catalog_query(
1497 "WITH \
1498 tables AS (SELECT * FROM mz_tables), \
1499 sources AS (SELECT * FROM mz_sources) \
1500 SELECT t.name, s.name \
1501 FROM tables t \
1502 JOIN sources s ON t.id = s.id \
1503 WHERE t.id IN (SELECT id FROM mz_columns)"
1504 )
1505 .is_ok(),
1506 "Should allow complex query with only system tables"
1507 );
1508
1509 assert!(
1511 validate_system_catalog_query(
1512 "SELECT name FROM mz_tables \
1513 UNION \
1514 SELECT name FROM mz_sources"
1515 )
1516 .is_ok(),
1517 "Should allow UNION of system tables"
1518 );
1519 }
1520
1521 #[mz_ore::test]
1522 fn test_validate_system_catalog_query_rejects_constant_queries() {
1523 assert!(
1526 validate_system_catalog_query("SELECT 1").is_err(),
1527 "Should reject constant SELECT with no table references"
1528 );
1529 assert!(
1530 validate_system_catalog_query("SELECT 1 + 2, 'hello'").is_err(),
1531 "Should reject constant expression SELECT"
1532 );
1533 assert!(
1534 validate_system_catalog_query("SELECT now()").is_err(),
1535 "Should reject function-only SELECT with no table references"
1536 );
1537 }
1538
1539 #[mz_ore::test]
1540 fn test_validate_system_catalog_query_rejects_mixed_tables() {
1541 assert!(
1542 validate_system_catalog_query(
1543 "SELECT * FROM mz_tables t JOIN user_data u ON t.id = u.table_id"
1544 )
1545 .is_err()
1546 );
1547 }
1548
1549 #[mz_ore::test]
1550 fn test_validate_system_catalog_query_allows_show() {
1551 assert!(
1553 validate_system_catalog_query("SHOW TABLES FROM mz_internal").is_ok(),
1554 "SHOW TABLES FROM mz_internal should be allowed"
1555 );
1556 assert!(
1557 validate_system_catalog_query("SHOW TABLES FROM mz_catalog").is_ok(),
1558 "SHOW TABLES FROM mz_catalog should be allowed"
1559 );
1560 assert!(
1561 validate_system_catalog_query("SHOW CLUSTERS").is_ok(),
1562 "SHOW CLUSTERS should be allowed"
1563 );
1564 assert!(
1565 validate_system_catalog_query("SHOW SOURCES").is_ok(),
1566 "SHOW SOURCES should be allowed"
1567 );
1568 assert!(
1569 validate_system_catalog_query("SHOW TABLES").is_ok(),
1570 "SHOW TABLES should be allowed"
1571 );
1572 }
1573
1574 #[mz_ore::test]
1575 fn test_validate_system_catalog_query_allows_explain() {
1576 assert!(
1577 validate_system_catalog_query("EXPLAIN SELECT * FROM mz_tables").is_ok(),
1578 "EXPLAIN of system table query should be allowed"
1579 );
1580 assert!(
1581 validate_system_catalog_query("EXPLAIN SELECT 1").is_ok(),
1582 "EXPLAIN SELECT 1 should be allowed"
1583 );
1584 }
1585
1586 #[mz_ore::test(tokio::test)]
1589 async fn test_tools_list_agent_query_tool_disabled() {
1590 let result = handle_tools_list(McpEndpointType::Agent, false, 1_000_000)
1591 .await
1592 .unwrap();
1593 let McpResult::ToolsList(list) = result else {
1594 panic!("Expected ToolsList result");
1595 };
1596 let tool_names: Vec<&str> = list.tools.iter().map(|t| t.name.as_str()).collect();
1597 assert!(
1598 tool_names.contains(&"get_data_products"),
1599 "get_data_products should always be present"
1600 );
1601 assert!(
1602 tool_names.contains(&"get_data_product_details"),
1603 "get_data_product_details should always be present"
1604 );
1605 assert!(
1606 tool_names.contains(&"read_data_product"),
1607 "read_data_product should always be present"
1608 );
1609 assert!(
1610 !tool_names.contains(&"query"),
1611 "query tool should be hidden when disabled"
1612 );
1613 }
1614
1615 #[mz_ore::test(tokio::test)]
1616 async fn test_tools_list_agent_query_tool_enabled() {
1617 let result = handle_tools_list(McpEndpointType::Agent, true, 1_000_000)
1618 .await
1619 .unwrap();
1620 let McpResult::ToolsList(list) = result else {
1621 panic!("Expected ToolsList result");
1622 };
1623 let tool_names: Vec<&str> = list.tools.iter().map(|t| t.name.as_str()).collect();
1624 assert!(
1625 tool_names.contains(&"get_data_products"),
1626 "get_data_products should always be present"
1627 );
1628 assert!(
1629 tool_names.contains(&"get_data_product_details"),
1630 "get_data_product_details should always be present"
1631 );
1632 assert!(
1633 tool_names.contains(&"read_data_product"),
1634 "read_data_product should always be present"
1635 );
1636 assert!(
1637 tool_names.contains(&"query"),
1638 "query tool should be present when enabled"
1639 );
1640 }
1641
1642 #[mz_ore::test(tokio::test)]
1643 async fn test_tools_list_developer_unaffected_by_query_flag() {
1644 for flag in [true, false] {
1646 let result = handle_tools_list(McpEndpointType::Developer, flag, 1_000_000)
1647 .await
1648 .unwrap();
1649 let McpResult::ToolsList(list) = result else {
1650 panic!("Expected ToolsList result");
1651 };
1652 let tool_names: Vec<&str> = list.tools.iter().map(|t| t.name.as_str()).collect();
1653 assert!(
1654 tool_names.contains(&"query_system_catalog"),
1655 "query_system_catalog should always be present on developer"
1656 );
1657 assert!(
1658 !tool_names.contains(&"query"),
1659 "query tool should never appear on developer"
1660 );
1661 }
1662 }
1663
1664 #[mz_ore::test]
1667 fn test_format_rows_response_within_limit() {
1668 let rows = vec![vec![json!("a"), json!(1)], vec![json!("b"), json!(2)]];
1669 let result = format_rows_response(rows, 1_000_000).unwrap();
1670 let McpResult::ToolContent(content) = result else {
1671 panic!("Expected ToolContent");
1672 };
1673 assert_eq!(content.content.len(), 1);
1674 assert!(content.content[0].text.contains("\"a\""));
1675 assert!(content.content[0].text.contains("\"b\""));
1676 }
1677
1678 #[mz_ore::test]
1679 fn test_format_rows_response_errors_when_over_limit() {
1680 let rows: Vec<Vec<serde_json::Value>> = (0..100)
1681 .map(|i| vec![json!(format!("row_{}", i)), json!(i)])
1682 .collect();
1683 let err = format_rows_response(rows, 500).unwrap_err();
1684 let msg = err.to_string();
1685 assert!(
1686 msg.contains("exceeds the 500 byte limit"),
1687 "Error should mention the size limit, got: {msg}"
1688 );
1689 assert!(
1690 msg.contains("Use LIMIT or WHERE"),
1691 "Error should suggest narrowing the query, got: {msg}"
1692 );
1693 }
1694
1695 #[mz_ore::test]
1696 fn test_format_rows_response_empty_rows() {
1697 let rows: Vec<Vec<serde_json::Value>> = vec![];
1698 let result = format_rows_response(rows, 1000).unwrap();
1699 let McpResult::ToolContent(content) = result else {
1700 panic!("Expected ToolContent");
1701 };
1702 assert_eq!(content.content.len(), 1);
1703 assert_eq!(content.content[0].text, "[]");
1704 }
1705
1706 #[mz_ore::test]
1709 fn test_safe_data_product_name_valid() {
1710 assert_eq!(
1712 safe_data_product_name(r#""materialize"."public"."my_view""#).unwrap(),
1713 r#""materialize"."public"."my_view""#
1714 );
1715 assert_eq!(
1717 safe_data_product_name(r#""public"."my_view""#).unwrap(),
1718 r#""public"."my_view""#
1719 );
1720 assert_eq!(safe_data_product_name("my_view").unwrap(), r#""my_view""#);
1722 }
1723
1724 #[mz_ore::test]
1725 fn test_safe_data_product_name_rejects_empty() {
1726 assert!(safe_data_product_name("").is_err());
1727 assert!(safe_data_product_name(" ").is_err());
1728 }
1729
1730 #[mz_ore::test]
1731 fn test_safe_data_product_name_rejects_sql_injection() {
1732 assert!(safe_data_product_name("my_view; DROP TABLE users").is_err());
1734 assert!(safe_data_product_name("my_view UNION SELECT * FROM secrets").is_err());
1736 assert!(safe_data_product_name("my_view, secrets").is_err());
1738 assert!(safe_data_product_name("my_view WHERE 1=1 --").is_err());
1740 }
1741
1742 #[mz_ore::test]
1743 fn test_mcp_error_codes() {
1744 assert_eq!(
1745 McpRequestError::InvalidJsonRpcVersion.error_code(),
1746 error_codes::INVALID_REQUEST
1747 );
1748 assert_eq!(
1749 McpRequestError::MethodNotFound("test".to_string()).error_code(),
1750 error_codes::METHOD_NOT_FOUND
1751 );
1752 assert_eq!(
1753 McpRequestError::QueryExecutionFailed("test".to_string()).error_code(),
1754 error_codes::INTERNAL_ERROR
1755 );
1756 }
1757}