1use std::sync::Arc;
27use std::time::Duration;
28
29use anyhow::anyhow;
30use axum::Extension;
31use axum::Json;
32use axum::response::IntoResponse;
33use http::{HeaderMap, HeaderValue, StatusCode};
34use mz_adapter_types::dyncfgs::{
35 ENABLE_MCP_AGENT, ENABLE_MCP_AGENT_QUERY_TOOL, ENABLE_MCP_DEVELOPER, MCP_MAX_RESPONSE_SIZE,
36};
37use mz_repr::namespaces::{self, SYSTEM_SCHEMAS};
38use mz_sql::parse::parse;
39use mz_sql::session::metadata::SessionMetadata;
40use mz_sql_parser::ast::display::{AstDisplay, escaped_string_literal};
41use mz_sql_parser::ast::visit::{self, Visit};
42use mz_sql_parser::ast::{Raw, RawItemName};
43use mz_sql_parser::parser::parse_item_name;
44use serde::{Deserialize, Serialize};
45use serde_json::json;
46use thiserror::Error;
47use tracing::{debug, warn};
48
49use crate::http::AuthedClient;
50use crate::http::sql::{SqlRequest, SqlResponse, SqlResult, execute_request};
51
52const JSONRPC_VERSION: &str = "2.0";
56
57const MCP_PROTOCOL_VERSION: &str = "2025-11-25";
60
61const MCP_REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
66
67const DISCOVERY_QUERY: &str = "SELECT * FROM mz_internal.mz_mcp_data_products";
69const DETAILS_QUERY_PREFIX: &str =
71 "SELECT * FROM mz_internal.mz_mcp_data_product_details WHERE object_name = ";
72
73#[derive(Debug, Error)]
75enum McpRequestError {
76 #[error("Invalid JSON-RPC version: expected 2.0")]
77 InvalidJsonRpcVersion,
78 #[error("Method not found: {0}")]
79 #[allow(dead_code)] MethodNotFound(String),
81 #[error("Tool not found: {0}")]
82 ToolNotFound(String),
83 #[error("Data product not found: {0}")]
84 DataProductNotFound(String),
85 #[error("Query validation failed: {0}")]
86 QueryValidationFailed(String),
87 #[error("Query execution failed: {0}")]
88 QueryExecutionFailed(String),
89 #[error("Internal error: {0}")]
90 Internal(#[from] anyhow::Error),
91}
92
93impl McpRequestError {
94 fn error_code(&self) -> i32 {
95 match self {
96 Self::InvalidJsonRpcVersion => error_codes::INVALID_REQUEST,
97 Self::MethodNotFound(_) => error_codes::METHOD_NOT_FOUND,
98 Self::ToolNotFound(_) => error_codes::INVALID_PARAMS,
99 Self::DataProductNotFound(_) => error_codes::INVALID_PARAMS,
100 Self::QueryValidationFailed(_) => error_codes::INVALID_PARAMS,
101 Self::QueryExecutionFailed(_) | Self::Internal(_) => error_codes::INTERNAL_ERROR,
102 }
103 }
104
105 fn error_type(&self) -> &'static str {
106 match self {
107 Self::InvalidJsonRpcVersion => "InvalidRequest",
108 Self::MethodNotFound(_) => "MethodNotFound",
109 Self::ToolNotFound(_) => "ToolNotFound",
110 Self::DataProductNotFound(_) => "DataProductNotFound",
111 Self::QueryValidationFailed(_) => "ValidationError",
112 Self::QueryExecutionFailed(_) => "ExecutionError",
113 Self::Internal(_) => "InternalError",
114 }
115 }
116}
117
118#[derive(Debug, Deserialize)]
120pub(crate) struct McpRequest {
121 jsonrpc: String,
122 id: Option<serde_json::Value>,
123 #[serde(flatten)]
124 method: McpMethod,
125}
126
127#[derive(Debug, Deserialize)]
129#[serde(tag = "method", content = "params")]
130enum McpMethod {
131 #[serde(rename = "initialize")]
133 Initialize(#[allow(dead_code)] InitializeParams),
134 #[serde(rename = "tools/list")]
135 ToolsList,
136 #[serde(rename = "tools/call")]
137 ToolsCall(ToolsCallParams),
138 #[serde(other)]
140 Unknown,
141}
142
143impl std::fmt::Display for McpMethod {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 match self {
146 McpMethod::Initialize(_) => write!(f, "initialize"),
147 McpMethod::ToolsList => write!(f, "tools/list"),
148 McpMethod::ToolsCall(_) => write!(f, "tools/call"),
149 McpMethod::Unknown => write!(f, "unknown"),
150 }
151 }
152}
153
154#[derive(Debug, Deserialize)]
155struct InitializeParams {
156 #[serde(rename = "protocolVersion")]
158 #[allow(dead_code)]
159 protocol_version: String,
160 #[serde(default)]
162 #[allow(dead_code)]
163 capabilities: serde_json::Value,
164 #[serde(rename = "clientInfo")]
166 #[allow(dead_code)]
167 client_info: Option<ClientInfo>,
168}
169
170#[derive(Debug, Deserialize)]
171struct ClientInfo {
172 #[allow(dead_code)]
173 name: String,
174 #[allow(dead_code)]
175 version: String,
176}
177
178#[derive(Debug, Deserialize)]
181#[serde(tag = "name", content = "arguments")]
182#[serde(rename_all = "snake_case")]
183enum ToolsCallParams {
184 GetDataProducts(#[serde(default)] ()),
187 GetDataProductDetails(GetDataProductDetailsParams),
188 ReadDataProduct(ReadDataProductParams),
189 Query(QueryParams),
190 QuerySystemCatalog(QuerySystemCatalogParams),
192}
193
194impl std::fmt::Display for ToolsCallParams {
195 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196 match self {
197 ToolsCallParams::GetDataProducts(_) => write!(f, "get_data_products"),
198 ToolsCallParams::GetDataProductDetails(_) => write!(f, "get_data_product_details"),
199 ToolsCallParams::ReadDataProduct(_) => write!(f, "read_data_product"),
200 ToolsCallParams::Query(_) => write!(f, "query"),
201 ToolsCallParams::QuerySystemCatalog(_) => write!(f, "query_system_catalog"),
202 }
203 }
204}
205
206#[derive(Debug, Deserialize)]
207struct GetDataProductDetailsParams {
208 name: String,
209}
210
211#[derive(Debug, Deserialize)]
212struct ReadDataProductParams {
213 name: String,
214 #[serde(default = "default_read_limit")]
215 limit: u32,
216 cluster: Option<String>,
217}
218
219fn default_read_limit() -> u32 {
220 500
221}
222
223const MAX_READ_LIMIT: u32 = 1000;
225
226#[derive(Debug, Deserialize)]
227struct QueryParams {
228 cluster: String,
229 sql_query: String,
230}
231
232#[derive(Debug, Deserialize)]
233struct QuerySystemCatalogParams {
234 sql_query: String,
235}
236
237#[derive(Debug, Serialize)]
238struct McpResponse {
239 jsonrpc: String,
240 id: serde_json::Value,
241 #[serde(skip_serializing_if = "Option::is_none")]
242 result: Option<McpResult>,
243 #[serde(skip_serializing_if = "Option::is_none")]
244 error: Option<McpError>,
245}
246
247#[derive(Debug, Serialize)]
249#[serde(untagged)]
250enum McpResult {
251 Initialize(InitializeResult),
252 ToolsList(ToolsListResult),
253 ToolContent(ToolContentResult),
254}
255
256#[derive(Debug, Serialize)]
257struct InitializeResult {
258 #[serde(rename = "protocolVersion")]
259 protocol_version: String,
260 capabilities: Capabilities,
261 #[serde(rename = "serverInfo")]
262 server_info: ServerInfo,
263 #[serde(skip_serializing_if = "Option::is_none")]
264 instructions: Option<String>,
265}
266
267#[derive(Debug, Serialize)]
268struct Capabilities {
269 tools: serde_json::Value,
270}
271
272#[derive(Debug, Serialize)]
273struct ServerInfo {
274 name: String,
275 version: String,
276}
277
278#[derive(Debug, Serialize)]
279struct ToolsListResult {
280 tools: Vec<ToolDefinition>,
281}
282
283#[derive(Debug, Serialize)]
284struct ToolDefinition {
285 name: String,
286 #[serde(skip_serializing_if = "Option::is_none")]
287 title: Option<String>,
288 description: String,
289 #[serde(rename = "inputSchema")]
290 input_schema: serde_json::Value,
291 #[serde(skip_serializing_if = "Option::is_none")]
292 annotations: Option<ToolAnnotations>,
293}
294
295#[derive(Debug, Serialize)]
298struct ToolAnnotations {
299 #[serde(rename = "readOnlyHint", skip_serializing_if = "Option::is_none")]
300 read_only_hint: Option<bool>,
301 #[serde(rename = "destructiveHint", skip_serializing_if = "Option::is_none")]
302 destructive_hint: Option<bool>,
303 #[serde(rename = "idempotentHint", skip_serializing_if = "Option::is_none")]
304 idempotent_hint: Option<bool>,
305 #[serde(rename = "openWorldHint", skip_serializing_if = "Option::is_none")]
306 open_world_hint: Option<bool>,
307}
308
309const READ_ONLY_ANNOTATIONS: ToolAnnotations = ToolAnnotations {
311 read_only_hint: Some(true),
312 destructive_hint: Some(false),
313 idempotent_hint: Some(true),
314 open_world_hint: Some(false),
315};
316
317#[derive(Debug, Serialize)]
318struct ToolContentResult {
319 content: Vec<ContentBlock>,
320 #[serde(rename = "isError")]
321 is_error: bool,
322}
323
324#[derive(Debug, Serialize)]
325struct ContentBlock {
326 #[serde(rename = "type")]
327 content_type: String,
328 text: String,
329}
330
331mod error_codes {
333 pub const INVALID_REQUEST: i32 = -32600;
334 pub const METHOD_NOT_FOUND: i32 = -32601;
335 pub const INVALID_PARAMS: i32 = -32602;
336 pub const INTERNAL_ERROR: i32 = -32603;
337}
338
339#[derive(Debug, Serialize)]
340struct McpError {
341 code: i32,
342 message: String,
343 #[serde(skip_serializing_if = "Option::is_none")]
344 data: Option<serde_json::Value>,
345}
346
347impl From<McpRequestError> for McpError {
348 fn from(err: McpRequestError) -> Self {
349 McpError {
350 code: err.error_code(),
351 message: err.to_string(),
352 data: Some(json!({
353 "error_type": err.error_type(),
354 })),
355 }
356 }
357}
358
359#[derive(Debug, Clone, Copy)]
360enum McpEndpointType {
361 Agent,
362 Developer,
363}
364
365impl std::fmt::Display for McpEndpointType {
366 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
367 match self {
368 McpEndpointType::Agent => write!(f, "agent"),
369 McpEndpointType::Developer => write!(f, "developer"),
370 }
371 }
372}
373
374pub async fn handle_mcp_method_not_allowed() -> impl IntoResponse {
377 StatusCode::METHOD_NOT_ALLOWED
378}
379
380pub async fn handle_mcp_agent(
382 headers: HeaderMap,
383 Extension(allowed_origins): Extension<Arc<Vec<HeaderValue>>>,
384 client: AuthedClient,
385 Json(body): Json<McpRequest>,
386) -> axum::response::Response {
387 if let Some(resp) = validate_origin(&headers, &allowed_origins) {
388 return resp;
389 }
390 handle_mcp_request(client, body, McpEndpointType::Agent)
391 .await
392 .into_response()
393}
394
395pub async fn handle_mcp_developer(
397 headers: HeaderMap,
398 Extension(allowed_origins): Extension<Arc<Vec<HeaderValue>>>,
399 client: AuthedClient,
400 Json(body): Json<McpRequest>,
401) -> axum::response::Response {
402 if let Some(resp) = validate_origin(&headers, &allowed_origins) {
403 return resp;
404 }
405 handle_mcp_request(client, body, McpEndpointType::Developer)
406 .await
407 .into_response()
408}
409
410fn validate_origin(
419 headers: &HeaderMap,
420 allowed: &[HeaderValue],
421) -> Option<axum::response::Response> {
422 let origin = headers.get(http::header::ORIGIN)?;
423 if mz_http_util::origin_is_allowed(origin, allowed) {
424 return None;
425 }
426 warn!(
427 origin = ?origin,
428 "MCP request rejected: origin not in allowlist",
429 );
430 Some(StatusCode::FORBIDDEN.into_response())
431}
432
433async fn handle_mcp_request(
434 mut client: AuthedClient,
435 request: McpRequest,
436 endpoint_type: McpEndpointType,
437) -> impl IntoResponse {
438 let catalog = client.client.catalog_snapshot("mcp").await;
440 let dyncfgs = catalog.system_config().dyncfgs();
441 let enabled = match endpoint_type {
442 McpEndpointType::Agent => ENABLE_MCP_AGENT.get(dyncfgs),
443 McpEndpointType::Developer => ENABLE_MCP_DEVELOPER.get(dyncfgs),
444 };
445 if !enabled {
446 debug!(endpoint = %endpoint_type, "MCP endpoint disabled by feature flag");
447 return StatusCode::SERVICE_UNAVAILABLE.into_response();
448 }
449
450 let query_tool_enabled = ENABLE_MCP_AGENT_QUERY_TOOL.get(dyncfgs);
451 let max_response_size = MCP_MAX_RESPONSE_SIZE.get(dyncfgs);
452
453 let user = client.client.session().user().name.clone();
454 let is_notification = request.id.is_none();
455
456 debug!(
457 method = %request.method,
458 endpoint = %endpoint_type,
459 user = %user,
460 is_notification = is_notification,
461 "MCP request received"
462 );
463
464 if is_notification {
466 debug!(method = %request.method, "Received notification (no response will be sent)");
467 return StatusCode::OK.into_response();
468 }
469
470 let request_id = request.id.clone().unwrap_or(serde_json::Value::Null);
471
472 let result = tokio::time::timeout(
474 MCP_REQUEST_TIMEOUT,
475 mz_ore::task::spawn(|| "mcp_request", async move {
476 handle_mcp_request_inner(
477 &mut client,
478 request,
479 endpoint_type,
480 query_tool_enabled,
481 max_response_size,
482 )
483 .await
484 }),
485 )
486 .await;
487
488 let response = match result {
489 Ok(inner) => inner,
490 Err(_elapsed) => {
491 warn!(
492 endpoint = %endpoint_type,
493 timeout = ?MCP_REQUEST_TIMEOUT,
494 "MCP request timed out",
495 );
496 McpResponse {
497 jsonrpc: JSONRPC_VERSION.to_string(),
498 id: request_id,
499 result: None,
500 error: Some(
501 McpRequestError::QueryExecutionFailed(format!(
502 "Request timed out after {} seconds.",
503 MCP_REQUEST_TIMEOUT.as_secs(),
504 ))
505 .into(),
506 ),
507 }
508 }
509 };
510
511 (StatusCode::OK, Json(response)).into_response()
512}
513
514async fn handle_mcp_request_inner(
515 client: &mut AuthedClient,
516 request: McpRequest,
517 endpoint_type: McpEndpointType,
518 query_tool_enabled: bool,
519 max_response_size: usize,
520) -> McpResponse {
521 let request_id = request.id.clone().unwrap_or(serde_json::Value::Null);
523
524 let result = handle_mcp_method(
525 client,
526 &request,
527 endpoint_type,
528 query_tool_enabled,
529 max_response_size,
530 )
531 .await;
532
533 match result {
534 Ok(result_value) => McpResponse {
535 jsonrpc: JSONRPC_VERSION.to_string(),
536 id: request_id,
537 result: Some(result_value),
538 error: None,
539 },
540 Err(e) => {
541 if !matches!(
543 e,
544 McpRequestError::MethodNotFound(_) | McpRequestError::InvalidJsonRpcVersion
545 ) {
546 warn!(error = %e, method = %request.method, "MCP method execution failed");
547 }
548 McpResponse {
549 jsonrpc: JSONRPC_VERSION.to_string(),
550 id: request_id,
551 result: None,
552 error: Some(e.into()),
553 }
554 }
555 }
556}
557
558async fn handle_mcp_method(
559 client: &mut AuthedClient,
560 request: &McpRequest,
561 endpoint_type: McpEndpointType,
562 query_tool_enabled: bool,
563 max_response_size: usize,
564) -> Result<McpResult, McpRequestError> {
565 if request.jsonrpc != JSONRPC_VERSION {
567 return Err(McpRequestError::InvalidJsonRpcVersion);
568 }
569
570 match &request.method {
572 McpMethod::Initialize(_) => {
573 debug!(endpoint = %endpoint_type, "Processing initialize");
574 handle_initialize(endpoint_type).await
575 }
576 McpMethod::ToolsList => {
577 debug!(endpoint = %endpoint_type, "Processing tools/list");
578 handle_tools_list(endpoint_type, query_tool_enabled, max_response_size).await
579 }
580 McpMethod::ToolsCall(params) => {
581 debug!(tool = %params, endpoint = %endpoint_type, "Processing tools/call");
582 handle_tools_call(
583 client,
584 params,
585 endpoint_type,
586 query_tool_enabled,
587 max_response_size,
588 )
589 .await
590 }
591 McpMethod::Unknown => Err(McpRequestError::MethodNotFound(
592 "unknown method".to_string(),
593 )),
594 }
595}
596
597async fn handle_initialize(endpoint_type: McpEndpointType) -> Result<McpResult, McpRequestError> {
598 Ok(McpResult::Initialize(InitializeResult {
599 protocol_version: MCP_PROTOCOL_VERSION.to_string(),
600 capabilities: Capabilities { tools: json!({}) },
601 server_info: ServerInfo {
602 name: format!("materialize-mcp-{}", endpoint_type),
603 version: env!("CARGO_PKG_VERSION").to_string(),
604 },
605 instructions: None,
606 }))
607}
608
609async fn handle_tools_list(
610 endpoint_type: McpEndpointType,
611 query_tool_enabled: bool,
612 max_response_size: usize,
613) -> Result<McpResult, McpRequestError> {
614 let size_hint = format!("Response limit: {} MB.", max_response_size / 1_000_000);
615
616 let tools = match endpoint_type {
617 McpEndpointType::Agent => {
618 let mut tools = vec![
619 ToolDefinition {
620 name: "get_data_products".to_string(),
621 title: Some("List Data Products".to_string()),
622 description: "Discover all available real-time data views (data products) that represent business entities like customers, orders, products, etc. Each data product provides fresh, queryable data with defined schemas. Use this first to see what data is available before querying specific information.".to_string(),
623 input_schema: json!({
624 "type": "object",
625 "properties": {},
626 "required": []
627 }),
628 annotations: Some(READ_ONLY_ANNOTATIONS),
629 },
630 ToolDefinition {
631 name: "get_data_product_details".to_string(),
632 title: Some("Get Data Product Details".to_string()),
633 description: "Get the complete schema and structure of a specific data product. This shows you exactly what fields are available, their types, and what data you can query. Use this after finding a data product from get_data_products() to understand how to query it.".to_string(),
634 input_schema: json!({
635 "type": "object",
636 "properties": {
637 "name": {
638 "type": "string",
639 "description": "Exact name of the data product from get_data_products() list"
640 }
641 },
642 "required": ["name"]
643 }),
644 annotations: Some(READ_ONLY_ANNOTATIONS),
645 },
646 ToolDefinition {
647 name: "read_data_product".to_string(),
648 title: Some("Read Data Product".to_string()),
649 description: format!("Read rows from a specific data product. Returns up to `limit` rows (default 500, max 1000). The data product must exist in the catalog (use get_data_products() to discover available products). Use this to retrieve actual data from a known data product. {size_hint}"),
650 input_schema: json!({
651 "type": "object",
652 "properties": {
653 "name": {
654 "type": "string",
655 "description": "Exact fully-qualified name of the data product (e.g. '\"materialize\".\"schema\".\"view_name\"')"
656 },
657 "limit": {
658 "type": "integer",
659 "description": "Maximum number of rows to return (default 500, max 1000)",
660 "default": 500
661 },
662 "cluster": {
663 "type": "string",
664 "description": "Optional cluster override. If omitted, uses the cluster from the data product catalog."
665 }
666 },
667 "required": ["name"]
668 }),
669 annotations: Some(READ_ONLY_ANNOTATIONS),
670 },
671 ];
672 if query_tool_enabled {
673 tools.push(ToolDefinition {
674 name: "query".to_string(),
675 title: Some("Query Data Products".to_string()),
676 description: format!("Execute SQL queries against real-time data products to retrieve current business information. Use standard PostgreSQL syntax. You can JOIN multiple data products together, but ONLY if they are all hosted on the same cluster. Always specify the cluster parameter from the data product details. This provides fresh, up-to-date results from materialized views. {size_hint}"),
677 input_schema: json!({
678 "type": "object",
679 "properties": {
680 "cluster": {
681 "type": "string",
682 "description": "Exact cluster name from the data product details - required for query execution"
683 },
684 "sql_query": {
685 "type": "string",
686 "description": "PostgreSQL-compatible SELECT statement to retrieve data. Use the fully qualified data product name exactly as provided (with double quotes). You can JOIN multiple data products, but only those on the same cluster."
687 }
688 },
689 "required": ["cluster", "sql_query"]
690 }),
691 annotations: Some(READ_ONLY_ANNOTATIONS),
692 });
693 }
694 tools
695 }
696 McpEndpointType::Developer => {
697 vec![ToolDefinition {
698 name: "query_system_catalog".to_string(),
699 title: Some("Query System Catalog".to_string()),
700 description: concat!(
701 "Query Materialize system catalog tables for troubleshooting and observability. ",
702 "Only mz_*, pg_catalog, and information_schema tables are accessible.\n\n",
703 "Key tables by scenario:\n",
704 "- Freshness: mz_internal.mz_wallclock_global_lag_recent_history, mz_internal.mz_materialization_lag, mz_internal.mz_hydration_statuses\n",
705 "- Memory: mz_internal.mz_cluster_replica_utilization, mz_internal.mz_cluster_replica_metrics, mz_internal.mz_dataflow_arrangement_sizes\n",
706 "- Cluster health: mz_internal.mz_cluster_replica_statuses, mz_catalog.mz_cluster_replicas\n",
707 "- Source/Sink health: mz_internal.mz_source_statuses, mz_internal.mz_sink_statuses, mz_internal.mz_source_statistics, mz_internal.mz_sink_statistics\n",
708 "- Object catalog: mz_catalog.mz_objects (all objects), mz_catalog.mz_tables, mz_catalog.mz_materialized_views, mz_catalog.mz_sources, mz_catalog.mz_sinks\n\n",
709 "Use SHOW TABLES FROM mz_internal or mz_catalog to discover more tables.",
710 ).to_owned() + &format!(" {size_hint}"),
711 input_schema: json!({
712 "type": "object",
713 "properties": {
714 "sql_query": {
715 "type": "string",
716 "description": "PostgreSQL-compatible SELECT, SHOW, or EXPLAIN query referencing mz_* system catalog tables"
717 }
718 },
719 "required": ["sql_query"]
720 }),
721 annotations: Some(READ_ONLY_ANNOTATIONS),
722 }]
723 }
724 };
725
726 Ok(McpResult::ToolsList(ToolsListResult { tools }))
727}
728
729async fn handle_tools_call(
730 client: &mut AuthedClient,
731 params: &ToolsCallParams,
732 endpoint_type: McpEndpointType,
733 query_tool_enabled: bool,
734 max_response_size: usize,
735) -> Result<McpResult, McpRequestError> {
736 match (endpoint_type, params) {
737 (McpEndpointType::Agent, ToolsCallParams::GetDataProducts(_)) => {
738 get_data_products(client, max_response_size).await
739 }
740 (McpEndpointType::Agent, ToolsCallParams::GetDataProductDetails(p)) => {
741 get_data_product_details(client, &p.name, max_response_size).await
742 }
743 (McpEndpointType::Agent, ToolsCallParams::ReadDataProduct(p)) => {
744 read_data_product(
745 client,
746 &p.name,
747 p.limit,
748 p.cluster.as_deref(),
749 max_response_size,
750 )
751 .await
752 }
753 (McpEndpointType::Agent, ToolsCallParams::Query(_)) if !query_tool_enabled => {
754 Err(McpRequestError::ToolNotFound(
755 "query tool is not available. Use get_data_products, get_data_product_details, and read_data_product instead.".to_string(),
756 ))
757 }
758 (McpEndpointType::Agent, ToolsCallParams::Query(p)) => {
759 execute_query(client, &p.cluster, &p.sql_query, max_response_size).await
760 }
761 (McpEndpointType::Developer, ToolsCallParams::QuerySystemCatalog(p)) => {
762 query_system_catalog(client, &p.sql_query, max_response_size).await
763 }
764 (endpoint, tool) => Err(McpRequestError::ToolNotFound(format!(
766 "{} is not available on {} endpoint",
767 tool, endpoint
768 ))),
769 }
770}
771
772async fn execute_sql(
774 client: &mut AuthedClient,
775 query: &str,
776) -> Result<Vec<Vec<serde_json::Value>>, McpRequestError> {
777 let mut response = SqlResponse::new();
778
779 execute_request(
780 client,
781 SqlRequest::Simple {
782 query: query.to_string(),
783 },
784 &mut response,
785 )
786 .await
787 .map_err(|e| McpRequestError::QueryExecutionFailed(e.to_string()))?;
788
789 for result in response.results {
792 match result {
793 SqlResult::Rows { rows, .. } => return Ok(rows),
794 SqlResult::Err { error, .. } => {
795 return Err(McpRequestError::QueryExecutionFailed(error.message));
796 }
797 SqlResult::Ok { .. } => continue,
798 }
799 }
800
801 Err(McpRequestError::QueryExecutionFailed(
802 "Query did not return any results".to_string(),
803 ))
804}
805
806fn format_rows_response(
813 rows: Vec<Vec<serde_json::Value>>,
814 max_size: usize,
815) -> Result<McpResult, McpRequestError> {
816 let text =
817 serde_json::to_string_pretty(&rows).map_err(|e| McpRequestError::Internal(anyhow!(e)))?;
818
819 if text.len() > max_size {
820 return Err(McpRequestError::QueryExecutionFailed(format!(
821 "Response size ({} bytes) exceeds the {} byte limit. \
822 Use LIMIT or WHERE to narrow your query.",
823 text.len(),
824 max_size,
825 )));
826 }
827
828 Ok(McpResult::ToolContent(ToolContentResult {
829 content: vec![ContentBlock {
830 content_type: "text".to_string(),
831 text,
832 }],
833 is_error: false,
834 }))
835}
836
837async fn get_data_products(
838 client: &mut AuthedClient,
839 max_response_size: usize,
840) -> Result<McpResult, McpRequestError> {
841 debug!("Executing get_data_products");
842 let rows = execute_sql(client, DISCOVERY_QUERY).await?;
843 debug!("get_data_products returned {} rows", rows.len());
844
845 format_rows_response(rows, max_response_size)
846}
847
848async fn get_data_product_details(
849 client: &mut AuthedClient,
850 name: &str,
851 max_response_size: usize,
852) -> Result<McpResult, McpRequestError> {
853 debug!(name = %name, "Executing get_data_product_details");
854
855 let query = format!("{}{}", DETAILS_QUERY_PREFIX, escaped_string_literal(name));
856
857 let rows = execute_sql(client, &query).await?;
858
859 if rows.is_empty() {
860 return Err(McpRequestError::DataProductNotFound(name.to_string()));
861 }
862
863 format_rows_response(rows, max_response_size)
864}
865
866fn safe_data_product_name(name: &str) -> Result<String, McpRequestError> {
873 let name = name.trim();
874 if name.is_empty() {
875 return Err(McpRequestError::QueryValidationFailed(
876 "Data product name cannot be empty".to_string(),
877 ));
878 }
879
880 let parsed = parse_item_name(name).map_err(|_| {
881 McpRequestError::QueryValidationFailed(format!(
882 "Invalid data product name: {}. Expected a valid object name, \
883 e.g. '\"database\".\"schema\".\"name\"' or 'my_view'",
884 name
885 ))
886 })?;
887
888 Ok(parsed.to_ast_string_stable())
891}
892
893async fn read_data_product(
901 client: &mut AuthedClient,
902 name: &str,
903 limit: u32,
904 cluster_override: Option<&str>,
905 max_response_size: usize,
906) -> Result<McpResult, McpRequestError> {
907 debug!(name = %name, limit = limit, cluster_override = ?cluster_override, "Executing read_data_product");
908
909 let safe_name = safe_data_product_name(name)?;
911
912 let lookup_query = format!(
920 "SELECT 1 FROM mz_internal.mz_mcp_data_products WHERE object_name = {}",
921 escaped_string_literal(name)
922 );
923 let lookup_rows = execute_sql(client, &lookup_query).await?;
924 if lookup_rows.is_empty() {
925 return Err(McpRequestError::DataProductNotFound(name.to_string()));
926 }
927
928 let clamped_limit = limit.min(MAX_READ_LIMIT);
929
930 let read_query = match cluster_override {
931 Some(cluster) => format!(
932 "BEGIN READ ONLY; SET CLUSTER = {}; SELECT * FROM {} LIMIT {}\n; COMMIT;",
933 escaped_string_literal(cluster),
934 safe_name,
935 clamped_limit,
936 ),
937 None => format!("SELECT * FROM {} LIMIT {}", safe_name, clamped_limit),
939 };
940
941 let rows = execute_sql(client, &read_query).await?;
942
943 format_rows_response(rows, max_response_size)
944}
945
946fn validate_readonly_query(sql: &str) -> Result<(), McpRequestError> {
948 let sql = sql.trim();
949 if sql.is_empty() {
950 return Err(McpRequestError::QueryValidationFailed(
951 "Empty query".to_string(),
952 ));
953 }
954
955 let stmts = parse(sql).map_err(|e| {
957 McpRequestError::QueryValidationFailed(format!("Failed to parse SQL: {}", e))
958 })?;
959
960 if stmts.len() != 1 {
962 return Err(McpRequestError::QueryValidationFailed(format!(
963 "Only one query allowed at a time. Found {} statements.",
964 stmts.len()
965 )));
966 }
967
968 let stmt = &stmts[0];
970 use mz_sql_parser::ast::Statement;
971
972 match &stmt.ast {
973 Statement::Select(_) | Statement::Show(_) | Statement::ExplainPlan(_) => {
974 Ok(())
976 }
977 _ => Err(McpRequestError::QueryValidationFailed(
978 "Only SELECT, SHOW, and EXPLAIN statements are allowed".to_string(),
979 )),
980 }
981}
982
983async fn execute_query(
984 client: &mut AuthedClient,
985 cluster: &str,
986 sql_query: &str,
987 max_response_size: usize,
988) -> Result<McpResult, McpRequestError> {
989 debug!(cluster = %cluster, "Executing user query");
990
991 validate_readonly_query(sql_query)?;
992
993 let combined_query = format!(
996 "BEGIN READ ONLY; SET CLUSTER = {}; {}\n; COMMIT;",
997 escaped_string_literal(cluster),
998 sql_query
999 );
1000
1001 let rows = execute_sql(client, &combined_query).await?;
1002
1003 format_rows_response(rows, max_response_size)
1004}
1005
1006async fn query_system_catalog(
1007 client: &mut AuthedClient,
1008 sql_query: &str,
1009 max_response_size: usize,
1010) -> Result<McpResult, McpRequestError> {
1011 debug!("Executing query_system_catalog");
1012
1013 validate_readonly_query(sql_query)?;
1015
1016 validate_system_catalog_query(sql_query)?;
1018
1019 let combined_query = format!(
1025 "BEGIN READ ONLY; SET search_path = mz_catalog, mz_internal, pg_catalog, information_schema; {}; COMMIT;",
1026 sql_query
1027 );
1028
1029 let rows = execute_sql(client, &combined_query).await?;
1030
1031 format_rows_response(rows, max_response_size)
1032}
1033
1034struct TableReferenceCollector {
1036 tables: Vec<(Option<String>, String)>,
1038 cte_names: std::collections::BTreeSet<String>,
1040}
1041
1042impl TableReferenceCollector {
1043 fn new() -> Self {
1044 Self {
1045 tables: Vec::new(),
1046 cte_names: std::collections::BTreeSet::new(),
1047 }
1048 }
1049}
1050
1051impl<'ast> Visit<'ast, Raw> for TableReferenceCollector {
1052 fn visit_cte(&mut self, cte: &'ast mz_sql_parser::ast::Cte<Raw>) {
1053 self.cte_names
1055 .insert(cte.alias.name.as_str().to_lowercase());
1056 visit::visit_cte(self, cte);
1057 }
1058
1059 fn visit_table_factor(&mut self, table_factor: &'ast mz_sql_parser::ast::TableFactor<Raw>) {
1060 if let mz_sql_parser::ast::TableFactor::Table { name, .. } = table_factor {
1062 match name {
1063 RawItemName::Name(n) | RawItemName::Id(_, n, _) => {
1064 let parts = &n.0;
1065 if !parts.is_empty() {
1066 let table_name = parts.last().unwrap().as_str().to_lowercase();
1067
1068 if self.cte_names.contains(&table_name) {
1070 visit::visit_table_factor(self, table_factor);
1071 return;
1072 }
1073
1074 let schema = if parts.len() >= 2 {
1076 Some(parts[parts.len() - 2].as_str().to_lowercase())
1077 } else {
1078 None
1079 };
1080 self.tables.push((schema, table_name));
1081 }
1082 }
1083 }
1084 }
1085 visit::visit_table_factor(self, table_factor);
1086 }
1087}
1088
1089fn validate_system_catalog_query(sql: &str) -> Result<(), McpRequestError> {
1097 let stmts = parse(sql).map_err(|e| {
1099 McpRequestError::QueryValidationFailed(format!("Failed to parse SQL: {}", e))
1100 })?;
1101
1102 if stmts.is_empty() {
1103 return Err(McpRequestError::QueryValidationFailed(
1104 "Empty query".to_string(),
1105 ));
1106 }
1107
1108 let mut collector = TableReferenceCollector::new();
1110 for stmt in &stmts {
1111 collector.visit_statement(&stmt.ast);
1112 }
1113
1114 let is_allowed_schema =
1117 |s: &str| SYSTEM_SCHEMAS.contains(&s) && s != namespaces::MZ_UNSAFE_SCHEMA;
1118
1119 let is_system_table = |(schema, table_name): &(Option<String>, String)| match schema {
1124 Some(s) => is_allowed_schema(s.as_str()),
1125 None => table_name.starts_with("mz_"),
1126 };
1127
1128 let non_system_tables: Vec<String> = collector
1130 .tables
1131 .iter()
1132 .filter(|t| !is_system_table(t))
1133 .map(|(schema, table)| match schema {
1134 Some(s) => format!("{}.{}", s, table),
1135 None => table.clone(),
1136 })
1137 .collect();
1138
1139 if !non_system_tables.is_empty() {
1140 return Err(McpRequestError::QueryValidationFailed(format!(
1141 "Query references non-system tables: {}. Only system catalog tables (mz_*, pg_catalog, information_schema) are allowed.",
1142 non_system_tables.join(", ")
1143 )));
1144 }
1145
1146 use mz_sql_parser::ast::Statement;
1149 let is_select = stmts.iter().any(|s| matches!(&s.ast, Statement::Select(_)));
1150
1151 if is_select && (collector.tables.is_empty() || !collector.tables.iter().any(is_system_table)) {
1152 return Err(McpRequestError::QueryValidationFailed(
1153 "Query must reference at least one system catalog table".to_string(),
1154 ));
1155 }
1156
1157 Ok(())
1158}
1159
1160#[cfg(test)]
1161mod tests {
1162 use super::*;
1163
1164 #[mz_ore::test]
1165 fn test_validate_readonly_query_select() {
1166 assert!(validate_readonly_query("SELECT * FROM mz_tables").is_ok());
1167 assert!(validate_readonly_query("SELECT 1 + 2").is_ok());
1168 assert!(validate_readonly_query(" SELECT 1 ").is_ok());
1169 }
1170
1171 #[mz_ore::test]
1172 fn test_validate_readonly_query_subqueries() {
1173 assert!(
1175 validate_readonly_query(
1176 "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM mz_columns)"
1177 )
1178 .is_ok()
1179 );
1180
1181 assert!(
1183 validate_readonly_query(
1184 "SELECT * FROM (SELECT * FROM mz_tables WHERE name LIKE 'test%') AS t"
1185 )
1186 .is_ok()
1187 );
1188
1189 assert!(validate_readonly_query(
1191 "SELECT * FROM mz_tables t WHERE EXISTS (SELECT 1 FROM mz_columns c WHERE c.id = t.id)"
1192 )
1193 .is_ok());
1194
1195 assert!(validate_readonly_query(
1197 "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM mz_columns WHERE type IN (SELECT name FROM mz_types))"
1198 )
1199 .is_ok());
1200
1201 assert!(
1203 validate_readonly_query(
1204 "SELECT * FROM mz_tables WHERE id = (SELECT MAX(id) FROM mz_columns)"
1205 )
1206 .is_ok()
1207 );
1208 }
1209
1210 #[mz_ore::test]
1211 fn test_validate_readonly_query_show() {
1212 assert!(validate_readonly_query("SHOW CLUSTERS").is_ok());
1213 assert!(validate_readonly_query("SHOW TABLES").is_ok());
1214 }
1215
1216 #[mz_ore::test]
1217 fn test_validate_readonly_query_explain() {
1218 assert!(validate_readonly_query("EXPLAIN SELECT 1").is_ok());
1219 }
1220
1221 #[mz_ore::test]
1222 fn test_validate_readonly_query_rejects_writes() {
1223 assert!(validate_readonly_query("INSERT INTO t VALUES (1)").is_err());
1224 assert!(validate_readonly_query("UPDATE t SET a = 1").is_err());
1225 assert!(validate_readonly_query("DELETE FROM t").is_err());
1226 assert!(validate_readonly_query("CREATE TABLE t (a INT)").is_err());
1227 assert!(validate_readonly_query("DROP TABLE t").is_err());
1228 }
1229
1230 #[mz_ore::test]
1231 fn test_validate_readonly_query_rejects_multiple() {
1232 assert!(validate_readonly_query("SELECT 1; SELECT 2").is_err());
1233 }
1234
1235 #[mz_ore::test]
1236 fn test_validate_readonly_query_rejects_empty() {
1237 assert!(validate_readonly_query("").is_err());
1238 assert!(validate_readonly_query(" ").is_err());
1239 }
1240
1241 #[mz_ore::test]
1242 fn test_validate_system_catalog_query_accepts_mz_tables() {
1243 assert!(validate_system_catalog_query("SELECT * FROM mz_tables").is_ok());
1244 assert!(validate_system_catalog_query("SELECT * FROM mz_internal.mz_comments").is_ok());
1245 assert!(
1246 validate_system_catalog_query(
1247 "SELECT * FROM mz_tables t JOIN mz_columns c ON t.id = c.id"
1248 )
1249 .is_ok()
1250 );
1251 }
1252
1253 #[mz_ore::test]
1254 fn test_validate_system_catalog_query_subqueries() {
1255 assert!(
1257 validate_system_catalog_query(
1258 "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM mz_columns)"
1259 )
1260 .is_ok()
1261 );
1262
1263 assert!(validate_system_catalog_query(
1265 "SELECT * FROM mz_tables WHERE id IN (SELECT table_id FROM mz_columns WHERE type IN (SELECT id FROM mz_types))"
1266 )
1267 .is_ok());
1268
1269 assert!(
1271 validate_system_catalog_query(
1272 "SELECT * FROM (SELECT * FROM mz_tables WHERE name LIKE 'test%') AS t"
1273 )
1274 .is_ok()
1275 );
1276
1277 assert!(
1279 validate_system_catalog_query(
1280 "SELECT * FROM mz_tables WHERE id IN (SELECT table_id FROM user_data)"
1281 )
1282 .is_err()
1283 );
1284
1285 assert!(validate_system_catalog_query(
1287 "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM (SELECT id FROM user_table) AS t)"
1288 )
1289 .is_err());
1290 }
1291
1292 #[mz_ore::test]
1293 fn test_validate_system_catalog_query_rejects_user_tables() {
1294 assert!(validate_system_catalog_query("SELECT * FROM user_data").is_err());
1295 assert!(validate_system_catalog_query("SELECT * FROM my_table").is_err());
1296 assert!(
1298 validate_system_catalog_query("SELECT * FROM user_data WHERE 'mz_' IS NOT NULL")
1299 .is_err()
1300 );
1301 }
1302
1303 #[mz_ore::test]
1304 fn test_validate_system_catalog_query_allows_functions() {
1305 assert!(
1307 validate_system_catalog_query(
1308 "SELECT date_part('year', now())::int4 AS y FROM mz_tables LIMIT 1"
1309 )
1310 .is_ok()
1311 );
1312 assert!(validate_system_catalog_query("SELECT length(name) FROM mz_tables").is_ok());
1313 assert!(
1314 validate_system_catalog_query(
1315 "SELECT count(*) FROM mz_sources WHERE now() > created_at"
1316 )
1317 .is_ok()
1318 );
1319 }
1320
1321 #[mz_ore::test]
1322 fn test_validate_system_catalog_query_schema_qualified() {
1323 assert!(validate_system_catalog_query("SELECT * FROM mz_catalog.mz_tables").is_ok());
1325 assert!(validate_system_catalog_query("SELECT * FROM mz_internal.mz_sessions").is_ok());
1326 assert!(validate_system_catalog_query("SELECT * FROM pg_catalog.pg_type").is_ok());
1327 assert!(validate_system_catalog_query("SELECT * FROM information_schema.tables").is_ok());
1328
1329 assert!(validate_system_catalog_query("SELECT * FROM public.user_table").is_err());
1331 assert!(validate_system_catalog_query("SELECT * FROM myschema.mytable").is_err());
1332
1333 assert!(
1335 validate_system_catalog_query("SELECT * FROM mz_unsafe.mz_some_table").is_err(),
1336 "mz_unsafe schema should be blocked even though it is a system schema"
1337 );
1338
1339 assert!(
1341 validate_system_catalog_query(
1342 "SELECT * FROM mz_catalog.mz_tables JOIN public.user_data ON true"
1343 )
1344 .is_err()
1345 );
1346 }
1347
1348 #[mz_ore::test]
1349 fn test_validate_system_catalog_query_adversarial_cases() {
1350 assert!(
1352 validate_system_catalog_query(
1353 "WITH user_cte AS (SELECT * FROM user_data) \
1354 SELECT * FROM mz_tables, user_cte"
1355 )
1356 .is_err(),
1357 "Should reject CTE referencing user table"
1358 );
1359
1360 assert!(
1362 validate_system_catalog_query(
1363 "WITH \
1364 cte1 AS (SELECT * FROM mz_tables), \
1365 cte2 AS (SELECT * FROM cte1), \
1366 cte3 AS (SELECT * FROM user_data) \
1367 SELECT * FROM cte2"
1368 )
1369 .is_err(),
1370 "Should reject CTE chain with user table"
1371 );
1372
1373 assert!(
1375 validate_system_catalog_query(
1376 "SELECT * FROM mz_tables t1 \
1377 JOIN user_data u ON t1.id = u.id \
1378 JOIN mz_sources s ON t1.id = s.id"
1379 )
1380 .is_err(),
1381 "Should reject multi-join with user table"
1382 );
1383
1384 assert!(
1386 validate_system_catalog_query(
1387 "SELECT * FROM mz_tables t \
1388 LEFT JOIN user_data u ON t.id = u.table_id \
1389 WHERE u.id IS NULL"
1390 )
1391 .is_err(),
1392 "Should reject LEFT JOIN with user table"
1393 );
1394
1395 assert!(
1397 validate_system_catalog_query(
1398 "SELECT * FROM mz_tables WHERE id IN \
1399 (SELECT table_id FROM (SELECT * FROM user_data) AS u)"
1400 )
1401 .is_err(),
1402 "Should reject nested subquery with user table"
1403 );
1404
1405 assert!(
1407 validate_system_catalog_query(
1408 "SELECT name FROM mz_tables \
1409 UNION \
1410 SELECT name FROM user_data"
1411 )
1412 .is_err(),
1413 "Should reject UNION with user table"
1414 );
1415
1416 assert!(
1418 validate_system_catalog_query(
1419 "SELECT id FROM mz_sources \
1420 UNION ALL \
1421 SELECT id FROM products"
1422 )
1423 .is_err(),
1424 "Should reject UNION ALL with user table"
1425 );
1426
1427 assert!(
1429 validate_system_catalog_query("SELECT * FROM mz_tables CROSS JOIN user_data").is_err(),
1430 "Should reject CROSS JOIN with user table"
1431 );
1432
1433 assert!(
1435 validate_system_catalog_query(
1436 "SELECT t.*, (SELECT COUNT(*) FROM user_data) AS cnt FROM mz_tables t"
1437 )
1438 .is_err(),
1439 "Should reject subquery in SELECT with user table"
1440 );
1441
1442 assert!(
1444 validate_system_catalog_query("SELECT * FROM mz_catalogg.fake_table").is_err(),
1445 "Should reject typo-squatting schema name"
1446 );
1447 assert!(
1448 validate_system_catalog_query("SELECT * FROM mz_catalog_hack.fake_table").is_err(),
1449 "Should reject fake schema with mz_catalog prefix"
1450 );
1451
1452 assert!(
1454 validate_system_catalog_query(
1455 "SELECT * FROM mz_tables t, LATERAL (SELECT * FROM user_data WHERE id = t.id) u"
1456 )
1457 .is_err(),
1458 "Should reject LATERAL join with user table"
1459 );
1460
1461 assert!(
1463 validate_system_catalog_query(
1464 "WITH \
1465 tables AS (SELECT * FROM mz_tables), \
1466 sources AS (SELECT * FROM mz_sources) \
1467 SELECT t.name, s.name \
1468 FROM tables t \
1469 JOIN sources s ON t.id = s.id \
1470 WHERE t.id IN (SELECT id FROM mz_columns)"
1471 )
1472 .is_ok(),
1473 "Should allow complex query with only system tables"
1474 );
1475
1476 assert!(
1478 validate_system_catalog_query(
1479 "SELECT name FROM mz_tables \
1480 UNION \
1481 SELECT name FROM mz_sources"
1482 )
1483 .is_ok(),
1484 "Should allow UNION of system tables"
1485 );
1486 }
1487
1488 #[mz_ore::test]
1489 fn test_validate_system_catalog_query_rejects_constant_queries() {
1490 assert!(
1493 validate_system_catalog_query("SELECT 1").is_err(),
1494 "Should reject constant SELECT with no table references"
1495 );
1496 assert!(
1497 validate_system_catalog_query("SELECT 1 + 2, 'hello'").is_err(),
1498 "Should reject constant expression SELECT"
1499 );
1500 assert!(
1501 validate_system_catalog_query("SELECT now()").is_err(),
1502 "Should reject function-only SELECT with no table references"
1503 );
1504 }
1505
1506 #[mz_ore::test]
1507 fn test_validate_system_catalog_query_rejects_mixed_tables() {
1508 assert!(
1509 validate_system_catalog_query(
1510 "SELECT * FROM mz_tables t JOIN user_data u ON t.id = u.table_id"
1511 )
1512 .is_err()
1513 );
1514 }
1515
1516 #[mz_ore::test]
1517 fn test_validate_system_catalog_query_allows_show() {
1518 assert!(
1520 validate_system_catalog_query("SHOW TABLES FROM mz_internal").is_ok(),
1521 "SHOW TABLES FROM mz_internal should be allowed"
1522 );
1523 assert!(
1524 validate_system_catalog_query("SHOW TABLES FROM mz_catalog").is_ok(),
1525 "SHOW TABLES FROM mz_catalog should be allowed"
1526 );
1527 assert!(
1528 validate_system_catalog_query("SHOW CLUSTERS").is_ok(),
1529 "SHOW CLUSTERS should be allowed"
1530 );
1531 assert!(
1532 validate_system_catalog_query("SHOW SOURCES").is_ok(),
1533 "SHOW SOURCES should be allowed"
1534 );
1535 assert!(
1536 validate_system_catalog_query("SHOW TABLES").is_ok(),
1537 "SHOW TABLES should be allowed"
1538 );
1539 }
1540
1541 #[mz_ore::test]
1542 fn test_validate_system_catalog_query_allows_explain() {
1543 assert!(
1544 validate_system_catalog_query("EXPLAIN SELECT * FROM mz_tables").is_ok(),
1545 "EXPLAIN of system table query should be allowed"
1546 );
1547 assert!(
1548 validate_system_catalog_query("EXPLAIN SELECT 1").is_ok(),
1549 "EXPLAIN SELECT 1 should be allowed"
1550 );
1551 }
1552
1553 #[mz_ore::test(tokio::test)]
1556 async fn test_tools_list_agent_query_tool_disabled() {
1557 let result = handle_tools_list(McpEndpointType::Agent, false, 1_000_000)
1558 .await
1559 .unwrap();
1560 let McpResult::ToolsList(list) = result else {
1561 panic!("Expected ToolsList result");
1562 };
1563 let tool_names: Vec<&str> = list.tools.iter().map(|t| t.name.as_str()).collect();
1564 assert!(
1565 tool_names.contains(&"get_data_products"),
1566 "get_data_products should always be present"
1567 );
1568 assert!(
1569 tool_names.contains(&"get_data_product_details"),
1570 "get_data_product_details should always be present"
1571 );
1572 assert!(
1573 tool_names.contains(&"read_data_product"),
1574 "read_data_product should always be present"
1575 );
1576 assert!(
1577 !tool_names.contains(&"query"),
1578 "query tool should be hidden when disabled"
1579 );
1580 }
1581
1582 #[mz_ore::test(tokio::test)]
1583 async fn test_tools_list_agent_query_tool_enabled() {
1584 let result = handle_tools_list(McpEndpointType::Agent, true, 1_000_000)
1585 .await
1586 .unwrap();
1587 let McpResult::ToolsList(list) = result else {
1588 panic!("Expected ToolsList result");
1589 };
1590 let tool_names: Vec<&str> = list.tools.iter().map(|t| t.name.as_str()).collect();
1591 assert!(
1592 tool_names.contains(&"get_data_products"),
1593 "get_data_products should always be present"
1594 );
1595 assert!(
1596 tool_names.contains(&"get_data_product_details"),
1597 "get_data_product_details should always be present"
1598 );
1599 assert!(
1600 tool_names.contains(&"read_data_product"),
1601 "read_data_product should always be present"
1602 );
1603 assert!(
1604 tool_names.contains(&"query"),
1605 "query tool should be present when enabled"
1606 );
1607 }
1608
1609 #[mz_ore::test(tokio::test)]
1610 async fn test_tools_list_developer_unaffected_by_query_flag() {
1611 for flag in [true, false] {
1613 let result = handle_tools_list(McpEndpointType::Developer, flag, 1_000_000)
1614 .await
1615 .unwrap();
1616 let McpResult::ToolsList(list) = result else {
1617 panic!("Expected ToolsList result");
1618 };
1619 let tool_names: Vec<&str> = list.tools.iter().map(|t| t.name.as_str()).collect();
1620 assert!(
1621 tool_names.contains(&"query_system_catalog"),
1622 "query_system_catalog should always be present on developer"
1623 );
1624 assert!(
1625 !tool_names.contains(&"query"),
1626 "query tool should never appear on developer"
1627 );
1628 }
1629 }
1630
1631 #[mz_ore::test]
1634 fn test_format_rows_response_within_limit() {
1635 let rows = vec![vec![json!("a"), json!(1)], vec![json!("b"), json!(2)]];
1636 let result = format_rows_response(rows, 1_000_000).unwrap();
1637 let McpResult::ToolContent(content) = result else {
1638 panic!("Expected ToolContent");
1639 };
1640 assert_eq!(content.content.len(), 1);
1641 assert!(content.content[0].text.contains("\"a\""));
1642 assert!(content.content[0].text.contains("\"b\""));
1643 }
1644
1645 #[mz_ore::test]
1646 fn test_format_rows_response_errors_when_over_limit() {
1647 let rows: Vec<Vec<serde_json::Value>> = (0..100)
1648 .map(|i| vec![json!(format!("row_{}", i)), json!(i)])
1649 .collect();
1650 let err = format_rows_response(rows, 500).unwrap_err();
1651 let msg = err.to_string();
1652 assert!(
1653 msg.contains("exceeds the 500 byte limit"),
1654 "Error should mention the size limit, got: {msg}"
1655 );
1656 assert!(
1657 msg.contains("Use LIMIT or WHERE"),
1658 "Error should suggest narrowing the query, got: {msg}"
1659 );
1660 }
1661
1662 #[mz_ore::test]
1663 fn test_format_rows_response_empty_rows() {
1664 let rows: Vec<Vec<serde_json::Value>> = vec![];
1665 let result = format_rows_response(rows, 1000).unwrap();
1666 let McpResult::ToolContent(content) = result else {
1667 panic!("Expected ToolContent");
1668 };
1669 assert_eq!(content.content.len(), 1);
1670 assert_eq!(content.content[0].text, "[]");
1671 }
1672
1673 #[mz_ore::test]
1676 fn test_safe_data_product_name_valid() {
1677 assert_eq!(
1679 safe_data_product_name(r#""materialize"."public"."my_view""#).unwrap(),
1680 r#""materialize"."public"."my_view""#
1681 );
1682 assert_eq!(
1684 safe_data_product_name(r#""public"."my_view""#).unwrap(),
1685 r#""public"."my_view""#
1686 );
1687 assert_eq!(safe_data_product_name("my_view").unwrap(), r#""my_view""#);
1689 }
1690
1691 #[mz_ore::test]
1692 fn test_safe_data_product_name_rejects_empty() {
1693 assert!(safe_data_product_name("").is_err());
1694 assert!(safe_data_product_name(" ").is_err());
1695 }
1696
1697 #[mz_ore::test]
1698 fn test_safe_data_product_name_rejects_sql_injection() {
1699 assert!(safe_data_product_name("my_view; DROP TABLE users").is_err());
1701 assert!(safe_data_product_name("my_view UNION SELECT * FROM secrets").is_err());
1703 assert!(safe_data_product_name("my_view, secrets").is_err());
1705 assert!(safe_data_product_name("my_view WHERE 1=1 --").is_err());
1707 }
1708
1709 #[mz_ore::test]
1710 fn test_mcp_error_codes() {
1711 assert_eq!(
1712 McpRequestError::InvalidJsonRpcVersion.error_code(),
1713 error_codes::INVALID_REQUEST
1714 );
1715 assert_eq!(
1716 McpRequestError::MethodNotFound("test".to_string()).error_code(),
1717 error_codes::METHOD_NOT_FOUND
1718 );
1719 assert_eq!(
1720 McpRequestError::QueryExecutionFailed("test".to_string()).error_code(),
1721 error_codes::INTERNAL_ERROR
1722 );
1723 }
1724}