1use std::sync::Arc;
27use std::time::Duration;
28
29use anyhow::anyhow;
30use axum::Extension;
31use axum::Json;
32use axum::response::IntoResponse;
33use http::{HeaderMap, HeaderValue, StatusCode};
34use mz_adapter_types::dyncfgs::{
35 ENABLE_MCP_AGENT, ENABLE_MCP_AGENT_QUERY_TOOL, ENABLE_MCP_DEVELOPER,
36 ENABLE_MCP_DEVELOPER_QUERY_TOOL, MCP_MAX_RESPONSE_SIZE,
37};
38use mz_repr::namespaces::{self, SYSTEM_SCHEMAS};
39use mz_sql::parse::parse;
40use mz_sql::session::metadata::SessionMetadata;
41use mz_sql::session::vars::{APPLICATION_NAME, Var, VarInput};
42use mz_sql_parser::ast::display::{AstDisplay, escaped_string_literal};
43use mz_sql_parser::ast::visit::{self, Visit};
44use mz_sql_parser::ast::{Raw, RawItemName};
45use mz_sql_parser::parser::parse_item_name;
46use serde::{Deserialize, Serialize};
47use serde_json::json;
48use thiserror::Error;
49use tracing::{debug, warn};
50
51use crate::http::AuthedClient;
52use crate::http::mcp_metrics::{McpMetrics, ToolCallGuard};
53use crate::http::sql::{SqlRequest, SqlResponse, SqlResult, execute_request};
54
55const JSONRPC_VERSION: &str = "2.0";
59
60const MCP_PROTOCOL_VERSION: &str = "2025-11-25";
63
64const MCP_REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
69
70const DISCOVERY_QUERY: &str = "SELECT * FROM mz_internal.mz_mcp_data_products";
72const DETAILS_QUERY_PREFIX: &str =
73 "SELECT * FROM mz_internal.mz_mcp_data_product_details WHERE object_name = ";
74
75#[derive(Debug, Error)]
77enum McpRequestError {
78 #[error("Invalid JSON-RPC version: expected 2.0")]
79 InvalidJsonRpcVersion,
80 #[error("Method not found: {0}")]
81 #[allow(dead_code)] MethodNotFound(String),
83 #[error("Tool not found: {0}")]
84 ToolNotFound(String),
85 #[error("Data product not found: {0}")]
86 DataProductNotFound(String),
87 #[error("{0}")]
88 ClusterPrivilegeMissing(String),
89 #[error("Query validation failed: {0}")]
90 QueryValidationFailed(String),
91 #[error("Query execution failed: {0}")]
92 QueryExecutionFailed(String),
93 #[error("Internal error: {0}")]
94 Internal(#[from] anyhow::Error),
95}
96
97impl McpRequestError {
98 fn error_code(&self) -> i32 {
99 match self {
100 Self::InvalidJsonRpcVersion => error_codes::INVALID_REQUEST,
101 Self::MethodNotFound(_) => error_codes::METHOD_NOT_FOUND,
102 Self::ToolNotFound(_) => error_codes::INVALID_PARAMS,
103 Self::DataProductNotFound(_) => error_codes::INVALID_PARAMS,
104 Self::ClusterPrivilegeMissing(_) => error_codes::INVALID_PARAMS,
105 Self::QueryValidationFailed(_) => error_codes::INVALID_PARAMS,
106 Self::QueryExecutionFailed(_) | Self::Internal(_) => error_codes::INTERNAL_ERROR,
107 }
108 }
109
110 fn error_type(&self) -> &'static str {
111 match self {
112 Self::InvalidJsonRpcVersion => "InvalidRequest",
113 Self::MethodNotFound(_) => "MethodNotFound",
114 Self::ToolNotFound(_) => "ToolNotFound",
115 Self::DataProductNotFound(_) => "DataProductNotFound",
116 Self::ClusterPrivilegeMissing(_) => "ClusterPrivilegeMissing",
117 Self::QueryValidationFailed(_) => "ValidationError",
118 Self::QueryExecutionFailed(_) => "ExecutionError",
119 Self::Internal(_) => "InternalError",
120 }
121 }
122}
123
124#[derive(Debug, Deserialize)]
126pub(crate) struct McpRequest {
127 jsonrpc: String,
128 id: Option<serde_json::Value>,
129 #[serde(flatten)]
130 method: McpMethod,
131}
132
133#[derive(Debug, Deserialize)]
135#[serde(tag = "method", content = "params")]
136enum McpMethod {
137 #[serde(rename = "initialize")]
139 Initialize(#[allow(dead_code)] InitializeParams),
140 #[serde(rename = "tools/list")]
141 ToolsList,
142 #[serde(rename = "tools/call")]
143 ToolsCall(ToolsCallParams),
144 #[serde(other)]
146 Unknown,
147}
148
149impl std::fmt::Display for McpMethod {
150 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151 match self {
152 McpMethod::Initialize(_) => write!(f, "initialize"),
153 McpMethod::ToolsList => write!(f, "tools/list"),
154 McpMethod::ToolsCall(_) => write!(f, "tools/call"),
155 McpMethod::Unknown => write!(f, "unknown"),
156 }
157 }
158}
159
160#[derive(Debug, Deserialize)]
161struct InitializeParams {
162 #[serde(rename = "protocolVersion")]
164 #[allow(dead_code)]
165 protocol_version: String,
166 #[serde(default)]
168 #[allow(dead_code)]
169 capabilities: serde_json::Value,
170 #[serde(rename = "clientInfo")]
172 #[allow(dead_code)]
173 client_info: Option<ClientInfo>,
174}
175
176#[derive(Debug, Deserialize)]
177struct ClientInfo {
178 #[allow(dead_code)]
179 name: String,
180 #[allow(dead_code)]
181 version: String,
182}
183
184#[derive(Debug, Deserialize)]
187#[serde(tag = "name", content = "arguments")]
188#[serde(rename_all = "snake_case")]
189enum ToolsCallParams {
190 GetDataProducts(#[serde(default)] ()),
193 GetDataProductDetails(GetDataProductDetailsParams),
194 ReadDataProduct(ReadDataProductParams),
195 Query(QueryParams),
196 QuerySystemCatalog(QuerySystemCatalogParams),
198}
199
200impl std::fmt::Display for ToolsCallParams {
201 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202 match self {
203 ToolsCallParams::GetDataProducts(_) => write!(f, "get_data_products"),
204 ToolsCallParams::GetDataProductDetails(_) => write!(f, "get_data_product_details"),
205 ToolsCallParams::ReadDataProduct(_) => write!(f, "read_data_product"),
206 ToolsCallParams::Query(_) => write!(f, "query"),
207 ToolsCallParams::QuerySystemCatalog(_) => write!(f, "query_system_catalog"),
208 }
209 }
210}
211
212#[derive(Debug, Deserialize)]
213struct GetDataProductDetailsParams {
214 name: String,
215}
216
217#[derive(Debug, Deserialize)]
218struct ReadDataProductParams {
219 name: String,
220 #[serde(default = "default_read_limit")]
221 limit: u32,
222 cluster: Option<String>,
223}
224
225fn default_read_limit() -> u32 {
226 500
227}
228
229#[derive(Debug, Deserialize)]
230struct QueryParams {
231 cluster: String,
232 sql_query: String,
233}
234
235#[derive(Debug, Deserialize)]
236struct QuerySystemCatalogParams {
237 sql_query: String,
238}
239
240#[derive(Debug, Serialize)]
241struct McpResponse {
242 jsonrpc: String,
243 id: serde_json::Value,
244 #[serde(skip_serializing_if = "Option::is_none")]
245 result: Option<McpResult>,
246 #[serde(skip_serializing_if = "Option::is_none")]
247 error: Option<McpError>,
248}
249
250#[derive(Debug, Serialize)]
252#[serde(untagged)]
253enum McpResult {
254 Initialize(InitializeResult),
255 ToolsList(ToolsListResult),
256 ToolContent(ToolContentResult),
257}
258
259#[derive(Debug, Serialize)]
260struct InitializeResult {
261 #[serde(rename = "protocolVersion")]
262 protocol_version: String,
263 capabilities: Capabilities,
264 #[serde(rename = "serverInfo")]
265 server_info: ServerInfo,
266 #[serde(skip_serializing_if = "Option::is_none")]
267 instructions: Option<String>,
268}
269
270#[derive(Debug, Serialize)]
271struct Capabilities {
272 tools: serde_json::Value,
273}
274
275#[derive(Debug, Serialize)]
276struct ServerInfo {
277 name: String,
278 version: String,
279}
280
281#[derive(Debug, Serialize)]
282struct ToolsListResult {
283 tools: Vec<ToolDefinition>,
284}
285
286#[derive(Debug, Serialize)]
287struct ToolDefinition {
288 name: String,
289 #[serde(skip_serializing_if = "Option::is_none")]
290 title: Option<String>,
291 description: String,
292 #[serde(rename = "inputSchema")]
293 input_schema: serde_json::Value,
294 #[serde(skip_serializing_if = "Option::is_none")]
295 annotations: Option<ToolAnnotations>,
296}
297
298#[derive(Debug, Serialize)]
301struct ToolAnnotations {
302 #[serde(rename = "readOnlyHint", skip_serializing_if = "Option::is_none")]
303 read_only_hint: Option<bool>,
304 #[serde(rename = "destructiveHint", skip_serializing_if = "Option::is_none")]
305 destructive_hint: Option<bool>,
306 #[serde(rename = "idempotentHint", skip_serializing_if = "Option::is_none")]
307 idempotent_hint: Option<bool>,
308 #[serde(rename = "openWorldHint", skip_serializing_if = "Option::is_none")]
309 open_world_hint: Option<bool>,
310}
311
312const READ_ONLY_ANNOTATIONS: ToolAnnotations = ToolAnnotations {
314 read_only_hint: Some(true),
315 destructive_hint: Some(false),
316 idempotent_hint: Some(true),
317 open_world_hint: Some(false),
318};
319
320#[derive(Debug, Serialize)]
321struct ToolContentResult {
322 content: Vec<ContentBlock>,
323 #[serde(rename = "isError")]
324 is_error: bool,
325}
326
327#[derive(Debug, Serialize)]
328struct ContentBlock {
329 #[serde(rename = "type")]
330 content_type: String,
331 text: String,
332}
333
334mod error_codes {
336 pub const INVALID_REQUEST: i32 = -32600;
337 pub const METHOD_NOT_FOUND: i32 = -32601;
338 pub const INVALID_PARAMS: i32 = -32602;
339 pub const INTERNAL_ERROR: i32 = -32603;
340}
341
342#[derive(Debug, Serialize)]
343struct McpError {
344 code: i32,
345 message: String,
346 #[serde(skip_serializing_if = "Option::is_none")]
347 data: Option<serde_json::Value>,
348}
349
350impl From<McpRequestError> for McpError {
351 fn from(err: McpRequestError) -> Self {
352 McpError {
353 code: err.error_code(),
354 message: err.to_string(),
355 data: Some(json!({
356 "error_type": err.error_type(),
357 })),
358 }
359 }
360}
361
362#[derive(Debug, Clone, Copy)]
363enum McpEndpointType {
364 Agent,
365 Developer,
366}
367
368impl McpEndpointType {
369 fn as_label(self) -> &'static str {
372 match self {
373 McpEndpointType::Agent => "agent",
374 McpEndpointType::Developer => "developer",
375 }
376 }
377}
378
379impl std::fmt::Display for McpEndpointType {
380 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
381 f.write_str(self.as_label())
382 }
383}
384
385pub async fn handle_mcp_method_not_allowed() -> impl IntoResponse {
388 StatusCode::METHOD_NOT_ALLOWED
389}
390
391pub async fn handle_mcp_agent(
393 headers: HeaderMap,
394 Extension(allowed_origins): Extension<Arc<Vec<HeaderValue>>>,
395 Extension(metrics): Extension<McpMetrics>,
396 client: AuthedClient,
397 Json(body): Json<McpRequest>,
398) -> axum::response::Response {
399 if let Some(resp) = validate_origin(&headers, &allowed_origins) {
400 return resp;
401 }
402 handle_mcp_request(client, body, McpEndpointType::Agent, metrics)
403 .await
404 .into_response()
405}
406
407pub async fn handle_mcp_developer(
409 headers: HeaderMap,
410 Extension(allowed_origins): Extension<Arc<Vec<HeaderValue>>>,
411 Extension(metrics): Extension<McpMetrics>,
412 client: AuthedClient,
413 Json(body): Json<McpRequest>,
414) -> axum::response::Response {
415 if let Some(resp) = validate_origin(&headers, &allowed_origins) {
416 return resp;
417 }
418 handle_mcp_request(client, body, McpEndpointType::Developer, metrics)
419 .await
420 .into_response()
421}
422
423fn validate_origin(
432 headers: &HeaderMap,
433 allowed: &[HeaderValue],
434) -> Option<axum::response::Response> {
435 let origin = headers.get(http::header::ORIGIN)?;
436 if mz_http_util::origin_is_allowed(origin, allowed) {
437 return None;
438 }
439 warn!(
440 origin = ?origin,
441 "MCP request rejected: origin not in allowlist",
442 );
443 Some(StatusCode::FORBIDDEN.into_response())
444}
445
446async fn handle_mcp_request(
447 mut client: AuthedClient,
448 request: McpRequest,
449 endpoint_type: McpEndpointType,
450 metrics: McpMetrics,
451) -> impl IntoResponse {
452 let endpoint_label = endpoint_type.as_label();
453 let method_label = request.method.to_string();
454 let record_request = |status: &str| {
455 metrics
456 .requests
457 .with_label_values(&[endpoint_label, &method_label, status])
458 .inc();
459 };
460
461 let catalog = client.client.catalog_snapshot("mcp").await;
463 let dyncfgs = catalog.system_config().dyncfgs();
464 let enabled = match endpoint_type {
465 McpEndpointType::Agent => ENABLE_MCP_AGENT.get(dyncfgs),
466 McpEndpointType::Developer => ENABLE_MCP_DEVELOPER.get(dyncfgs),
467 };
468 if !enabled {
469 debug!(endpoint = %endpoint_type, "MCP endpoint disabled by feature flag");
470 record_request("endpoint_disabled");
471 return StatusCode::SERVICE_UNAVAILABLE.into_response();
472 }
473
474 let query_tool_enabled = match endpoint_type {
479 McpEndpointType::Agent => ENABLE_MCP_AGENT_QUERY_TOOL.get(dyncfgs),
480 McpEndpointType::Developer => ENABLE_MCP_DEVELOPER_QUERY_TOOL.get(dyncfgs),
481 };
482 let max_response_size = MCP_MAX_RESPONSE_SIZE.get(dyncfgs);
483
484 let app_name = match endpoint_type {
488 McpEndpointType::Agent => "mz_mcp_agents",
489 McpEndpointType::Developer => "mz_mcp_developer",
490 };
491 client
492 .client
493 .session()
494 .vars_mut()
495 .set_default(APPLICATION_NAME.name(), VarInput::Flat(app_name))
496 .expect("application_name is a known session var");
497
498 let user = client.client.session().user().name.clone();
499 let is_notification = request.id.is_none();
500
501 debug!(
502 method = %request.method,
503 endpoint = %endpoint_type,
504 user = %user,
505 is_notification = is_notification,
506 "MCP request received"
507 );
508
509 if is_notification {
511 debug!(method = %request.method, "Received notification (no response will be sent)");
512 record_request("ok");
513 return StatusCode::OK.into_response();
514 }
515
516 let request_id = request.id.clone().unwrap_or(serde_json::Value::Null);
517
518 let metrics_inner = metrics.clone();
523 let result = tokio::time::timeout(
524 MCP_REQUEST_TIMEOUT,
525 mz_ore::task::spawn(|| "mcp_request", async move {
526 handle_mcp_request_inner(
527 &mut client,
528 request,
529 endpoint_type,
530 query_tool_enabled,
531 max_response_size,
532 metrics_inner,
533 )
534 .await
535 })
536 .abort_on_drop(),
537 )
538 .await;
539
540 let (response, status_label): (McpResponse, &'static str) = match result {
541 Ok(inner) => inner,
542 Err(_elapsed) => {
543 warn!(
544 endpoint = %endpoint_type,
545 timeout = ?MCP_REQUEST_TIMEOUT,
546 "MCP request timed out",
547 );
548 let response = McpResponse {
549 jsonrpc: JSONRPC_VERSION.to_string(),
550 id: request_id,
551 result: None,
552 error: Some(
553 McpRequestError::QueryExecutionFailed(format!(
554 "Request timed out after {} seconds.",
555 MCP_REQUEST_TIMEOUT.as_secs(),
556 ))
557 .into(),
558 ),
559 };
560 (response, "timeout")
561 }
562 };
563
564 record_request(status_label);
565 (StatusCode::OK, Json(response)).into_response()
566}
567
568async fn handle_mcp_request_inner(
569 client: &mut AuthedClient,
570 request: McpRequest,
571 endpoint_type: McpEndpointType,
572 query_tool_enabled: bool,
573 max_response_size: usize,
574 metrics: McpMetrics,
575) -> (McpResponse, &'static str) {
576 let request_id = request.id.clone().unwrap_or(serde_json::Value::Null);
578
579 let result = handle_mcp_method(
580 client,
581 &request,
582 endpoint_type,
583 query_tool_enabled,
584 max_response_size,
585 &metrics,
586 )
587 .await;
588
589 let status_label = match &result {
590 Ok(_) => "ok",
591 Err(e) => e.error_type(),
592 };
593
594 let response = match result {
595 Ok(result_value) => McpResponse {
596 jsonrpc: JSONRPC_VERSION.to_string(),
597 id: request_id,
598 result: Some(result_value),
599 error: None,
600 },
601 Err(e) => {
602 if !matches!(
604 e,
605 McpRequestError::MethodNotFound(_) | McpRequestError::InvalidJsonRpcVersion
606 ) {
607 warn!(error = %e, method = %request.method, "MCP method execution failed");
608 }
609 McpResponse {
610 jsonrpc: JSONRPC_VERSION.to_string(),
611 id: request_id,
612 result: None,
613 error: Some(e.into()),
614 }
615 }
616 };
617
618 (response, status_label)
619}
620
621async fn handle_mcp_method(
622 client: &mut AuthedClient,
623 request: &McpRequest,
624 endpoint_type: McpEndpointType,
625 query_tool_enabled: bool,
626 max_response_size: usize,
627 metrics: &McpMetrics,
628) -> Result<McpResult, McpRequestError> {
629 if request.jsonrpc != JSONRPC_VERSION {
631 return Err(McpRequestError::InvalidJsonRpcVersion);
632 }
633
634 match &request.method {
636 McpMethod::Initialize(_) => {
637 debug!(endpoint = %endpoint_type, "Processing initialize");
638 handle_initialize(endpoint_type, query_tool_enabled).await
639 }
640 McpMethod::ToolsList => {
641 debug!(endpoint = %endpoint_type, "Processing tools/list");
642 handle_tools_list(endpoint_type, query_tool_enabled, max_response_size).await
643 }
644 McpMethod::ToolsCall(params) => {
645 debug!(tool = %params, endpoint = %endpoint_type, "Processing tools/call");
646 handle_tools_call(
647 client,
648 params,
649 endpoint_type,
650 query_tool_enabled,
651 max_response_size,
652 metrics,
653 )
654 .await
655 }
656 McpMethod::Unknown => Err(McpRequestError::MethodNotFound(
657 "unknown method".to_string(),
658 )),
659 }
660}
661
662fn endpoint_instructions(
665 endpoint_type: McpEndpointType,
666 query_tool_enabled: bool,
667) -> Option<String> {
668 match endpoint_type {
669 McpEndpointType::Agent => Some(
670 concat!(
671 "You have access to Materialize data products via MCP. ",
672 "Prefer indexed objects (served from memory) over unindexed materialized views ",
673 "(read from persistent storage). `read_data_product` automatically routes the ",
674 "read to the cluster recorded in the data product catalog so indexes are used; ",
675 "you only need to set the `cluster` parameter if you intentionally want the ",
676 "read to run on a different cluster (e.g. one with larger or more replicas). ",
677 "`get_data_product_details` returns a `hydration` object with `hydrated`, ",
678 "`replica_count`, and `hydrated_replica_count` fields. Reads never return ",
679 "partial data: a read against a not-yet-hydrated product blocks until the ",
680 "dataflow catches up, and may hit the request timeout. Check `hydrated` ",
681 "before reading: if it is false and `replica_count` is greater than 0, the ",
682 "dataflow is still warming up, so wait and retry; if `replica_count` is 0 the ",
683 "cluster has no replicas and the read cannot make progress until one is added.",
684 )
685 .to_string(),
686 ),
687 McpEndpointType::Developer => {
688 let query_tool_line = if query_tool_enabled {
692 "- query: read-only SELECT/SHOW/EXPLAIN that can also reach user objects on a named cluster. Use this for EXPLAIN ANALYZE and for inspecting user objects directly.\n"
693 } else {
694 ""
695 };
696 Some(format!(
697 "You are connected to the Materialize developer MCP server for troubleshooting and observability.\n\n\
698 Tools:\n\
699 - query_system_catalog: read-only SELECT/SHOW/EXPLAIN restricted to system catalog tables (mz_*, pg_catalog, information_schema). No cluster argument; prefer this for most catalog lookups.\n\
700 {query_tool_line}\n\
701 IMPORTANT: Before writing queries, discover table schemas using the mz_ontology tables:\n\
702 - mz_internal.mz_ontology_entity_types: what catalog entities exist and which tables they map to\n\
703 - mz_internal.mz_ontology_link_types: relationships between entities (foreign keys, metrics, etc.)\n\
704 - mz_internal.mz_ontology_properties: column names, types, and descriptions for each entity\n\
705 - mz_internal.mz_ontology_semantic_types: typed ID domains (CatalogItemId, ReplicaId, etc.)\n\n\
706 Use these to find the correct tables, join paths, and column names instead of guessing.\n\n\
707 Key rules:\n\
708 - mz_source_statuses and mz_sink_statuses use `last_status_change_at` (NOT `updated_at`)\n\
709 - mz_cluster_replica_utilization only has `replica_id` — JOIN with mz_cluster_replicas and mz_clusters to get names\n\
710 - Do NOT query mz_introspection.mz_dataflow_arrangement_sizes — it is cluster-scoped and has uint8/text type mismatches\n\
711 - Use SHOW COLUMNS FROM <table> to verify column names if unsure",
712 ))
713 }
714 }
715}
716
717async fn handle_initialize(
718 endpoint_type: McpEndpointType,
719 query_tool_enabled: bool,
720) -> Result<McpResult, McpRequestError> {
721 Ok(McpResult::Initialize(InitializeResult {
722 protocol_version: MCP_PROTOCOL_VERSION.to_string(),
723 capabilities: Capabilities { tools: json!({}) },
724 server_info: ServerInfo {
725 name: format!("materialize-mcp-{}", endpoint_type),
726 version: env!("CARGO_PKG_VERSION").to_string(),
727 },
728 instructions: endpoint_instructions(endpoint_type, query_tool_enabled),
729 }))
730}
731
732async fn handle_tools_list(
733 endpoint_type: McpEndpointType,
734 query_tool_enabled: bool,
735 max_response_size: usize,
736) -> Result<McpResult, McpRequestError> {
737 let size_hint = format!("Response limit: {} MB.", max_response_size / 1_000_000);
738
739 let tools = match endpoint_type {
740 McpEndpointType::Agent => {
741 let mut tools = vec![
742 ToolDefinition {
743 name: "get_data_products".to_string(),
744 title: Some("List Data Products".to_string()),
745 description: "Discover all available real-time data views (data products) that represent business entities like customers, orders, products, etc. Each data product provides fresh, queryable data with defined schemas. Use this first to see what data is available before querying specific information.".to_string(),
746 input_schema: json!({
747 "type": "object",
748 "properties": {},
749 "required": []
750 }),
751 annotations: Some(READ_ONLY_ANNOTATIONS),
752 },
753 ToolDefinition {
754 name: "get_data_product_details".to_string(),
755 title: Some("Get Data Product Details".to_string()),
756 description: "Get the complete schema and structure of a specific data product, plus a `hydration` object reporting whether the dataflow is ready across the cluster's replicas (`{hydrated, replica_count, hydrated_replica_count}`). This shows you exactly what fields are available, their types, and what data you can query. Reads never return partial data, so check `hydration` before reading: if `hydrated` is false and `replica_count` is greater than 0 the dataflow is still warming up (a read would block until it catches up, possibly hitting the request timeout), so wait and retry; if `replica_count` is 0 the cluster has no replicas and the read cannot make progress until one is added.".to_string(),
757 input_schema: json!({
758 "type": "object",
759 "properties": {
760 "name": {
761 "type": "string",
762 "description": "Exact name of the data product from get_data_products() list"
763 }
764 },
765 "required": ["name"]
766 }),
767 annotations: Some(READ_ONLY_ANNOTATIONS),
768 },
769 ToolDefinition {
770 name: "read_data_product".to_string(),
771 title: Some("Read Data Product".to_string()),
772 description: format!("Read rows from a specific data product. Returns up to `limit` rows (default 500). The data product must exist in the catalog (use get_data_products() to discover available products). Use this to retrieve actual data from a known data product. {size_hint}"),
773 input_schema: json!({
774 "type": "object",
775 "properties": {
776 "name": {
777 "type": "string",
778 "description": "Exact fully-qualified name of the data product (e.g. '\"materialize\".\"schema\".\"view_name\"')"
779 },
780 "limit": {
781 "type": "integer",
782 "description": "Maximum number of rows to return (default 500)",
783 "default": 500
784 },
785 "cluster": {
786 "type": "string",
787 "description": "Optional override. By default, the read runs on the cluster recorded in the data product catalog (where the index or materialized view dataflow lives), so indexed reads actually hit their arrangement. Set this only to intentionally run the same read on a different cluster — e.g. one with more or larger replicas, or to compare cost/latency."
788 }
789 },
790 "required": ["name"]
791 }),
792 annotations: Some(READ_ONLY_ANNOTATIONS),
793 },
794 ];
795 if query_tool_enabled {
796 tools.push(ToolDefinition {
797 name: "query".to_string(),
798 title: Some("Query Data Products".to_string()),
799 description: format!("Execute SQL queries against real-time data products to retrieve current business information. Use standard PostgreSQL syntax. You can JOIN multiple data products together, but ONLY if they are all hosted on the same cluster. Always specify the cluster parameter from the data product details. This provides fresh, up-to-date results from materialized views. {size_hint}"),
800 input_schema: json!({
801 "type": "object",
802 "properties": {
803 "cluster": {
804 "type": "string",
805 "description": "Exact cluster name from the data product details - required for query execution"
806 },
807 "sql_query": {
808 "type": "string",
809 "description": "PostgreSQL-compatible SELECT statement to retrieve data. Use the fully qualified data product name exactly as provided (with double quotes). You can JOIN multiple data products, but only those on the same cluster."
810 }
811 },
812 "required": ["cluster", "sql_query"]
813 }),
814 annotations: Some(READ_ONLY_ANNOTATIONS),
815 });
816 }
817 tools
818 }
819 McpEndpointType::Developer => {
820 let mut tools = vec![ToolDefinition {
821 name: "query_system_catalog".to_string(),
822 title: Some("Query System Catalog".to_string()),
823 description: concat!(
824 "Query Materialize system catalog tables for troubleshooting and observability. ",
825 "Only mz_*, pg_catalog, and information_schema tables are accessible. ",
826 "Use the mz_internal.mz_ontology_* tables to discover tables, columns, and join paths before writing queries.",
827 ).to_owned() + &format!(" {size_hint}"),
828 input_schema: json!({
829 "type": "object",
830 "properties": {
831 "sql_query": {
832 "type": "string",
833 "description": "PostgreSQL-compatible SELECT, SHOW, or EXPLAIN query referencing mz_* system catalog tables"
834 }
835 },
836 "required": ["sql_query"]
837 }),
838 annotations: Some(READ_ONLY_ANNOTATIONS),
839 }];
840 if query_tool_enabled {
841 tools.push(ToolDefinition {
842 name: "query".to_string(),
843 title: Some("Query".to_string()),
844 description: format!(
845 "Execute a read-only SQL query (SELECT, SHOW, or EXPLAIN) against any object the role can access, including system catalog and user objects. Requires a cluster, which is what enables EXPLAIN ANALYZE and queries against indexed user objects. For pure system catalog lookups that do not need a cluster, prefer `query_system_catalog`. {size_hint}",
846 ),
847 input_schema: json!({
848 "type": "object",
849 "properties": {
850 "cluster": {
851 "type": "string",
852 "description": "Exact cluster name the query should run on. Required: EXPLAIN ANALYZE and queries against indexed user objects need a specific cluster to execute on."
853 },
854 "sql_query": {
855 "type": "string",
856 "description": "PostgreSQL-compatible SELECT, SHOW, or EXPLAIN statement. Multi-statement queries are rejected."
857 }
858 },
859 "required": ["cluster", "sql_query"]
860 }),
861 annotations: Some(READ_ONLY_ANNOTATIONS),
862 });
863 }
864 tools
865 }
866 };
867
868 Ok(McpResult::ToolsList(ToolsListResult { tools }))
869}
870
871async fn handle_tools_call(
872 client: &mut AuthedClient,
873 params: &ToolsCallParams,
874 endpoint_type: McpEndpointType,
875 query_tool_enabled: bool,
876 max_response_size: usize,
877 metrics: &McpMetrics,
878) -> Result<McpResult, McpRequestError> {
879 let mut guard = ToolCallGuard::new(metrics, endpoint_type.as_label(), params.to_string());
881
882 let result = match (endpoint_type, params) {
883 (McpEndpointType::Agent, ToolsCallParams::GetDataProducts(_)) => {
884 get_data_products(client, max_response_size).await
885 }
886 (McpEndpointType::Agent, ToolsCallParams::GetDataProductDetails(p)) => {
887 get_data_product_details(client, &p.name, max_response_size).await
888 }
889 (McpEndpointType::Agent, ToolsCallParams::ReadDataProduct(p)) => {
890 read_data_product(
891 client,
892 &p.name,
893 p.limit,
894 p.cluster.as_deref(),
895 max_response_size,
896 )
897 .await
898 }
899 (McpEndpointType::Agent, ToolsCallParams::Query(_)) if !query_tool_enabled => {
900 Err(McpRequestError::ToolNotFound(
901 "query tool is not available. Use get_data_products, get_data_product_details, and read_data_product instead.".to_string(),
902 ))
903 }
904 (McpEndpointType::Agent, ToolsCallParams::Query(p)) => {
905 execute_query(client, &p.cluster, &p.sql_query, max_response_size).await
906 }
907 (McpEndpointType::Developer, ToolsCallParams::QuerySystemCatalog(p)) => {
908 query_system_catalog(client, &p.sql_query, max_response_size).await
909 }
910 (McpEndpointType::Developer, ToolsCallParams::Query(_)) if !query_tool_enabled => {
911 Err(McpRequestError::ToolNotFound(
912 "query tool is not available. Use query_system_catalog instead.".to_string(),
913 ))
914 }
915 (McpEndpointType::Developer, ToolsCallParams::Query(p)) => {
916 execute_query(client, &p.cluster, &p.sql_query, max_response_size).await
917 }
918 (endpoint, tool) => Err(McpRequestError::ToolNotFound(format!(
920 "{} is not available on {} endpoint",
921 tool, endpoint
922 ))),
923 };
924
925 guard.set_status(match &result {
926 Ok(_) => "ok",
927 Err(e) => e.error_type(),
928 });
929
930 result
931}
932
933async fn execute_sql(
935 client: &mut AuthedClient,
936 query: &str,
937) -> Result<Vec<Vec<serde_json::Value>>, McpRequestError> {
938 let mut response = SqlResponse::new();
939
940 execute_request(
941 client,
942 SqlRequest::Simple {
943 query: mz_ore::sql::Sql::trusted_external_request(query.to_string()),
944 },
945 &mut response,
946 )
947 .await
948 .map_err(|e| McpRequestError::QueryExecutionFailed(e.to_string()))?;
949
950 for result in response.results {
953 match result {
954 SqlResult::Rows { rows, .. } => return Ok(rows),
955 SqlResult::Err { error, .. } => {
956 return Err(McpRequestError::QueryExecutionFailed(error.message));
957 }
958 SqlResult::Ok { .. } => continue,
959 }
960 }
961
962 Err(McpRequestError::QueryExecutionFailed(
963 "Query did not return any results".to_string(),
964 ))
965}
966
967fn format_rows_response(
974 rows: Vec<Vec<serde_json::Value>>,
975 max_size: usize,
976) -> Result<McpResult, McpRequestError> {
977 let text =
978 serde_json::to_string_pretty(&rows).map_err(|e| McpRequestError::Internal(anyhow!(e)))?;
979
980 if text.len() > max_size {
981 return Err(McpRequestError::QueryExecutionFailed(format!(
982 "Response size ({} bytes) exceeds the {} byte limit. \
983 Use LIMIT or WHERE to narrow your query.",
984 text.len(),
985 max_size,
986 )));
987 }
988
989 Ok(McpResult::ToolContent(ToolContentResult {
990 content: vec![ContentBlock {
991 content_type: "text".to_string(),
992 text,
993 }],
994 is_error: false,
995 }))
996}
997
998async fn get_data_products(
999 client: &mut AuthedClient,
1000 max_response_size: usize,
1001) -> Result<McpResult, McpRequestError> {
1002 debug!("Executing get_data_products");
1003 let rows = execute_sql(client, DISCOVERY_QUERY).await?;
1004 debug!("get_data_products returned {} rows", rows.len());
1005
1006 format_rows_response(rows, max_response_size)
1007}
1008
1009async fn get_data_product_details(
1010 client: &mut AuthedClient,
1011 name: &str,
1012 max_response_size: usize,
1013) -> Result<McpResult, McpRequestError> {
1014 debug!(name = %name, "Executing get_data_product_details");
1015
1016 let query = format!("{}{}", DETAILS_QUERY_PREFIX, escaped_string_literal(name));
1017
1018 let rows = execute_sql(client, &query).await?;
1019
1020 if rows.is_empty() {
1021 return Err(McpRequestError::DataProductNotFound(name.to_string()));
1022 }
1023
1024 format_rows_response(rows, max_response_size)
1025}
1026
1027fn safe_data_product_name(name: &str) -> Result<String, McpRequestError> {
1034 let name = name.trim();
1035 if name.is_empty() {
1036 return Err(McpRequestError::QueryValidationFailed(
1037 "Data product name cannot be empty".to_string(),
1038 ));
1039 }
1040
1041 let parsed = parse_item_name(name).map_err(|_| {
1042 McpRequestError::QueryValidationFailed(format!(
1043 "Invalid data product name: {}. Expected a valid object name, \
1044 e.g. '\"database\".\"schema\".\"name\"' or 'my_view'",
1045 name
1046 ))
1047 })?;
1048
1049 Ok(parsed.to_ast_string_stable())
1052}
1053
1054async fn read_data_product(
1073 client: &mut AuthedClient,
1074 name: &str,
1075 limit: u32,
1076 cluster_override: Option<&str>,
1077 max_response_size: usize,
1078) -> Result<McpResult, McpRequestError> {
1079 debug!(name = %name, limit = limit, cluster_override = ?cluster_override, "Executing read_data_product");
1080
1081 let safe_name = safe_data_product_name(name)?;
1083
1084 let lookup_query = format!(
1092 "SELECT \
1093 dp.cluster, \
1094 dp.cluster IS NULL OR cp.name IS NOT NULL AS has_cluster_usage \
1095 FROM mz_internal.mz_mcp_data_products dp \
1096 LEFT JOIN mz_internal.mz_show_my_cluster_privileges cp \
1097 ON cp.name = dp.cluster AND cp.privilege_type = 'USAGE' \
1098 WHERE dp.object_name = {} \
1099 ORDER BY \
1100 (dp.cluster IS NOT NULL AND cp.name IS NOT NULL) DESC, \
1101 dp.cluster NULLS LAST \
1102 LIMIT 1",
1103 escaped_string_literal(name)
1104 );
1105 let lookup_rows = execute_sql(client, &lookup_query).await?;
1106 if lookup_rows.is_empty() {
1107 return Err(McpRequestError::DataProductNotFound(name.to_string()));
1108 }
1109 let lookup_row = lookup_rows.first();
1110 let catalog_cluster: Option<&str> = lookup_row
1111 .and_then(|row| row.first())
1112 .and_then(|v| v.as_str());
1113 let has_cluster_usage: bool = lookup_row
1116 .and_then(|row| row.get(1))
1117 .and_then(|v| v.as_bool())
1118 .unwrap_or(false);
1119
1120 let target_cluster = match cluster_override {
1126 Some(c) => c,
1127 None => match catalog_cluster {
1128 Some(c) if has_cluster_usage => c,
1129 Some(c) => {
1130 return Err(McpRequestError::ClusterPrivilegeMissing(format!(
1131 "Data product {name} is hosted on cluster {c:?}, which your role \
1132 does not have USAGE on. Pass `cluster: \"<a-cluster-you-have-USAGE-on>\"` \
1133 to read it from a different cluster (slower, no index), or have USAGE \
1134 granted on {c:?}.",
1135 )));
1136 }
1137 None => {
1138 return Err(McpRequestError::Internal(anyhow!(
1141 "data product {name} has no cluster in the catalog"
1142 )));
1143 }
1144 },
1145 };
1146
1147 let read_query = build_read_query(&safe_name, limit, target_cluster);
1152
1153 let rows = execute_sql(client, &read_query).await?;
1154
1155 format_rows_response(rows, max_response_size)
1156}
1157
1158fn build_read_query(safe_name: &str, limit: u32, target_cluster: &str) -> String {
1166 format!(
1167 "BEGIN READ ONLY; SET CLUSTER = {}; SELECT * FROM {} LIMIT {}\n; COMMIT;",
1168 escaped_string_literal(target_cluster),
1169 safe_name,
1170 limit,
1171 )
1172}
1173
1174fn validate_readonly_query(sql: &str) -> Result<(), McpRequestError> {
1176 let sql = sql.trim();
1177 if sql.is_empty() {
1178 return Err(McpRequestError::QueryValidationFailed(
1179 "Empty query".to_string(),
1180 ));
1181 }
1182
1183 let stmts = parse(sql).map_err(|e| {
1185 McpRequestError::QueryValidationFailed(format!("Failed to parse SQL: {}", e))
1186 })?;
1187
1188 if stmts.len() != 1 {
1190 return Err(McpRequestError::QueryValidationFailed(format!(
1191 "Only one query allowed at a time. Found {} statements.",
1192 stmts.len()
1193 )));
1194 }
1195
1196 let stmt = &stmts[0];
1202 use mz_sql_parser::ast::Statement;
1203
1204 match &stmt.ast {
1205 Statement::Select(_)
1206 | Statement::Show(_)
1207 | Statement::ExplainPlan(_)
1208 | Statement::ExplainPushdown(_)
1209 | Statement::ExplainTimestamp(_)
1210 | Statement::ExplainSinkSchema(_)
1211 | Statement::ExplainAnalyzeObject(_)
1212 | Statement::ExplainAnalyzeCluster(_) => Ok(()),
1213 _ => Err(McpRequestError::QueryValidationFailed(
1214 "Only SELECT, SHOW, and EXPLAIN statements are allowed".to_string(),
1215 )),
1216 }
1217}
1218
1219async fn execute_query(
1220 client: &mut AuthedClient,
1221 cluster: &str,
1222 sql_query: &str,
1223 max_response_size: usize,
1224) -> Result<McpResult, McpRequestError> {
1225 debug!(cluster = %cluster, "Executing user query");
1226
1227 validate_readonly_query(sql_query)?;
1228
1229 let combined_query = format!(
1232 "BEGIN READ ONLY; SET CLUSTER = {}; {}\n; COMMIT;",
1233 escaped_string_literal(cluster),
1234 sql_query
1235 );
1236
1237 let rows = execute_sql(client, &combined_query).await?;
1238
1239 format_rows_response(rows, max_response_size)
1240}
1241
1242async fn query_system_catalog(
1243 client: &mut AuthedClient,
1244 sql_query: &str,
1245 max_response_size: usize,
1246) -> Result<McpResult, McpRequestError> {
1247 debug!("Executing query_system_catalog");
1248
1249 validate_readonly_query(sql_query)?;
1251
1252 validate_system_catalog_query(sql_query)?;
1254
1255 let combined_query = format!(
1261 "BEGIN READ ONLY; SET search_path = mz_catalog, mz_internal, pg_catalog, information_schema; {}; COMMIT;",
1262 sql_query
1263 );
1264
1265 let rows = execute_sql(client, &combined_query).await?;
1266
1267 format_rows_response(rows, max_response_size)
1268}
1269
1270struct TableReferenceCollector {
1272 tables: Vec<(Option<String>, String)>,
1274 cte_names: std::collections::BTreeSet<String>,
1276}
1277
1278impl TableReferenceCollector {
1279 fn new() -> Self {
1280 Self {
1281 tables: Vec::new(),
1282 cte_names: std::collections::BTreeSet::new(),
1283 }
1284 }
1285}
1286
1287impl<'ast> Visit<'ast, Raw> for TableReferenceCollector {
1288 fn visit_cte(&mut self, cte: &'ast mz_sql_parser::ast::Cte<Raw>) {
1289 self.cte_names
1291 .insert(cte.alias.name.as_str().to_lowercase());
1292 visit::visit_cte(self, cte);
1293 }
1294
1295 fn visit_table_factor(&mut self, table_factor: &'ast mz_sql_parser::ast::TableFactor<Raw>) {
1296 if let mz_sql_parser::ast::TableFactor::Table { name, .. } = table_factor {
1298 match name {
1299 RawItemName::Name(n) | RawItemName::Id(_, n, _) => {
1300 let parts = &n.0;
1301 if !parts.is_empty() {
1302 let table_name = parts.last().unwrap().as_str().to_lowercase();
1303
1304 if self.cte_names.contains(&table_name) {
1306 visit::visit_table_factor(self, table_factor);
1307 return;
1308 }
1309
1310 let schema = if parts.len() >= 2 {
1312 Some(parts[parts.len() - 2].as_str().to_lowercase())
1313 } else {
1314 None
1315 };
1316 self.tables.push((schema, table_name));
1317 }
1318 }
1319 }
1320 }
1321 visit::visit_table_factor(self, table_factor);
1322 }
1323}
1324
1325fn validate_system_catalog_query(sql: &str) -> Result<(), McpRequestError> {
1333 let stmts = parse(sql).map_err(|e| {
1335 McpRequestError::QueryValidationFailed(format!("Failed to parse SQL: {}", e))
1336 })?;
1337
1338 if stmts.is_empty() {
1339 return Err(McpRequestError::QueryValidationFailed(
1340 "Empty query".to_string(),
1341 ));
1342 }
1343
1344 let mut collector = TableReferenceCollector::new();
1346 for stmt in &stmts {
1347 collector.visit_statement(&stmt.ast);
1348 }
1349
1350 let is_allowed_schema =
1353 |s: &str| SYSTEM_SCHEMAS.contains(&s) && s != namespaces::MZ_UNSAFE_SCHEMA;
1354
1355 let is_system_table = |(schema, table_name): &(Option<String>, String)| match schema {
1360 Some(s) => is_allowed_schema(s.as_str()),
1361 None => table_name.starts_with("mz_"),
1362 };
1363
1364 let non_system_tables: Vec<String> = collector
1366 .tables
1367 .iter()
1368 .filter(|t| !is_system_table(t))
1369 .map(|(schema, table)| match schema {
1370 Some(s) => format!("{}.{}", s, table),
1371 None => table.clone(),
1372 })
1373 .collect();
1374
1375 if !non_system_tables.is_empty() {
1376 return Err(McpRequestError::QueryValidationFailed(format!(
1377 "Query references non-system tables: {}. Only system catalog tables (mz_*, pg_catalog, information_schema) are allowed.",
1378 non_system_tables.join(", ")
1379 )));
1380 }
1381
1382 use mz_sql_parser::ast::Statement;
1385 let is_select = stmts.iter().any(|s| matches!(&s.ast, Statement::Select(_)));
1386
1387 if is_select && (collector.tables.is_empty() || !collector.tables.iter().any(is_system_table)) {
1388 return Err(McpRequestError::QueryValidationFailed(
1389 "Query must reference at least one system catalog table".to_string(),
1390 ));
1391 }
1392
1393 Ok(())
1394}
1395
1396#[cfg(test)]
1397mod tests {
1398 use super::*;
1399
1400 #[mz_ore::test]
1401 fn test_validate_readonly_query_select() {
1402 assert!(validate_readonly_query("SELECT * FROM mz_tables").is_ok());
1403 assert!(validate_readonly_query("SELECT 1 + 2").is_ok());
1404 assert!(validate_readonly_query(" SELECT 1 ").is_ok());
1405 }
1406
1407 #[mz_ore::test]
1408 fn test_validate_readonly_query_subqueries() {
1409 assert!(
1411 validate_readonly_query(
1412 "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM mz_columns)"
1413 )
1414 .is_ok()
1415 );
1416
1417 assert!(
1419 validate_readonly_query(
1420 "SELECT * FROM (SELECT * FROM mz_tables WHERE name LIKE 'test%') AS t"
1421 )
1422 .is_ok()
1423 );
1424
1425 assert!(validate_readonly_query(
1427 "SELECT * FROM mz_tables t WHERE EXISTS (SELECT 1 FROM mz_columns c WHERE c.id = t.id)"
1428 )
1429 .is_ok());
1430
1431 assert!(validate_readonly_query(
1433 "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM mz_columns WHERE type IN (SELECT name FROM mz_types))"
1434 )
1435 .is_ok());
1436
1437 assert!(
1439 validate_readonly_query(
1440 "SELECT * FROM mz_tables WHERE id = (SELECT MAX(id) FROM mz_columns)"
1441 )
1442 .is_ok()
1443 );
1444 }
1445
1446 #[mz_ore::test]
1447 fn test_validate_readonly_query_show() {
1448 assert!(validate_readonly_query("SHOW CLUSTERS").is_ok());
1449 assert!(validate_readonly_query("SHOW TABLES").is_ok());
1450 }
1451
1452 #[mz_ore::test]
1453 fn test_validate_readonly_query_explain() {
1454 assert!(validate_readonly_query("EXPLAIN SELECT 1").is_ok());
1459 assert!(
1460 validate_readonly_query("EXPLAIN FILTER PUSHDOWN FOR SELECT * FROM mz_tables").is_ok()
1461 );
1462 assert!(validate_readonly_query("EXPLAIN TIMESTAMP FOR SELECT 1").is_ok());
1463 assert!(validate_readonly_query("EXPLAIN ANALYZE MEMORY FOR INDEX foo").is_ok());
1464 assert!(validate_readonly_query("EXPLAIN ANALYZE MEMORY FOR MATERIALIZED VIEW mv").is_ok());
1465 assert!(validate_readonly_query("EXPLAIN ANALYZE CLUSTER MEMORY").is_ok());
1466 }
1467
1468 #[mz_ore::test]
1469 fn test_validate_readonly_query_rejects_writes() {
1470 assert!(validate_readonly_query("INSERT INTO t VALUES (1)").is_err());
1471 assert!(validate_readonly_query("UPDATE t SET a = 1").is_err());
1472 assert!(validate_readonly_query("DELETE FROM t").is_err());
1473 assert!(validate_readonly_query("CREATE TABLE t (a INT)").is_err());
1474 assert!(validate_readonly_query("DROP TABLE t").is_err());
1475 }
1476
1477 #[mz_ore::test]
1478 fn test_validate_readonly_query_rejects_multiple() {
1479 assert!(validate_readonly_query("SELECT 1; SELECT 2").is_err());
1480 }
1481
1482 #[mz_ore::test]
1483 fn test_validate_readonly_query_rejects_empty() {
1484 assert!(validate_readonly_query("").is_err());
1485 assert!(validate_readonly_query(" ").is_err());
1486 }
1487
1488 #[mz_ore::test]
1489 fn test_validate_system_catalog_query_accepts_mz_tables() {
1490 assert!(validate_system_catalog_query("SELECT * FROM mz_tables").is_ok());
1491 assert!(validate_system_catalog_query("SELECT * FROM mz_internal.mz_comments").is_ok());
1492 assert!(
1493 validate_system_catalog_query(
1494 "SELECT * FROM mz_tables t JOIN mz_columns c ON t.id = c.id"
1495 )
1496 .is_ok()
1497 );
1498 }
1499
1500 #[mz_ore::test]
1501 fn test_validate_system_catalog_query_subqueries() {
1502 assert!(
1504 validate_system_catalog_query(
1505 "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM mz_columns)"
1506 )
1507 .is_ok()
1508 );
1509
1510 assert!(validate_system_catalog_query(
1512 "SELECT * FROM mz_tables WHERE id IN (SELECT table_id FROM mz_columns WHERE type IN (SELECT id FROM mz_types))"
1513 )
1514 .is_ok());
1515
1516 assert!(
1518 validate_system_catalog_query(
1519 "SELECT * FROM (SELECT * FROM mz_tables WHERE name LIKE 'test%') AS t"
1520 )
1521 .is_ok()
1522 );
1523
1524 assert!(
1526 validate_system_catalog_query(
1527 "SELECT * FROM mz_tables WHERE id IN (SELECT table_id FROM user_data)"
1528 )
1529 .is_err()
1530 );
1531
1532 assert!(validate_system_catalog_query(
1534 "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM (SELECT id FROM user_table) AS t)"
1535 )
1536 .is_err());
1537 }
1538
1539 #[mz_ore::test]
1540 fn test_validate_system_catalog_query_rejects_user_tables() {
1541 assert!(validate_system_catalog_query("SELECT * FROM user_data").is_err());
1542 assert!(validate_system_catalog_query("SELECT * FROM my_table").is_err());
1543 assert!(
1545 validate_system_catalog_query("SELECT * FROM user_data WHERE 'mz_' IS NOT NULL")
1546 .is_err()
1547 );
1548 }
1549
1550 #[mz_ore::test]
1551 fn test_validate_system_catalog_query_allows_functions() {
1552 assert!(
1554 validate_system_catalog_query(
1555 "SELECT date_part('year', now())::int4 AS y FROM mz_tables LIMIT 1"
1556 )
1557 .is_ok()
1558 );
1559 assert!(validate_system_catalog_query("SELECT length(name) FROM mz_tables").is_ok());
1560 assert!(
1561 validate_system_catalog_query(
1562 "SELECT count(*) FROM mz_sources WHERE now() > created_at"
1563 )
1564 .is_ok()
1565 );
1566 }
1567
1568 #[mz_ore::test]
1569 fn test_validate_system_catalog_query_schema_qualified() {
1570 assert!(validate_system_catalog_query("SELECT * FROM mz_catalog.mz_tables").is_ok());
1572 assert!(validate_system_catalog_query("SELECT * FROM mz_internal.mz_sessions").is_ok());
1573 assert!(validate_system_catalog_query("SELECT * FROM pg_catalog.pg_type").is_ok());
1574 assert!(validate_system_catalog_query("SELECT * FROM information_schema.tables").is_ok());
1575
1576 assert!(validate_system_catalog_query("SELECT * FROM public.user_table").is_err());
1578 assert!(validate_system_catalog_query("SELECT * FROM myschema.mytable").is_err());
1579
1580 assert!(
1582 validate_system_catalog_query("SELECT * FROM mz_unsafe.mz_some_table").is_err(),
1583 "mz_unsafe schema should be blocked even though it is a system schema"
1584 );
1585
1586 assert!(
1588 validate_system_catalog_query(
1589 "SELECT * FROM mz_catalog.mz_tables JOIN public.user_data ON true"
1590 )
1591 .is_err()
1592 );
1593 }
1594
1595 #[mz_ore::test]
1596 fn test_validate_system_catalog_query_adversarial_cases() {
1597 assert!(
1599 validate_system_catalog_query(
1600 "WITH user_cte AS (SELECT * FROM user_data) \
1601 SELECT * FROM mz_tables, user_cte"
1602 )
1603 .is_err(),
1604 "Should reject CTE referencing user table"
1605 );
1606
1607 assert!(
1609 validate_system_catalog_query(
1610 "WITH \
1611 cte1 AS (SELECT * FROM mz_tables), \
1612 cte2 AS (SELECT * FROM cte1), \
1613 cte3 AS (SELECT * FROM user_data) \
1614 SELECT * FROM cte2"
1615 )
1616 .is_err(),
1617 "Should reject CTE chain with user table"
1618 );
1619
1620 assert!(
1622 validate_system_catalog_query(
1623 "SELECT * FROM mz_tables t1 \
1624 JOIN user_data u ON t1.id = u.id \
1625 JOIN mz_sources s ON t1.id = s.id"
1626 )
1627 .is_err(),
1628 "Should reject multi-join with user table"
1629 );
1630
1631 assert!(
1633 validate_system_catalog_query(
1634 "SELECT * FROM mz_tables t \
1635 LEFT JOIN user_data u ON t.id = u.table_id \
1636 WHERE u.id IS NULL"
1637 )
1638 .is_err(),
1639 "Should reject LEFT JOIN with user table"
1640 );
1641
1642 assert!(
1644 validate_system_catalog_query(
1645 "SELECT * FROM mz_tables WHERE id IN \
1646 (SELECT table_id FROM (SELECT * FROM user_data) AS u)"
1647 )
1648 .is_err(),
1649 "Should reject nested subquery with user table"
1650 );
1651
1652 assert!(
1654 validate_system_catalog_query(
1655 "SELECT name FROM mz_tables \
1656 UNION \
1657 SELECT name FROM user_data"
1658 )
1659 .is_err(),
1660 "Should reject UNION with user table"
1661 );
1662
1663 assert!(
1665 validate_system_catalog_query(
1666 "SELECT id FROM mz_sources \
1667 UNION ALL \
1668 SELECT id FROM products"
1669 )
1670 .is_err(),
1671 "Should reject UNION ALL with user table"
1672 );
1673
1674 assert!(
1676 validate_system_catalog_query("SELECT * FROM mz_tables CROSS JOIN user_data").is_err(),
1677 "Should reject CROSS JOIN with user table"
1678 );
1679
1680 assert!(
1682 validate_system_catalog_query(
1683 "SELECT t.*, (SELECT COUNT(*) FROM user_data) AS cnt FROM mz_tables t"
1684 )
1685 .is_err(),
1686 "Should reject subquery in SELECT with user table"
1687 );
1688
1689 assert!(
1691 validate_system_catalog_query("SELECT * FROM mz_catalogg.fake_table").is_err(),
1692 "Should reject typo-squatting schema name"
1693 );
1694 assert!(
1695 validate_system_catalog_query("SELECT * FROM mz_catalog_hack.fake_table").is_err(),
1696 "Should reject fake schema with mz_catalog prefix"
1697 );
1698
1699 assert!(
1701 validate_system_catalog_query(
1702 "SELECT * FROM mz_tables t, LATERAL (SELECT * FROM user_data WHERE id = t.id) u"
1703 )
1704 .is_err(),
1705 "Should reject LATERAL join with user table"
1706 );
1707
1708 assert!(
1710 validate_system_catalog_query(
1711 "WITH \
1712 tables AS (SELECT * FROM mz_tables), \
1713 sources AS (SELECT * FROM mz_sources) \
1714 SELECT t.name, s.name \
1715 FROM tables t \
1716 JOIN sources s ON t.id = s.id \
1717 WHERE t.id IN (SELECT id FROM mz_columns)"
1718 )
1719 .is_ok(),
1720 "Should allow complex query with only system tables"
1721 );
1722
1723 assert!(
1725 validate_system_catalog_query(
1726 "SELECT name FROM mz_tables \
1727 UNION \
1728 SELECT name FROM mz_sources"
1729 )
1730 .is_ok(),
1731 "Should allow UNION of system tables"
1732 );
1733 }
1734
1735 #[mz_ore::test]
1736 fn test_validate_system_catalog_query_rejects_constant_queries() {
1737 assert!(
1740 validate_system_catalog_query("SELECT 1").is_err(),
1741 "Should reject constant SELECT with no table references"
1742 );
1743 assert!(
1744 validate_system_catalog_query("SELECT 1 + 2, 'hello'").is_err(),
1745 "Should reject constant expression SELECT"
1746 );
1747 assert!(
1748 validate_system_catalog_query("SELECT now()").is_err(),
1749 "Should reject function-only SELECT with no table references"
1750 );
1751 }
1752
1753 #[mz_ore::test]
1754 fn test_validate_system_catalog_query_rejects_mixed_tables() {
1755 assert!(
1756 validate_system_catalog_query(
1757 "SELECT * FROM mz_tables t JOIN user_data u ON t.id = u.table_id"
1758 )
1759 .is_err()
1760 );
1761 }
1762
1763 #[mz_ore::test]
1764 fn test_validate_system_catalog_query_allows_show() {
1765 assert!(
1767 validate_system_catalog_query("SHOW TABLES FROM mz_internal").is_ok(),
1768 "SHOW TABLES FROM mz_internal should be allowed"
1769 );
1770 assert!(
1771 validate_system_catalog_query("SHOW TABLES FROM mz_catalog").is_ok(),
1772 "SHOW TABLES FROM mz_catalog should be allowed"
1773 );
1774 assert!(
1775 validate_system_catalog_query("SHOW CLUSTERS").is_ok(),
1776 "SHOW CLUSTERS should be allowed"
1777 );
1778 assert!(
1779 validate_system_catalog_query("SHOW SOURCES").is_ok(),
1780 "SHOW SOURCES should be allowed"
1781 );
1782 assert!(
1783 validate_system_catalog_query("SHOW TABLES").is_ok(),
1784 "SHOW TABLES should be allowed"
1785 );
1786 }
1787
1788 #[mz_ore::test]
1789 fn test_validate_system_catalog_query_allows_explain() {
1790 assert!(
1791 validate_system_catalog_query("EXPLAIN SELECT * FROM mz_tables").is_ok(),
1792 "EXPLAIN of system table query should be allowed"
1793 );
1794 assert!(
1795 validate_system_catalog_query("EXPLAIN SELECT 1").is_ok(),
1796 "EXPLAIN SELECT 1 should be allowed"
1797 );
1798 }
1799
1800 #[mz_ore::test(tokio::test)]
1803 async fn test_tools_list_agent_query_tool_disabled() {
1804 let result = handle_tools_list(McpEndpointType::Agent, false, 1_000_000)
1805 .await
1806 .unwrap();
1807 let McpResult::ToolsList(list) = result else {
1808 panic!("Expected ToolsList result");
1809 };
1810 let tool_names: Vec<&str> = list.tools.iter().map(|t| t.name.as_str()).collect();
1811 assert!(
1812 tool_names.contains(&"get_data_products"),
1813 "get_data_products should always be present"
1814 );
1815 assert!(
1816 tool_names.contains(&"get_data_product_details"),
1817 "get_data_product_details should always be present"
1818 );
1819 assert!(
1820 tool_names.contains(&"read_data_product"),
1821 "read_data_product should always be present"
1822 );
1823 assert!(
1824 !tool_names.contains(&"query"),
1825 "query tool should be hidden when disabled"
1826 );
1827 }
1828
1829 #[mz_ore::test(tokio::test)]
1830 async fn test_tools_list_agent_query_tool_enabled() {
1831 let result = handle_tools_list(McpEndpointType::Agent, true, 1_000_000)
1832 .await
1833 .unwrap();
1834 let McpResult::ToolsList(list) = result else {
1835 panic!("Expected ToolsList result");
1836 };
1837 let tool_names: Vec<&str> = list.tools.iter().map(|t| t.name.as_str()).collect();
1838 assert!(
1839 tool_names.contains(&"get_data_products"),
1840 "get_data_products should always be present"
1841 );
1842 assert!(
1843 tool_names.contains(&"get_data_product_details"),
1844 "get_data_product_details should always be present"
1845 );
1846 assert!(
1847 tool_names.contains(&"read_data_product"),
1848 "read_data_product should always be present"
1849 );
1850 assert!(
1851 tool_names.contains(&"query"),
1852 "query tool should be present when enabled"
1853 );
1854 }
1855
1856 #[mz_ore::test(tokio::test)]
1857 async fn test_tools_list_developer_query_tool_disabled() {
1858 let result = handle_tools_list(McpEndpointType::Developer, false, 1_000_000)
1859 .await
1860 .unwrap();
1861 let McpResult::ToolsList(list) = result else {
1862 panic!("Expected ToolsList result");
1863 };
1864 let tool_names: Vec<&str> = list.tools.iter().map(|t| t.name.as_str()).collect();
1865 assert!(
1866 tool_names.contains(&"query_system_catalog"),
1867 "query_system_catalog should always be present on developer"
1868 );
1869 assert!(
1870 !tool_names.contains(&"query"),
1871 "query tool should be hidden when disabled"
1872 );
1873 }
1874
1875 #[mz_ore::test(tokio::test)]
1876 async fn test_tools_list_developer_query_tool_enabled() {
1877 let result = handle_tools_list(McpEndpointType::Developer, true, 1_000_000)
1878 .await
1879 .unwrap();
1880 let McpResult::ToolsList(list) = result else {
1881 panic!("Expected ToolsList result");
1882 };
1883 let tool_names: Vec<&str> = list.tools.iter().map(|t| t.name.as_str()).collect();
1884 assert!(
1885 tool_names.contains(&"query_system_catalog"),
1886 "query_system_catalog should always be present on developer"
1887 );
1888 assert!(
1889 tool_names.contains(&"query"),
1890 "query tool should be present on developer when enabled"
1891 );
1892 }
1893
1894 #[mz_ore::test]
1897 fn test_format_rows_response_within_limit() {
1898 let rows = vec![vec![json!("a"), json!(1)], vec![json!("b"), json!(2)]];
1899 let result = format_rows_response(rows, 1_000_000).unwrap();
1900 let McpResult::ToolContent(content) = result else {
1901 panic!("Expected ToolContent");
1902 };
1903 assert_eq!(content.content.len(), 1);
1904 assert!(content.content[0].text.contains("\"a\""));
1905 assert!(content.content[0].text.contains("\"b\""));
1906 }
1907
1908 #[mz_ore::test]
1909 fn test_format_rows_response_errors_when_over_limit() {
1910 let rows: Vec<Vec<serde_json::Value>> = (0..100)
1911 .map(|i| vec![json!(format!("row_{}", i)), json!(i)])
1912 .collect();
1913 let err = format_rows_response(rows, 500).unwrap_err();
1914 let msg = err.to_string();
1915 assert!(
1916 msg.contains("exceeds the 500 byte limit"),
1917 "Error should mention the size limit, got: {msg}"
1918 );
1919 assert!(
1920 msg.contains("Use LIMIT or WHERE"),
1921 "Error should suggest narrowing the query, got: {msg}"
1922 );
1923 }
1924
1925 #[mz_ore::test]
1926 fn test_format_rows_response_empty_rows() {
1927 let rows: Vec<Vec<serde_json::Value>> = vec![];
1928 let result = format_rows_response(rows, 1000).unwrap();
1929 let McpResult::ToolContent(content) = result else {
1930 panic!("Expected ToolContent");
1931 };
1932 assert_eq!(content.content.len(), 1);
1933 assert_eq!(content.content[0].text, "[]");
1934 }
1935
1936 #[mz_ore::test]
1939 fn test_safe_data_product_name_valid() {
1940 assert_eq!(
1942 safe_data_product_name(r#""materialize"."public"."my_view""#).unwrap(),
1943 r#""materialize"."public"."my_view""#
1944 );
1945 assert_eq!(
1947 safe_data_product_name(r#""public"."my_view""#).unwrap(),
1948 r#""public"."my_view""#
1949 );
1950 assert_eq!(safe_data_product_name("my_view").unwrap(), r#""my_view""#);
1952 }
1953
1954 #[mz_ore::test]
1955 fn test_safe_data_product_name_rejects_empty() {
1956 assert!(safe_data_product_name("").is_err());
1957 assert!(safe_data_product_name(" ").is_err());
1958 }
1959
1960 #[mz_ore::test]
1961 fn test_safe_data_product_name_rejects_sql_injection() {
1962 assert!(safe_data_product_name("my_view; DROP TABLE users").is_err());
1964 assert!(safe_data_product_name("my_view UNION SELECT * FROM secrets").is_err());
1966 assert!(safe_data_product_name("my_view, secrets").is_err());
1968 assert!(safe_data_product_name("my_view WHERE 1=1 --").is_err());
1970 }
1971
1972 #[mz_ore::test]
1978 fn test_build_read_query_with_cluster() {
1979 let sql = build_read_query("\"db\".\"sch\".\"v\"", 50, "prod_cluster");
1980 assert!(sql.contains("BEGIN READ ONLY"), "{sql}");
1981 assert!(sql.contains("SET CLUSTER = 'prod_cluster'"), "{sql}");
1982 assert!(
1983 sql.contains("SELECT * FROM \"db\".\"sch\".\"v\" LIMIT 50"),
1984 "{sql}",
1985 );
1986 assert!(sql.contains("COMMIT"), "{sql}");
1987 }
1988
1989 #[mz_ore::test]
1994 fn test_build_read_query_escapes_cluster_name() {
1995 let sql = build_read_query("\"db\".\"sch\".\"v\"", 10, "evil'; DROP TABLE secrets; --");
1996 assert!(
1998 sql.contains("SET CLUSTER = 'evil''; DROP TABLE secrets; --'"),
1999 "single quote should be doubled inside the literal: {sql}",
2000 );
2001 assert_eq!(
2004 sql.matches("SET CLUSTER").count(),
2005 1,
2006 "exactly one SET CLUSTER statement: {sql}",
2007 );
2008 assert_eq!(
2009 sql.matches("DROP TABLE").count(),
2010 1,
2011 "DROP TABLE should appear once, inside the quoted literal: {sql}",
2012 );
2013 }
2014
2015 #[mz_ore::test]
2016 fn test_mcp_error_codes() {
2017 assert_eq!(
2018 McpRequestError::InvalidJsonRpcVersion.error_code(),
2019 error_codes::INVALID_REQUEST
2020 );
2021 assert_eq!(
2022 McpRequestError::MethodNotFound("test".to_string()).error_code(),
2023 error_codes::METHOD_NOT_FOUND
2024 );
2025 assert_eq!(
2026 McpRequestError::QueryExecutionFailed("test".to_string()).error_code(),
2027 error_codes::INTERNAL_ERROR
2028 );
2029 }
2030}