1use std::time::Duration;
27
28use anyhow::anyhow;
29use axum::Json;
30use axum::response::IntoResponse;
31use http::{HeaderMap, StatusCode};
32use mz_adapter_types::dyncfgs::{
33 ENABLE_MCP_AGENTS, ENABLE_MCP_AGENTS_QUERY_TOOL, ENABLE_MCP_OBSERVATORY, MCP_MAX_RESPONSE_SIZE,
34};
35use mz_repr::namespaces::{self, SYSTEM_SCHEMAS};
36use mz_sql::parse::parse;
37use mz_sql::session::metadata::SessionMetadata;
38use mz_sql_parser::ast::display::{AstDisplay, escaped_string_literal};
39use mz_sql_parser::ast::visit::{self, Visit};
40use mz_sql_parser::ast::{Raw, RawItemName};
41use mz_sql_parser::parser::parse_item_name;
42use serde::{Deserialize, Serialize};
43use serde_json::json;
44use thiserror::Error;
45use tracing::{debug, warn};
46
47use crate::http::AuthedClient;
48use crate::http::sql::{SqlRequest, SqlResponse, SqlResult, execute_request};
49
50const JSONRPC_VERSION: &str = "2.0";
54
55const MCP_PROTOCOL_VERSION: &str = "2025-11-25";
58
59const MCP_REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
64
65const DISCOVERY_QUERY: &str = "SELECT * FROM mz_internal.mz_mcp_data_products";
67const DETAILS_QUERY_PREFIX: &str =
69 "SELECT * FROM mz_internal.mz_mcp_data_product_details WHERE object_name = ";
70
71#[derive(Debug, Error)]
73enum McpRequestError {
74 #[error("Invalid JSON-RPC version: expected 2.0")]
75 InvalidJsonRpcVersion,
76 #[error("Method not found: {0}")]
77 #[allow(dead_code)] MethodNotFound(String),
79 #[error("Tool not found: {0}")]
80 ToolNotFound(String),
81 #[error("Data product not found: {0}")]
82 DataProductNotFound(String),
83 #[error("Query validation failed: {0}")]
84 QueryValidationFailed(String),
85 #[error("Query execution failed: {0}")]
86 QueryExecutionFailed(String),
87 #[error("Internal error: {0}")]
88 Internal(#[from] anyhow::Error),
89}
90
91impl McpRequestError {
92 fn error_code(&self) -> i32 {
93 match self {
94 Self::InvalidJsonRpcVersion => error_codes::INVALID_REQUEST,
95 Self::MethodNotFound(_) => error_codes::METHOD_NOT_FOUND,
96 Self::ToolNotFound(_) => error_codes::INVALID_PARAMS,
97 Self::DataProductNotFound(_) => error_codes::INVALID_PARAMS,
98 Self::QueryValidationFailed(_) => error_codes::INVALID_PARAMS,
99 Self::QueryExecutionFailed(_) | Self::Internal(_) => error_codes::INTERNAL_ERROR,
100 }
101 }
102
103 fn error_type(&self) -> &'static str {
104 match self {
105 Self::InvalidJsonRpcVersion => "InvalidRequest",
106 Self::MethodNotFound(_) => "MethodNotFound",
107 Self::ToolNotFound(_) => "ToolNotFound",
108 Self::DataProductNotFound(_) => "DataProductNotFound",
109 Self::QueryValidationFailed(_) => "ValidationError",
110 Self::QueryExecutionFailed(_) => "ExecutionError",
111 Self::Internal(_) => "InternalError",
112 }
113 }
114}
115
116#[derive(Debug, Deserialize)]
118pub(crate) struct McpRequest {
119 jsonrpc: String,
120 id: Option<serde_json::Value>,
121 #[serde(flatten)]
122 method: McpMethod,
123}
124
125#[derive(Debug, Deserialize)]
127#[serde(tag = "method", content = "params")]
128enum McpMethod {
129 #[serde(rename = "initialize")]
131 Initialize(#[allow(dead_code)] InitializeParams),
132 #[serde(rename = "tools/list")]
133 ToolsList,
134 #[serde(rename = "tools/call")]
135 ToolsCall(ToolsCallParams),
136 #[serde(other)]
138 Unknown,
139}
140
141impl std::fmt::Display for McpMethod {
142 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143 match self {
144 McpMethod::Initialize(_) => write!(f, "initialize"),
145 McpMethod::ToolsList => write!(f, "tools/list"),
146 McpMethod::ToolsCall(_) => write!(f, "tools/call"),
147 McpMethod::Unknown => write!(f, "unknown"),
148 }
149 }
150}
151
152#[derive(Debug, Deserialize)]
153struct InitializeParams {
154 #[serde(rename = "protocolVersion")]
156 #[allow(dead_code)]
157 protocol_version: String,
158 #[serde(default)]
160 #[allow(dead_code)]
161 capabilities: serde_json::Value,
162 #[serde(rename = "clientInfo")]
164 #[allow(dead_code)]
165 client_info: Option<ClientInfo>,
166}
167
168#[derive(Debug, Deserialize)]
169struct ClientInfo {
170 #[allow(dead_code)]
171 name: String,
172 #[allow(dead_code)]
173 version: String,
174}
175
176#[derive(Debug, Deserialize)]
179#[serde(tag = "name", content = "arguments")]
180#[serde(rename_all = "snake_case")]
181enum ToolsCallParams {
182 GetDataProducts(#[serde(default)] ()),
185 GetDataProductDetails(GetDataProductDetailsParams),
186 ReadDataProduct(ReadDataProductParams),
187 Query(QueryParams),
188 QuerySystemCatalog(QuerySystemCatalogParams),
190}
191
192impl std::fmt::Display for ToolsCallParams {
193 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194 match self {
195 ToolsCallParams::GetDataProducts(_) => write!(f, "get_data_products"),
196 ToolsCallParams::GetDataProductDetails(_) => write!(f, "get_data_product_details"),
197 ToolsCallParams::ReadDataProduct(_) => write!(f, "read_data_product"),
198 ToolsCallParams::Query(_) => write!(f, "query"),
199 ToolsCallParams::QuerySystemCatalog(_) => write!(f, "query_system_catalog"),
200 }
201 }
202}
203
204#[derive(Debug, Deserialize)]
205struct GetDataProductDetailsParams {
206 name: String,
207}
208
209#[derive(Debug, Deserialize)]
210struct ReadDataProductParams {
211 name: String,
212 #[serde(default = "default_read_limit")]
213 limit: u32,
214 cluster: Option<String>,
215}
216
217fn default_read_limit() -> u32 {
218 500
219}
220
221const MAX_READ_LIMIT: u32 = 1000;
223
224#[derive(Debug, Deserialize)]
225struct QueryParams {
226 cluster: String,
227 sql_query: String,
228}
229
230#[derive(Debug, Deserialize)]
231struct QuerySystemCatalogParams {
232 sql_query: String,
233}
234
235#[derive(Debug, Serialize)]
236struct McpResponse {
237 jsonrpc: String,
238 id: serde_json::Value,
239 #[serde(skip_serializing_if = "Option::is_none")]
240 result: Option<McpResult>,
241 #[serde(skip_serializing_if = "Option::is_none")]
242 error: Option<McpError>,
243}
244
245#[derive(Debug, Serialize)]
247#[serde(untagged)]
248enum McpResult {
249 Initialize(InitializeResult),
250 ToolsList(ToolsListResult),
251 ToolContent(ToolContentResult),
252}
253
254#[derive(Debug, Serialize)]
255struct InitializeResult {
256 #[serde(rename = "protocolVersion")]
257 protocol_version: String,
258 capabilities: Capabilities,
259 #[serde(rename = "serverInfo")]
260 server_info: ServerInfo,
261 #[serde(skip_serializing_if = "Option::is_none")]
262 instructions: Option<String>,
263}
264
265#[derive(Debug, Serialize)]
266struct Capabilities {
267 tools: serde_json::Value,
268}
269
270#[derive(Debug, Serialize)]
271struct ServerInfo {
272 name: String,
273 version: String,
274}
275
276#[derive(Debug, Serialize)]
277struct ToolsListResult {
278 tools: Vec<ToolDefinition>,
279}
280
281#[derive(Debug, Serialize)]
282struct ToolDefinition {
283 name: String,
284 #[serde(skip_serializing_if = "Option::is_none")]
285 title: Option<String>,
286 description: String,
287 #[serde(rename = "inputSchema")]
288 input_schema: serde_json::Value,
289 #[serde(skip_serializing_if = "Option::is_none")]
290 annotations: Option<ToolAnnotations>,
291}
292
293#[derive(Debug, Serialize)]
296struct ToolAnnotations {
297 #[serde(rename = "readOnlyHint", skip_serializing_if = "Option::is_none")]
298 read_only_hint: Option<bool>,
299 #[serde(rename = "destructiveHint", skip_serializing_if = "Option::is_none")]
300 destructive_hint: Option<bool>,
301 #[serde(rename = "idempotentHint", skip_serializing_if = "Option::is_none")]
302 idempotent_hint: Option<bool>,
303 #[serde(rename = "openWorldHint", skip_serializing_if = "Option::is_none")]
304 open_world_hint: Option<bool>,
305}
306
307const READ_ONLY_ANNOTATIONS: ToolAnnotations = ToolAnnotations {
309 read_only_hint: Some(true),
310 destructive_hint: Some(false),
311 idempotent_hint: Some(true),
312 open_world_hint: Some(false),
313};
314
315#[derive(Debug, Serialize)]
316struct ToolContentResult {
317 content: Vec<ContentBlock>,
318 #[serde(rename = "isError")]
319 is_error: bool,
320}
321
322#[derive(Debug, Serialize)]
323struct ContentBlock {
324 #[serde(rename = "type")]
325 content_type: String,
326 text: String,
327}
328
329mod error_codes {
331 pub const INVALID_REQUEST: i32 = -32600;
332 pub const METHOD_NOT_FOUND: i32 = -32601;
333 pub const INVALID_PARAMS: i32 = -32602;
334 pub const INTERNAL_ERROR: i32 = -32603;
335}
336
337#[derive(Debug, Serialize)]
338struct McpError {
339 code: i32,
340 message: String,
341 #[serde(skip_serializing_if = "Option::is_none")]
342 data: Option<serde_json::Value>,
343}
344
345impl From<McpRequestError> for McpError {
346 fn from(err: McpRequestError) -> Self {
347 McpError {
348 code: err.error_code(),
349 message: err.to_string(),
350 data: Some(json!({
351 "error_type": err.error_type(),
352 })),
353 }
354 }
355}
356
357#[derive(Debug, Clone, Copy)]
358enum McpEndpointType {
359 Agents,
360 Observatory,
361}
362
363impl std::fmt::Display for McpEndpointType {
364 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
365 match self {
366 McpEndpointType::Agents => write!(f, "agents"),
367 McpEndpointType::Observatory => write!(f, "observatory"),
368 }
369 }
370}
371
372pub async fn handle_mcp_method_not_allowed() -> impl IntoResponse {
375 StatusCode::METHOD_NOT_ALLOWED
376}
377
378pub async fn handle_mcp_agents(
380 headers: HeaderMap,
381 client: AuthedClient,
382 Json(body): Json<McpRequest>,
383) -> axum::response::Response {
384 if let Some(resp) = validate_origin(&headers) {
385 return resp;
386 }
387 handle_mcp_request(client, body, McpEndpointType::Agents)
388 .await
389 .into_response()
390}
391
392pub async fn handle_mcp_observatory(
394 headers: HeaderMap,
395 client: AuthedClient,
396 Json(body): Json<McpRequest>,
397) -> axum::response::Response {
398 if let Some(resp) = validate_origin(&headers) {
399 return resp;
400 }
401 handle_mcp_request(client, body, McpEndpointType::Observatory)
402 .await
403 .into_response()
404}
405
406fn validate_origin(headers: &HeaderMap) -> Option<axum::response::Response> {
410 let origin = match headers.get(http::header::ORIGIN) {
411 Some(o) => o,
412 None => return None, };
414
415 let host = headers
416 .get(http::header::HOST)
417 .and_then(|h| h.to_str().ok());
418
419 let origin_str = origin.to_str().ok().unwrap_or("");
420
421 let origin_host = origin_str
423 .strip_prefix("https://")
424 .or_else(|| origin_str.strip_prefix("http://"))
425 .unwrap_or(origin_str);
426
427 match host {
428 Some(h) if h == origin_host => None, _ => {
430 warn!(
431 origin = origin_str,
432 host = ?host,
433 "MCP request rejected: Origin does not match Host",
434 );
435 Some(StatusCode::FORBIDDEN.into_response())
436 }
437 }
438}
439
440async fn handle_mcp_request(
441 mut client: AuthedClient,
442 request: McpRequest,
443 endpoint_type: McpEndpointType,
444) -> impl IntoResponse {
445 let catalog = client.client.catalog_snapshot("mcp").await;
447 let dyncfgs = catalog.system_config().dyncfgs();
448 let enabled = match endpoint_type {
449 McpEndpointType::Agents => ENABLE_MCP_AGENTS.get(dyncfgs),
450 McpEndpointType::Observatory => ENABLE_MCP_OBSERVATORY.get(dyncfgs),
451 };
452 if !enabled {
453 debug!(endpoint = %endpoint_type, "MCP endpoint disabled by feature flag");
454 return StatusCode::SERVICE_UNAVAILABLE.into_response();
455 }
456
457 let query_tool_enabled = ENABLE_MCP_AGENTS_QUERY_TOOL.get(dyncfgs);
458 let max_response_size = MCP_MAX_RESPONSE_SIZE.get(dyncfgs);
459
460 let user = client.client.session().user().name.clone();
461 let is_notification = request.id.is_none();
462
463 debug!(
464 method = %request.method,
465 endpoint = %endpoint_type,
466 user = %user,
467 is_notification = is_notification,
468 "MCP request received"
469 );
470
471 if is_notification {
473 debug!(method = %request.method, "Received notification (no response will be sent)");
474 return StatusCode::OK.into_response();
475 }
476
477 let request_id = request.id.clone().unwrap_or(serde_json::Value::Null);
478
479 let result = tokio::time::timeout(
481 MCP_REQUEST_TIMEOUT,
482 mz_ore::task::spawn(|| "mcp_request", async move {
483 handle_mcp_request_inner(
484 &mut client,
485 request,
486 endpoint_type,
487 query_tool_enabled,
488 max_response_size,
489 )
490 .await
491 }),
492 )
493 .await;
494
495 let response = match result {
496 Ok(inner) => inner,
497 Err(_elapsed) => {
498 warn!(
499 endpoint = %endpoint_type,
500 timeout = ?MCP_REQUEST_TIMEOUT,
501 "MCP request timed out",
502 );
503 McpResponse {
504 jsonrpc: JSONRPC_VERSION.to_string(),
505 id: request_id,
506 result: None,
507 error: Some(
508 McpRequestError::QueryExecutionFailed(format!(
509 "Request timed out after {} seconds.",
510 MCP_REQUEST_TIMEOUT.as_secs(),
511 ))
512 .into(),
513 ),
514 }
515 }
516 };
517
518 (StatusCode::OK, Json(response)).into_response()
519}
520
521async fn handle_mcp_request_inner(
522 client: &mut AuthedClient,
523 request: McpRequest,
524 endpoint_type: McpEndpointType,
525 query_tool_enabled: bool,
526 max_response_size: usize,
527) -> McpResponse {
528 let request_id = request.id.clone().unwrap_or(serde_json::Value::Null);
530
531 let result = handle_mcp_method(
532 client,
533 &request,
534 endpoint_type,
535 query_tool_enabled,
536 max_response_size,
537 )
538 .await;
539
540 match result {
541 Ok(result_value) => McpResponse {
542 jsonrpc: JSONRPC_VERSION.to_string(),
543 id: request_id,
544 result: Some(result_value),
545 error: None,
546 },
547 Err(e) => {
548 if !matches!(
550 e,
551 McpRequestError::MethodNotFound(_) | McpRequestError::InvalidJsonRpcVersion
552 ) {
553 warn!(error = %e, method = %request.method, "MCP method execution failed");
554 }
555 McpResponse {
556 jsonrpc: JSONRPC_VERSION.to_string(),
557 id: request_id,
558 result: None,
559 error: Some(e.into()),
560 }
561 }
562 }
563}
564
565async fn handle_mcp_method(
566 client: &mut AuthedClient,
567 request: &McpRequest,
568 endpoint_type: McpEndpointType,
569 query_tool_enabled: bool,
570 max_response_size: usize,
571) -> Result<McpResult, McpRequestError> {
572 if request.jsonrpc != JSONRPC_VERSION {
574 return Err(McpRequestError::InvalidJsonRpcVersion);
575 }
576
577 match &request.method {
579 McpMethod::Initialize(_) => {
580 debug!(endpoint = %endpoint_type, "Processing initialize");
581 handle_initialize(endpoint_type).await
582 }
583 McpMethod::ToolsList => {
584 debug!(endpoint = %endpoint_type, "Processing tools/list");
585 handle_tools_list(endpoint_type, query_tool_enabled, max_response_size).await
586 }
587 McpMethod::ToolsCall(params) => {
588 debug!(tool = %params, endpoint = %endpoint_type, "Processing tools/call");
589 handle_tools_call(
590 client,
591 params,
592 endpoint_type,
593 query_tool_enabled,
594 max_response_size,
595 )
596 .await
597 }
598 McpMethod::Unknown => Err(McpRequestError::MethodNotFound(
599 "unknown method".to_string(),
600 )),
601 }
602}
603
604async fn handle_initialize(endpoint_type: McpEndpointType) -> Result<McpResult, McpRequestError> {
605 Ok(McpResult::Initialize(InitializeResult {
606 protocol_version: MCP_PROTOCOL_VERSION.to_string(),
607 capabilities: Capabilities { tools: json!({}) },
608 server_info: ServerInfo {
609 name: format!("materialize-mcp-{}", endpoint_type),
610 version: env!("CARGO_PKG_VERSION").to_string(),
611 },
612 instructions: None,
613 }))
614}
615
616async fn handle_tools_list(
617 endpoint_type: McpEndpointType,
618 query_tool_enabled: bool,
619 max_response_size: usize,
620) -> Result<McpResult, McpRequestError> {
621 let size_hint = format!("Response limit: {} MB.", max_response_size / 1_000_000);
622
623 let tools = match endpoint_type {
624 McpEndpointType::Agents => {
625 let mut tools = vec![
626 ToolDefinition {
627 name: "get_data_products".to_string(),
628 title: Some("List Data Products".to_string()),
629 description: "Discover all available real-time data views (data products) that represent business entities like customers, orders, products, etc. Each data product provides fresh, queryable data with defined schemas. Use this first to see what data is available before querying specific information.".to_string(),
630 input_schema: json!({
631 "type": "object",
632 "properties": {},
633 "required": []
634 }),
635 annotations: Some(READ_ONLY_ANNOTATIONS),
636 },
637 ToolDefinition {
638 name: "get_data_product_details".to_string(),
639 title: Some("Get Data Product Details".to_string()),
640 description: "Get the complete schema and structure of a specific data product. This shows you exactly what fields are available, their types, and what data you can query. Use this after finding a data product from get_data_products() to understand how to query it.".to_string(),
641 input_schema: json!({
642 "type": "object",
643 "properties": {
644 "name": {
645 "type": "string",
646 "description": "Exact name of the data product from get_data_products() list"
647 }
648 },
649 "required": ["name"]
650 }),
651 annotations: Some(READ_ONLY_ANNOTATIONS),
652 },
653 ToolDefinition {
654 name: "read_data_product".to_string(),
655 title: Some("Read Data Product".to_string()),
656 description: format!("Read rows from a specific data product. Returns up to `limit` rows (default 500, max 1000). The data product must exist in the catalog (use get_data_products() to discover available products). Use this to retrieve actual data from a known data product. {size_hint}"),
657 input_schema: json!({
658 "type": "object",
659 "properties": {
660 "name": {
661 "type": "string",
662 "description": "Exact fully-qualified name of the data product (e.g. '\"materialize\".\"schema\".\"view_name\"')"
663 },
664 "limit": {
665 "type": "integer",
666 "description": "Maximum number of rows to return (default 500, max 1000)",
667 "default": 500
668 },
669 "cluster": {
670 "type": "string",
671 "description": "Optional cluster override. If omitted, uses the cluster from the data product catalog."
672 }
673 },
674 "required": ["name"]
675 }),
676 annotations: Some(READ_ONLY_ANNOTATIONS),
677 },
678 ];
679 if query_tool_enabled {
680 tools.push(ToolDefinition {
681 name: "query".to_string(),
682 title: Some("Query Data Products".to_string()),
683 description: format!("Execute SQL queries against real-time data products to retrieve current business information. Use standard PostgreSQL syntax. You can JOIN multiple data products together, but ONLY if they are all hosted on the same cluster. Always specify the cluster parameter from the data product details. This provides fresh, up-to-date results from materialized views. {size_hint}"),
684 input_schema: json!({
685 "type": "object",
686 "properties": {
687 "cluster": {
688 "type": "string",
689 "description": "Exact cluster name from the data product details - required for query execution"
690 },
691 "sql_query": {
692 "type": "string",
693 "description": "PostgreSQL-compatible SELECT statement to retrieve data. Use the fully qualified data product name exactly as provided (with double quotes). You can JOIN multiple data products, but only those on the same cluster."
694 }
695 },
696 "required": ["cluster", "sql_query"]
697 }),
698 annotations: Some(READ_ONLY_ANNOTATIONS),
699 });
700 }
701 tools
702 }
703 McpEndpointType::Observatory => {
704 vec![ToolDefinition {
705 name: "query_system_catalog".to_string(),
706 title: Some("Query System Catalog".to_string()),
707 description: concat!(
708 "Query Materialize system catalog tables for troubleshooting and observability. ",
709 "Only mz_*, pg_catalog, and information_schema tables are accessible.\n\n",
710 "Key tables by scenario:\n",
711 "- Freshness: mz_internal.mz_wallclock_global_lag_recent_history, mz_internal.mz_materialization_lag, mz_internal.mz_hydration_statuses\n",
712 "- Memory: mz_internal.mz_cluster_replica_utilization, mz_internal.mz_cluster_replica_metrics, mz_internal.mz_dataflow_arrangement_sizes\n",
713 "- Cluster health: mz_internal.mz_cluster_replica_statuses, mz_catalog.mz_cluster_replicas\n",
714 "- Source/Sink health: mz_internal.mz_source_statuses, mz_internal.mz_sink_statuses, mz_internal.mz_source_statistics, mz_internal.mz_sink_statistics\n",
715 "- Object catalog: mz_catalog.mz_objects (all objects), mz_catalog.mz_tables, mz_catalog.mz_materialized_views, mz_catalog.mz_sources, mz_catalog.mz_sinks\n\n",
716 "Use SHOW TABLES FROM mz_internal or mz_catalog to discover more tables.",
717 ).to_owned() + &format!(" {size_hint}"),
718 input_schema: json!({
719 "type": "object",
720 "properties": {
721 "sql_query": {
722 "type": "string",
723 "description": "PostgreSQL-compatible SELECT, SHOW, or EXPLAIN query referencing mz_* system catalog tables"
724 }
725 },
726 "required": ["sql_query"]
727 }),
728 annotations: Some(READ_ONLY_ANNOTATIONS),
729 }]
730 }
731 };
732
733 Ok(McpResult::ToolsList(ToolsListResult { tools }))
734}
735
736async fn handle_tools_call(
737 client: &mut AuthedClient,
738 params: &ToolsCallParams,
739 endpoint_type: McpEndpointType,
740 query_tool_enabled: bool,
741 max_response_size: usize,
742) -> Result<McpResult, McpRequestError> {
743 match (endpoint_type, params) {
744 (McpEndpointType::Agents, ToolsCallParams::GetDataProducts(_)) => {
745 get_data_products(client, max_response_size).await
746 }
747 (McpEndpointType::Agents, ToolsCallParams::GetDataProductDetails(p)) => {
748 get_data_product_details(client, &p.name, max_response_size).await
749 }
750 (McpEndpointType::Agents, ToolsCallParams::ReadDataProduct(p)) => {
751 read_data_product(
752 client,
753 &p.name,
754 p.limit,
755 p.cluster.as_deref(),
756 max_response_size,
757 )
758 .await
759 }
760 (McpEndpointType::Agents, ToolsCallParams::Query(_)) if !query_tool_enabled => {
761 Err(McpRequestError::ToolNotFound(
762 "query tool is not available. Use get_data_products, get_data_product_details, and read_data_product instead.".to_string(),
763 ))
764 }
765 (McpEndpointType::Agents, ToolsCallParams::Query(p)) => {
766 execute_query(client, &p.cluster, &p.sql_query, max_response_size).await
767 }
768 (McpEndpointType::Observatory, ToolsCallParams::QuerySystemCatalog(p)) => {
769 query_system_catalog(client, &p.sql_query, max_response_size).await
770 }
771 (endpoint, tool) => Err(McpRequestError::ToolNotFound(format!(
773 "{} is not available on {} endpoint",
774 tool, endpoint
775 ))),
776 }
777}
778
779async fn execute_sql(
781 client: &mut AuthedClient,
782 query: &str,
783) -> Result<Vec<Vec<serde_json::Value>>, McpRequestError> {
784 let mut response = SqlResponse::new();
785
786 execute_request(
787 client,
788 SqlRequest::Simple {
789 query: query.to_string(),
790 },
791 &mut response,
792 )
793 .await
794 .map_err(|e| McpRequestError::QueryExecutionFailed(e.to_string()))?;
795
796 for result in response.results {
799 match result {
800 SqlResult::Rows { rows, .. } => return Ok(rows),
801 SqlResult::Err { error, .. } => {
802 return Err(McpRequestError::QueryExecutionFailed(error.message));
803 }
804 SqlResult::Ok { .. } => continue,
805 }
806 }
807
808 Err(McpRequestError::QueryExecutionFailed(
809 "Query did not return any results".to_string(),
810 ))
811}
812
813fn format_rows_response(
820 rows: Vec<Vec<serde_json::Value>>,
821 max_size: usize,
822) -> Result<McpResult, McpRequestError> {
823 let text =
824 serde_json::to_string_pretty(&rows).map_err(|e| McpRequestError::Internal(anyhow!(e)))?;
825
826 if text.len() > max_size {
827 return Err(McpRequestError::QueryExecutionFailed(format!(
828 "Response size ({} bytes) exceeds the {} byte limit. \
829 Use LIMIT or WHERE to narrow your query.",
830 text.len(),
831 max_size,
832 )));
833 }
834
835 Ok(McpResult::ToolContent(ToolContentResult {
836 content: vec![ContentBlock {
837 content_type: "text".to_string(),
838 text,
839 }],
840 is_error: false,
841 }))
842}
843
844async fn get_data_products(
845 client: &mut AuthedClient,
846 max_response_size: usize,
847) -> Result<McpResult, McpRequestError> {
848 debug!("Executing get_data_products");
849 let rows = execute_sql(client, DISCOVERY_QUERY).await?;
850 debug!("get_data_products returned {} rows", rows.len());
851
852 format_rows_response(rows, max_response_size)
853}
854
855async fn get_data_product_details(
856 client: &mut AuthedClient,
857 name: &str,
858 max_response_size: usize,
859) -> Result<McpResult, McpRequestError> {
860 debug!(name = %name, "Executing get_data_product_details");
861
862 let query = format!("{}{}", DETAILS_QUERY_PREFIX, escaped_string_literal(name));
863
864 let rows = execute_sql(client, &query).await?;
865
866 if rows.is_empty() {
867 return Err(McpRequestError::DataProductNotFound(name.to_string()));
868 }
869
870 format_rows_response(rows, max_response_size)
871}
872
873fn safe_data_product_name(name: &str) -> Result<String, McpRequestError> {
880 let name = name.trim();
881 if name.is_empty() {
882 return Err(McpRequestError::QueryValidationFailed(
883 "Data product name cannot be empty".to_string(),
884 ));
885 }
886
887 let parsed = parse_item_name(name).map_err(|_| {
888 McpRequestError::QueryValidationFailed(format!(
889 "Invalid data product name: {}. Expected a valid object name, \
890 e.g. '\"database\".\"schema\".\"name\"' or 'my_view'",
891 name
892 ))
893 })?;
894
895 Ok(parsed.to_ast_string_stable())
898}
899
900async fn read_data_product(
908 client: &mut AuthedClient,
909 name: &str,
910 limit: u32,
911 cluster_override: Option<&str>,
912 max_response_size: usize,
913) -> Result<McpResult, McpRequestError> {
914 debug!(name = %name, limit = limit, cluster_override = ?cluster_override, "Executing read_data_product");
915
916 let safe_name = safe_data_product_name(name)?;
918
919 let lookup_query = format!(
927 "SELECT 1 FROM mz_internal.mz_mcp_data_products WHERE object_name = {}",
928 escaped_string_literal(name)
929 );
930 let lookup_rows = execute_sql(client, &lookup_query).await?;
931 if lookup_rows.is_empty() {
932 return Err(McpRequestError::DataProductNotFound(name.to_string()));
933 }
934
935 let clamped_limit = limit.min(MAX_READ_LIMIT);
936
937 let read_query = match cluster_override {
938 Some(cluster) => format!(
939 "BEGIN READ ONLY; SET CLUSTER = {}; SELECT * FROM {} LIMIT {}; COMMIT;",
940 escaped_string_literal(cluster),
941 safe_name,
942 clamped_limit,
943 ),
944 None => format!("SELECT * FROM {} LIMIT {}", safe_name, clamped_limit),
946 };
947
948 let rows = execute_sql(client, &read_query).await?;
949
950 format_rows_response(rows, max_response_size)
951}
952
953fn validate_readonly_query(sql: &str) -> Result<(), McpRequestError> {
955 let sql = sql.trim();
956 if sql.is_empty() {
957 return Err(McpRequestError::QueryValidationFailed(
958 "Empty query".to_string(),
959 ));
960 }
961
962 let stmts = parse(sql).map_err(|e| {
964 McpRequestError::QueryValidationFailed(format!("Failed to parse SQL: {}", e))
965 })?;
966
967 if stmts.len() != 1 {
969 return Err(McpRequestError::QueryValidationFailed(format!(
970 "Only one query allowed at a time. Found {} statements.",
971 stmts.len()
972 )));
973 }
974
975 let stmt = &stmts[0];
977 use mz_sql_parser::ast::Statement;
978
979 match &stmt.ast {
980 Statement::Select(_) | Statement::Show(_) | Statement::ExplainPlan(_) => {
981 Ok(())
983 }
984 _ => Err(McpRequestError::QueryValidationFailed(
985 "Only SELECT, SHOW, and EXPLAIN statements are allowed".to_string(),
986 )),
987 }
988}
989
990async fn execute_query(
991 client: &mut AuthedClient,
992 cluster: &str,
993 sql_query: &str,
994 max_response_size: usize,
995) -> Result<McpResult, McpRequestError> {
996 debug!(cluster = %cluster, "Executing user query");
997
998 validate_readonly_query(sql_query)?;
999
1000 let combined_query = format!(
1003 "BEGIN READ ONLY; SET CLUSTER = {}; {}; COMMIT;",
1004 escaped_string_literal(cluster),
1005 sql_query
1006 );
1007
1008 let rows = execute_sql(client, &combined_query).await?;
1009
1010 format_rows_response(rows, max_response_size)
1011}
1012
1013async fn query_system_catalog(
1014 client: &mut AuthedClient,
1015 sql_query: &str,
1016 max_response_size: usize,
1017) -> Result<McpResult, McpRequestError> {
1018 debug!("Executing query_system_catalog");
1019
1020 validate_readonly_query(sql_query)?;
1022
1023 validate_system_catalog_query(sql_query)?;
1025
1026 let rows = execute_sql(client, sql_query).await?;
1029
1030 format_rows_response(rows, max_response_size)
1031}
1032
1033struct TableReferenceCollector {
1035 tables: Vec<(Option<String>, String)>,
1037 cte_names: std::collections::BTreeSet<String>,
1039}
1040
1041impl TableReferenceCollector {
1042 fn new() -> Self {
1043 Self {
1044 tables: Vec::new(),
1045 cte_names: std::collections::BTreeSet::new(),
1046 }
1047 }
1048}
1049
1050impl<'ast> Visit<'ast, Raw> for TableReferenceCollector {
1051 fn visit_cte(&mut self, cte: &'ast mz_sql_parser::ast::Cte<Raw>) {
1052 self.cte_names
1054 .insert(cte.alias.name.as_str().to_lowercase());
1055 visit::visit_cte(self, cte);
1056 }
1057
1058 fn visit_table_factor(&mut self, table_factor: &'ast mz_sql_parser::ast::TableFactor<Raw>) {
1059 if let mz_sql_parser::ast::TableFactor::Table { name, .. } = table_factor {
1061 match name {
1062 RawItemName::Name(n) | RawItemName::Id(_, n, _) => {
1063 let parts = &n.0;
1064 if !parts.is_empty() {
1065 let table_name = parts.last().unwrap().as_str().to_lowercase();
1066
1067 if self.cte_names.contains(&table_name) {
1069 visit::visit_table_factor(self, table_factor);
1070 return;
1071 }
1072
1073 let schema = if parts.len() >= 2 {
1075 Some(parts[parts.len() - 2].as_str().to_lowercase())
1076 } else {
1077 None
1078 };
1079 self.tables.push((schema, table_name));
1080 }
1081 }
1082 }
1083 }
1084 visit::visit_table_factor(self, table_factor);
1085 }
1086}
1087
1088fn validate_system_catalog_query(sql: &str) -> Result<(), McpRequestError> {
1096 let stmts = parse(sql).map_err(|e| {
1098 McpRequestError::QueryValidationFailed(format!("Failed to parse SQL: {}", e))
1099 })?;
1100
1101 if stmts.is_empty() {
1102 return Err(McpRequestError::QueryValidationFailed(
1103 "Empty query".to_string(),
1104 ));
1105 }
1106
1107 let mut collector = TableReferenceCollector::new();
1109 for stmt in &stmts {
1110 collector.visit_statement(&stmt.ast);
1111 }
1112
1113 let is_allowed_schema =
1116 |s: &str| SYSTEM_SCHEMAS.contains(&s) && s != namespaces::MZ_UNSAFE_SCHEMA;
1117
1118 let is_system_table = |(schema, table_name): &(Option<String>, String)| {
1120 match schema {
1121 Some(s) => is_allowed_schema(s.as_str()),
1123 None => table_name.starts_with("mz_"),
1127 }
1128 };
1129
1130 let non_system_tables: Vec<String> = collector
1132 .tables
1133 .iter()
1134 .filter(|t| !is_system_table(t))
1135 .map(|(schema, table)| match schema {
1136 Some(s) => format!("{}.{}", s, table),
1137 None => table.clone(),
1138 })
1139 .collect();
1140
1141 if !non_system_tables.is_empty() {
1142 return Err(McpRequestError::QueryValidationFailed(format!(
1143 "Query references non-system tables: {}. Only system catalog tables (mz_*, pg_catalog, information_schema) are allowed.",
1144 non_system_tables.join(", ")
1145 )));
1146 }
1147
1148 use mz_sql_parser::ast::Statement;
1151 let is_select = stmts.iter().any(|s| matches!(&s.ast, Statement::Select(_)));
1152
1153 if is_select && (collector.tables.is_empty() || !collector.tables.iter().any(is_system_table)) {
1154 return Err(McpRequestError::QueryValidationFailed(
1155 "Query must reference at least one system catalog table".to_string(),
1156 ));
1157 }
1158
1159 Ok(())
1160}
1161
1162#[cfg(test)]
1163mod tests {
1164 use super::*;
1165
1166 #[mz_ore::test]
1167 fn test_validate_readonly_query_select() {
1168 assert!(validate_readonly_query("SELECT * FROM mz_tables").is_ok());
1169 assert!(validate_readonly_query("SELECT 1 + 2").is_ok());
1170 assert!(validate_readonly_query(" SELECT 1 ").is_ok());
1171 }
1172
1173 #[mz_ore::test]
1174 fn test_validate_readonly_query_subqueries() {
1175 assert!(
1177 validate_readonly_query(
1178 "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM mz_columns)"
1179 )
1180 .is_ok()
1181 );
1182
1183 assert!(
1185 validate_readonly_query(
1186 "SELECT * FROM (SELECT * FROM mz_tables WHERE name LIKE 'test%') AS t"
1187 )
1188 .is_ok()
1189 );
1190
1191 assert!(validate_readonly_query(
1193 "SELECT * FROM mz_tables t WHERE EXISTS (SELECT 1 FROM mz_columns c WHERE c.id = t.id)"
1194 )
1195 .is_ok());
1196
1197 assert!(validate_readonly_query(
1199 "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM mz_columns WHERE type IN (SELECT name FROM mz_types))"
1200 )
1201 .is_ok());
1202
1203 assert!(
1205 validate_readonly_query(
1206 "SELECT * FROM mz_tables WHERE id = (SELECT MAX(id) FROM mz_columns)"
1207 )
1208 .is_ok()
1209 );
1210 }
1211
1212 #[mz_ore::test]
1213 fn test_validate_readonly_query_show() {
1214 assert!(validate_readonly_query("SHOW CLUSTERS").is_ok());
1215 assert!(validate_readonly_query("SHOW TABLES").is_ok());
1216 }
1217
1218 #[mz_ore::test]
1219 fn test_validate_readonly_query_explain() {
1220 assert!(validate_readonly_query("EXPLAIN SELECT 1").is_ok());
1221 }
1222
1223 #[mz_ore::test]
1224 fn test_validate_readonly_query_rejects_writes() {
1225 assert!(validate_readonly_query("INSERT INTO t VALUES (1)").is_err());
1226 assert!(validate_readonly_query("UPDATE t SET a = 1").is_err());
1227 assert!(validate_readonly_query("DELETE FROM t").is_err());
1228 assert!(validate_readonly_query("CREATE TABLE t (a INT)").is_err());
1229 assert!(validate_readonly_query("DROP TABLE t").is_err());
1230 }
1231
1232 #[mz_ore::test]
1233 fn test_validate_readonly_query_rejects_multiple() {
1234 assert!(validate_readonly_query("SELECT 1; SELECT 2").is_err());
1235 }
1236
1237 #[mz_ore::test]
1238 fn test_validate_readonly_query_rejects_empty() {
1239 assert!(validate_readonly_query("").is_err());
1240 assert!(validate_readonly_query(" ").is_err());
1241 }
1242
1243 #[mz_ore::test]
1244 fn test_validate_system_catalog_query_accepts_mz_tables() {
1245 assert!(validate_system_catalog_query("SELECT * FROM mz_tables").is_ok());
1246 assert!(validate_system_catalog_query("SELECT * FROM mz_internal.mz_comments").is_ok());
1247 assert!(
1248 validate_system_catalog_query(
1249 "SELECT * FROM mz_tables t JOIN mz_columns c ON t.id = c.id"
1250 )
1251 .is_ok()
1252 );
1253 }
1254
1255 #[mz_ore::test]
1256 fn test_validate_system_catalog_query_subqueries() {
1257 assert!(
1259 validate_system_catalog_query(
1260 "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM mz_columns)"
1261 )
1262 .is_ok()
1263 );
1264
1265 assert!(validate_system_catalog_query(
1267 "SELECT * FROM mz_tables WHERE id IN (SELECT table_id FROM mz_columns WHERE type IN (SELECT id FROM mz_types))"
1268 )
1269 .is_ok());
1270
1271 assert!(
1273 validate_system_catalog_query(
1274 "SELECT * FROM (SELECT * FROM mz_tables WHERE name LIKE 'test%') AS t"
1275 )
1276 .is_ok()
1277 );
1278
1279 assert!(
1281 validate_system_catalog_query(
1282 "SELECT * FROM mz_tables WHERE id IN (SELECT table_id FROM user_data)"
1283 )
1284 .is_err()
1285 );
1286
1287 assert!(validate_system_catalog_query(
1289 "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM (SELECT id FROM user_table) AS t)"
1290 )
1291 .is_err());
1292 }
1293
1294 #[mz_ore::test]
1295 fn test_validate_system_catalog_query_rejects_user_tables() {
1296 assert!(validate_system_catalog_query("SELECT * FROM user_data").is_err());
1297 assert!(validate_system_catalog_query("SELECT * FROM my_table").is_err());
1298 assert!(
1300 validate_system_catalog_query("SELECT * FROM user_data WHERE 'mz_' IS NOT NULL")
1301 .is_err()
1302 );
1303 }
1304
1305 #[mz_ore::test]
1306 fn test_validate_system_catalog_query_allows_functions() {
1307 assert!(
1309 validate_system_catalog_query(
1310 "SELECT date_part('year', now())::int4 AS y FROM mz_tables LIMIT 1"
1311 )
1312 .is_ok()
1313 );
1314 assert!(validate_system_catalog_query("SELECT length(name) FROM mz_tables").is_ok());
1315 assert!(
1316 validate_system_catalog_query(
1317 "SELECT count(*) FROM mz_sources WHERE now() > created_at"
1318 )
1319 .is_ok()
1320 );
1321 }
1322
1323 #[mz_ore::test]
1324 fn test_validate_system_catalog_query_schema_qualified() {
1325 assert!(validate_system_catalog_query("SELECT * FROM mz_catalog.mz_tables").is_ok());
1327 assert!(validate_system_catalog_query("SELECT * FROM mz_internal.mz_sessions").is_ok());
1328 assert!(validate_system_catalog_query("SELECT * FROM pg_catalog.pg_type").is_ok());
1329 assert!(validate_system_catalog_query("SELECT * FROM information_schema.tables").is_ok());
1330
1331 assert!(validate_system_catalog_query("SELECT * FROM public.user_table").is_err());
1333 assert!(validate_system_catalog_query("SELECT * FROM myschema.mytable").is_err());
1334
1335 assert!(
1337 validate_system_catalog_query("SELECT * FROM mz_unsafe.mz_some_table").is_err(),
1338 "mz_unsafe schema should be blocked even though it is a system schema"
1339 );
1340
1341 assert!(
1343 validate_system_catalog_query(
1344 "SELECT * FROM mz_catalog.mz_tables JOIN public.user_data ON true"
1345 )
1346 .is_err()
1347 );
1348 }
1349
1350 #[mz_ore::test]
1351 fn test_validate_system_catalog_query_adversarial_cases() {
1352 assert!(
1354 validate_system_catalog_query(
1355 "WITH user_cte AS (SELECT * FROM user_data) \
1356 SELECT * FROM mz_tables, user_cte"
1357 )
1358 .is_err(),
1359 "Should reject CTE referencing user table"
1360 );
1361
1362 assert!(
1364 validate_system_catalog_query(
1365 "WITH \
1366 cte1 AS (SELECT * FROM mz_tables), \
1367 cte2 AS (SELECT * FROM cte1), \
1368 cte3 AS (SELECT * FROM user_data) \
1369 SELECT * FROM cte2"
1370 )
1371 .is_err(),
1372 "Should reject CTE chain with user table"
1373 );
1374
1375 assert!(
1377 validate_system_catalog_query(
1378 "SELECT * FROM mz_tables t1 \
1379 JOIN user_data u ON t1.id = u.id \
1380 JOIN mz_sources s ON t1.id = s.id"
1381 )
1382 .is_err(),
1383 "Should reject multi-join with user table"
1384 );
1385
1386 assert!(
1388 validate_system_catalog_query(
1389 "SELECT * FROM mz_tables t \
1390 LEFT JOIN user_data u ON t.id = u.table_id \
1391 WHERE u.id IS NULL"
1392 )
1393 .is_err(),
1394 "Should reject LEFT JOIN with user table"
1395 );
1396
1397 assert!(
1399 validate_system_catalog_query(
1400 "SELECT * FROM mz_tables WHERE id IN \
1401 (SELECT table_id FROM (SELECT * FROM user_data) AS u)"
1402 )
1403 .is_err(),
1404 "Should reject nested subquery with user table"
1405 );
1406
1407 assert!(
1409 validate_system_catalog_query(
1410 "SELECT name FROM mz_tables \
1411 UNION \
1412 SELECT name FROM user_data"
1413 )
1414 .is_err(),
1415 "Should reject UNION with user table"
1416 );
1417
1418 assert!(
1420 validate_system_catalog_query(
1421 "SELECT id FROM mz_sources \
1422 UNION ALL \
1423 SELECT id FROM products"
1424 )
1425 .is_err(),
1426 "Should reject UNION ALL with user table"
1427 );
1428
1429 assert!(
1431 validate_system_catalog_query("SELECT * FROM mz_tables CROSS JOIN user_data").is_err(),
1432 "Should reject CROSS JOIN with user table"
1433 );
1434
1435 assert!(
1437 validate_system_catalog_query(
1438 "SELECT t.*, (SELECT COUNT(*) FROM user_data) AS cnt FROM mz_tables t"
1439 )
1440 .is_err(),
1441 "Should reject subquery in SELECT with user table"
1442 );
1443
1444 assert!(
1446 validate_system_catalog_query("SELECT * FROM mz_catalogg.fake_table").is_err(),
1447 "Should reject typo-squatting schema name"
1448 );
1449 assert!(
1450 validate_system_catalog_query("SELECT * FROM mz_catalog_hack.fake_table").is_err(),
1451 "Should reject fake schema with mz_catalog prefix"
1452 );
1453
1454 assert!(
1456 validate_system_catalog_query(
1457 "SELECT * FROM mz_tables t, LATERAL (SELECT * FROM user_data WHERE id = t.id) u"
1458 )
1459 .is_err(),
1460 "Should reject LATERAL join with user table"
1461 );
1462
1463 assert!(
1465 validate_system_catalog_query(
1466 "WITH \
1467 tables AS (SELECT * FROM mz_tables), \
1468 sources AS (SELECT * FROM mz_sources) \
1469 SELECT t.name, s.name \
1470 FROM tables t \
1471 JOIN sources s ON t.id = s.id \
1472 WHERE t.id IN (SELECT id FROM mz_columns)"
1473 )
1474 .is_ok(),
1475 "Should allow complex query with only system tables"
1476 );
1477
1478 assert!(
1480 validate_system_catalog_query(
1481 "SELECT name FROM mz_tables \
1482 UNION \
1483 SELECT name FROM mz_sources"
1484 )
1485 .is_ok(),
1486 "Should allow UNION of system tables"
1487 );
1488 }
1489
1490 #[mz_ore::test]
1491 fn test_validate_system_catalog_query_rejects_constant_queries() {
1492 assert!(
1495 validate_system_catalog_query("SELECT 1").is_err(),
1496 "Should reject constant SELECT with no table references"
1497 );
1498 assert!(
1499 validate_system_catalog_query("SELECT 1 + 2, 'hello'").is_err(),
1500 "Should reject constant expression SELECT"
1501 );
1502 assert!(
1503 validate_system_catalog_query("SELECT now()").is_err(),
1504 "Should reject function-only SELECT with no table references"
1505 );
1506 }
1507
1508 #[mz_ore::test]
1509 fn test_validate_system_catalog_query_rejects_mixed_tables() {
1510 assert!(
1511 validate_system_catalog_query(
1512 "SELECT * FROM mz_tables t JOIN user_data u ON t.id = u.table_id"
1513 )
1514 .is_err()
1515 );
1516 }
1517
1518 #[mz_ore::test]
1519 fn test_validate_system_catalog_query_allows_show() {
1520 assert!(
1522 validate_system_catalog_query("SHOW TABLES FROM mz_internal").is_ok(),
1523 "SHOW TABLES FROM mz_internal should be allowed"
1524 );
1525 assert!(
1526 validate_system_catalog_query("SHOW TABLES FROM mz_catalog").is_ok(),
1527 "SHOW TABLES FROM mz_catalog should be allowed"
1528 );
1529 assert!(
1530 validate_system_catalog_query("SHOW CLUSTERS").is_ok(),
1531 "SHOW CLUSTERS should be allowed"
1532 );
1533 assert!(
1534 validate_system_catalog_query("SHOW SOURCES").is_ok(),
1535 "SHOW SOURCES should be allowed"
1536 );
1537 assert!(
1538 validate_system_catalog_query("SHOW TABLES").is_ok(),
1539 "SHOW TABLES should be allowed"
1540 );
1541 }
1542
1543 #[mz_ore::test]
1544 fn test_validate_system_catalog_query_allows_explain() {
1545 assert!(
1546 validate_system_catalog_query("EXPLAIN SELECT * FROM mz_tables").is_ok(),
1547 "EXPLAIN of system table query should be allowed"
1548 );
1549 assert!(
1550 validate_system_catalog_query("EXPLAIN SELECT 1").is_ok(),
1551 "EXPLAIN SELECT 1 should be allowed"
1552 );
1553 }
1554
1555 #[mz_ore::test(tokio::test)]
1558 async fn test_tools_list_agents_query_tool_disabled() {
1559 let result = handle_tools_list(McpEndpointType::Agents, false, 1_000_000)
1560 .await
1561 .unwrap();
1562 let McpResult::ToolsList(list) = result else {
1563 panic!("Expected ToolsList result");
1564 };
1565 let tool_names: Vec<&str> = list.tools.iter().map(|t| t.name.as_str()).collect();
1566 assert!(
1567 tool_names.contains(&"get_data_products"),
1568 "get_data_products should always be present"
1569 );
1570 assert!(
1571 tool_names.contains(&"get_data_product_details"),
1572 "get_data_product_details should always be present"
1573 );
1574 assert!(
1575 tool_names.contains(&"read_data_product"),
1576 "read_data_product should always be present"
1577 );
1578 assert!(
1579 !tool_names.contains(&"query"),
1580 "query tool should be hidden when disabled"
1581 );
1582 }
1583
1584 #[mz_ore::test(tokio::test)]
1585 async fn test_tools_list_agents_query_tool_enabled() {
1586 let result = handle_tools_list(McpEndpointType::Agents, true, 1_000_000)
1587 .await
1588 .unwrap();
1589 let McpResult::ToolsList(list) = result else {
1590 panic!("Expected ToolsList result");
1591 };
1592 let tool_names: Vec<&str> = list.tools.iter().map(|t| t.name.as_str()).collect();
1593 assert!(
1594 tool_names.contains(&"get_data_products"),
1595 "get_data_products should always be present"
1596 );
1597 assert!(
1598 tool_names.contains(&"get_data_product_details"),
1599 "get_data_product_details should always be present"
1600 );
1601 assert!(
1602 tool_names.contains(&"read_data_product"),
1603 "read_data_product should always be present"
1604 );
1605 assert!(
1606 tool_names.contains(&"query"),
1607 "query tool should be present when enabled"
1608 );
1609 }
1610
1611 #[mz_ore::test(tokio::test)]
1612 async fn test_tools_list_observatory_unaffected_by_query_flag() {
1613 for flag in [true, false] {
1615 let result = handle_tools_list(McpEndpointType::Observatory, flag, 1_000_000)
1616 .await
1617 .unwrap();
1618 let McpResult::ToolsList(list) = result else {
1619 panic!("Expected ToolsList result");
1620 };
1621 let tool_names: Vec<&str> = list.tools.iter().map(|t| t.name.as_str()).collect();
1622 assert!(
1623 tool_names.contains(&"query_system_catalog"),
1624 "query_system_catalog should always be present on observatory"
1625 );
1626 assert!(
1627 !tool_names.contains(&"query"),
1628 "query tool should never appear on observatory"
1629 );
1630 }
1631 }
1632
1633 #[mz_ore::test]
1636 fn test_safe_data_product_name_valid() {
1637 assert_eq!(
1639 safe_data_product_name(r#""materialize"."public"."my_view""#).unwrap(),
1640 r#""materialize"."public"."my_view""#
1641 );
1642 assert_eq!(
1644 safe_data_product_name(r#""public"."my_view""#).unwrap(),
1645 r#""public"."my_view""#
1646 );
1647 assert_eq!(safe_data_product_name("my_view").unwrap(), r#""my_view""#);
1649 }
1650
1651 #[mz_ore::test]
1652 fn test_safe_data_product_name_rejects_empty() {
1653 assert!(safe_data_product_name("").is_err());
1654 assert!(safe_data_product_name(" ").is_err());
1655 }
1656
1657 #[mz_ore::test]
1658 fn test_safe_data_product_name_rejects_sql_injection() {
1659 assert!(safe_data_product_name("my_view; DROP TABLE users").is_err());
1661 assert!(safe_data_product_name("my_view UNION SELECT * FROM secrets").is_err());
1663 assert!(safe_data_product_name("my_view, secrets").is_err());
1665 assert!(safe_data_product_name("my_view WHERE 1=1 --").is_err());
1667 }
1668
1669 #[mz_ore::test]
1672 fn test_validate_origin_no_header() {
1673 let headers = HeaderMap::new();
1674 assert!(validate_origin(&headers).is_none(), "No Origin = allow");
1675 }
1676
1677 #[mz_ore::test]
1678 fn test_validate_origin_matching() {
1679 let mut headers = HeaderMap::new();
1680 headers.insert(http::header::ORIGIN, "https://example.com".parse().unwrap());
1681 headers.insert(http::header::HOST, "example.com".parse().unwrap());
1682 assert!(
1683 validate_origin(&headers).is_none(),
1684 "Matching Origin = allow"
1685 );
1686 }
1687
1688 #[mz_ore::test]
1689 fn test_validate_origin_mismatch() {
1690 let mut headers = HeaderMap::new();
1691 headers.insert(http::header::ORIGIN, "https://evil.com".parse().unwrap());
1692 headers.insert(http::header::HOST, "example.com".parse().unwrap());
1693 assert!(
1694 validate_origin(&headers).is_some(),
1695 "Mismatched Origin = reject"
1696 );
1697 }
1698
1699 #[mz_ore::test]
1700 fn test_validate_origin_no_host() {
1701 let mut headers = HeaderMap::new();
1702 headers.insert(http::header::ORIGIN, "https://example.com".parse().unwrap());
1703 assert!(
1704 validate_origin(&headers).is_some(),
1705 "Origin with no Host = reject"
1706 );
1707 }
1708
1709 #[mz_ore::test]
1710 fn test_mcp_error_codes() {
1711 assert_eq!(
1712 McpRequestError::InvalidJsonRpcVersion.error_code(),
1713 error_codes::INVALID_REQUEST
1714 );
1715 assert_eq!(
1716 McpRequestError::MethodNotFound("test".to_string()).error_code(),
1717 error_codes::METHOD_NOT_FOUND
1718 );
1719 assert_eq!(
1720 McpRequestError::QueryExecutionFailed("test".to_string()).error_code(),
1721 error_codes::INTERNAL_ERROR
1722 );
1723 }
1724}