1use std::sync::Arc;
27use std::time::Duration;
28
29use anyhow::anyhow;
30use axum::Extension;
31use axum::Json;
32use axum::response::IntoResponse;
33use http::{HeaderMap, HeaderValue, StatusCode};
34use mz_adapter_types::dyncfgs::{
35 ENABLE_MCP_AGENT, ENABLE_MCP_AGENT_QUERY_TOOL, ENABLE_MCP_DEVELOPER, MCP_MAX_RESPONSE_SIZE,
36};
37use mz_repr::namespaces::{self, SYSTEM_SCHEMAS};
38use mz_sql::parse::parse;
39use mz_sql::session::metadata::SessionMetadata;
40use mz_sql::session::vars::{APPLICATION_NAME, Var, VarInput};
41use mz_sql_parser::ast::display::{AstDisplay, escaped_string_literal};
42use mz_sql_parser::ast::visit::{self, Visit};
43use mz_sql_parser::ast::{Raw, RawItemName};
44use mz_sql_parser::parser::parse_item_name;
45use serde::{Deserialize, Serialize};
46use serde_json::json;
47use thiserror::Error;
48use tracing::{debug, warn};
49
50use crate::http::AuthedClient;
51use crate::http::mcp_metrics::{McpMetrics, ToolCallGuard};
52use crate::http::sql::{SqlRequest, SqlResponse, SqlResult, execute_request};
53
54const JSONRPC_VERSION: &str = "2.0";
58
59const MCP_PROTOCOL_VERSION: &str = "2025-11-25";
62
63const MCP_REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
68
69const DISCOVERY_QUERY: &str = "SELECT * FROM mz_internal.mz_mcp_data_products";
71const DETAILS_QUERY_PREFIX: &str =
72 "SELECT * FROM mz_internal.mz_mcp_data_product_details WHERE object_name = ";
73
74#[derive(Debug, Error)]
76enum McpRequestError {
77 #[error("Invalid JSON-RPC version: expected 2.0")]
78 InvalidJsonRpcVersion,
79 #[error("Method not found: {0}")]
80 #[allow(dead_code)] MethodNotFound(String),
82 #[error("Tool not found: {0}")]
83 ToolNotFound(String),
84 #[error("Data product not found: {0}")]
85 DataProductNotFound(String),
86 #[error("{0}")]
87 ClusterPrivilegeMissing(String),
88 #[error("Query validation failed: {0}")]
89 QueryValidationFailed(String),
90 #[error("Query execution failed: {0}")]
91 QueryExecutionFailed(String),
92 #[error("Internal error: {0}")]
93 Internal(#[from] anyhow::Error),
94}
95
96impl McpRequestError {
97 fn error_code(&self) -> i32 {
98 match self {
99 Self::InvalidJsonRpcVersion => error_codes::INVALID_REQUEST,
100 Self::MethodNotFound(_) => error_codes::METHOD_NOT_FOUND,
101 Self::ToolNotFound(_) => error_codes::INVALID_PARAMS,
102 Self::DataProductNotFound(_) => error_codes::INVALID_PARAMS,
103 Self::ClusterPrivilegeMissing(_) => error_codes::INVALID_PARAMS,
104 Self::QueryValidationFailed(_) => error_codes::INVALID_PARAMS,
105 Self::QueryExecutionFailed(_) | Self::Internal(_) => error_codes::INTERNAL_ERROR,
106 }
107 }
108
109 fn error_type(&self) -> &'static str {
110 match self {
111 Self::InvalidJsonRpcVersion => "InvalidRequest",
112 Self::MethodNotFound(_) => "MethodNotFound",
113 Self::ToolNotFound(_) => "ToolNotFound",
114 Self::DataProductNotFound(_) => "DataProductNotFound",
115 Self::ClusterPrivilegeMissing(_) => "ClusterPrivilegeMissing",
116 Self::QueryValidationFailed(_) => "ValidationError",
117 Self::QueryExecutionFailed(_) => "ExecutionError",
118 Self::Internal(_) => "InternalError",
119 }
120 }
121}
122
123#[derive(Debug, Deserialize)]
125pub(crate) struct McpRequest {
126 jsonrpc: String,
127 id: Option<serde_json::Value>,
128 #[serde(flatten)]
129 method: McpMethod,
130}
131
132#[derive(Debug, Deserialize)]
134#[serde(tag = "method", content = "params")]
135enum McpMethod {
136 #[serde(rename = "initialize")]
138 Initialize(#[allow(dead_code)] InitializeParams),
139 #[serde(rename = "tools/list")]
140 ToolsList,
141 #[serde(rename = "tools/call")]
142 ToolsCall(ToolsCallParams),
143 #[serde(other)]
145 Unknown,
146}
147
148impl std::fmt::Display for McpMethod {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 match self {
151 McpMethod::Initialize(_) => write!(f, "initialize"),
152 McpMethod::ToolsList => write!(f, "tools/list"),
153 McpMethod::ToolsCall(_) => write!(f, "tools/call"),
154 McpMethod::Unknown => write!(f, "unknown"),
155 }
156 }
157}
158
159#[derive(Debug, Deserialize)]
160struct InitializeParams {
161 #[serde(rename = "protocolVersion")]
163 #[allow(dead_code)]
164 protocol_version: String,
165 #[serde(default)]
167 #[allow(dead_code)]
168 capabilities: serde_json::Value,
169 #[serde(rename = "clientInfo")]
171 #[allow(dead_code)]
172 client_info: Option<ClientInfo>,
173}
174
175#[derive(Debug, Deserialize)]
176struct ClientInfo {
177 #[allow(dead_code)]
178 name: String,
179 #[allow(dead_code)]
180 version: String,
181}
182
183#[derive(Debug, Deserialize)]
186#[serde(tag = "name", content = "arguments")]
187#[serde(rename_all = "snake_case")]
188enum ToolsCallParams {
189 GetDataProducts(#[serde(default)] ()),
192 GetDataProductDetails(GetDataProductDetailsParams),
193 ReadDataProduct(ReadDataProductParams),
194 Query(QueryParams),
195 QuerySystemCatalog(QuerySystemCatalogParams),
197}
198
199impl std::fmt::Display for ToolsCallParams {
200 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201 match self {
202 ToolsCallParams::GetDataProducts(_) => write!(f, "get_data_products"),
203 ToolsCallParams::GetDataProductDetails(_) => write!(f, "get_data_product_details"),
204 ToolsCallParams::ReadDataProduct(_) => write!(f, "read_data_product"),
205 ToolsCallParams::Query(_) => write!(f, "query"),
206 ToolsCallParams::QuerySystemCatalog(_) => write!(f, "query_system_catalog"),
207 }
208 }
209}
210
211#[derive(Debug, Deserialize)]
212struct GetDataProductDetailsParams {
213 name: String,
214}
215
216#[derive(Debug, Deserialize)]
217struct ReadDataProductParams {
218 name: String,
219 #[serde(default = "default_read_limit")]
220 limit: u32,
221 cluster: Option<String>,
222}
223
224fn default_read_limit() -> u32 {
225 500
226}
227
228#[derive(Debug, Deserialize)]
229struct QueryParams {
230 cluster: String,
231 sql_query: String,
232}
233
234#[derive(Debug, Deserialize)]
235struct QuerySystemCatalogParams {
236 sql_query: String,
237}
238
239#[derive(Debug, Serialize)]
240struct McpResponse {
241 jsonrpc: String,
242 id: serde_json::Value,
243 #[serde(skip_serializing_if = "Option::is_none")]
244 result: Option<McpResult>,
245 #[serde(skip_serializing_if = "Option::is_none")]
246 error: Option<McpError>,
247}
248
249#[derive(Debug, Serialize)]
251#[serde(untagged)]
252enum McpResult {
253 Initialize(InitializeResult),
254 ToolsList(ToolsListResult),
255 ToolContent(ToolContentResult),
256}
257
258#[derive(Debug, Serialize)]
259struct InitializeResult {
260 #[serde(rename = "protocolVersion")]
261 protocol_version: String,
262 capabilities: Capabilities,
263 #[serde(rename = "serverInfo")]
264 server_info: ServerInfo,
265 #[serde(skip_serializing_if = "Option::is_none")]
266 instructions: Option<String>,
267}
268
269#[derive(Debug, Serialize)]
270struct Capabilities {
271 tools: serde_json::Value,
272}
273
274#[derive(Debug, Serialize)]
275struct ServerInfo {
276 name: String,
277 version: String,
278}
279
280#[derive(Debug, Serialize)]
281struct ToolsListResult {
282 tools: Vec<ToolDefinition>,
283}
284
285#[derive(Debug, Serialize)]
286struct ToolDefinition {
287 name: String,
288 #[serde(skip_serializing_if = "Option::is_none")]
289 title: Option<String>,
290 description: String,
291 #[serde(rename = "inputSchema")]
292 input_schema: serde_json::Value,
293 #[serde(skip_serializing_if = "Option::is_none")]
294 annotations: Option<ToolAnnotations>,
295}
296
297#[derive(Debug, Serialize)]
300struct ToolAnnotations {
301 #[serde(rename = "readOnlyHint", skip_serializing_if = "Option::is_none")]
302 read_only_hint: Option<bool>,
303 #[serde(rename = "destructiveHint", skip_serializing_if = "Option::is_none")]
304 destructive_hint: Option<bool>,
305 #[serde(rename = "idempotentHint", skip_serializing_if = "Option::is_none")]
306 idempotent_hint: Option<bool>,
307 #[serde(rename = "openWorldHint", skip_serializing_if = "Option::is_none")]
308 open_world_hint: Option<bool>,
309}
310
311const READ_ONLY_ANNOTATIONS: ToolAnnotations = ToolAnnotations {
313 read_only_hint: Some(true),
314 destructive_hint: Some(false),
315 idempotent_hint: Some(true),
316 open_world_hint: Some(false),
317};
318
319#[derive(Debug, Serialize)]
320struct ToolContentResult {
321 content: Vec<ContentBlock>,
322 #[serde(rename = "isError")]
323 is_error: bool,
324}
325
326#[derive(Debug, Serialize)]
327struct ContentBlock {
328 #[serde(rename = "type")]
329 content_type: String,
330 text: String,
331}
332
333mod error_codes {
335 pub const INVALID_REQUEST: i32 = -32600;
336 pub const METHOD_NOT_FOUND: i32 = -32601;
337 pub const INVALID_PARAMS: i32 = -32602;
338 pub const INTERNAL_ERROR: i32 = -32603;
339}
340
341#[derive(Debug, Serialize)]
342struct McpError {
343 code: i32,
344 message: String,
345 #[serde(skip_serializing_if = "Option::is_none")]
346 data: Option<serde_json::Value>,
347}
348
349impl From<McpRequestError> for McpError {
350 fn from(err: McpRequestError) -> Self {
351 McpError {
352 code: err.error_code(),
353 message: err.to_string(),
354 data: Some(json!({
355 "error_type": err.error_type(),
356 })),
357 }
358 }
359}
360
361#[derive(Debug, Clone, Copy)]
362enum McpEndpointType {
363 Agent,
364 Developer,
365}
366
367impl McpEndpointType {
368 fn as_label(self) -> &'static str {
371 match self {
372 McpEndpointType::Agent => "agent",
373 McpEndpointType::Developer => "developer",
374 }
375 }
376}
377
378impl std::fmt::Display for McpEndpointType {
379 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
380 f.write_str(self.as_label())
381 }
382}
383
384pub async fn handle_mcp_method_not_allowed() -> impl IntoResponse {
387 StatusCode::METHOD_NOT_ALLOWED
388}
389
390pub async fn handle_mcp_agent(
392 headers: HeaderMap,
393 Extension(allowed_origins): Extension<Arc<Vec<HeaderValue>>>,
394 Extension(metrics): Extension<McpMetrics>,
395 client: AuthedClient,
396 Json(body): Json<McpRequest>,
397) -> axum::response::Response {
398 if let Some(resp) = validate_origin(&headers, &allowed_origins) {
399 return resp;
400 }
401 handle_mcp_request(client, body, McpEndpointType::Agent, metrics)
402 .await
403 .into_response()
404}
405
406pub async fn handle_mcp_developer(
408 headers: HeaderMap,
409 Extension(allowed_origins): Extension<Arc<Vec<HeaderValue>>>,
410 Extension(metrics): Extension<McpMetrics>,
411 client: AuthedClient,
412 Json(body): Json<McpRequest>,
413) -> axum::response::Response {
414 if let Some(resp) = validate_origin(&headers, &allowed_origins) {
415 return resp;
416 }
417 handle_mcp_request(client, body, McpEndpointType::Developer, metrics)
418 .await
419 .into_response()
420}
421
422fn validate_origin(
431 headers: &HeaderMap,
432 allowed: &[HeaderValue],
433) -> Option<axum::response::Response> {
434 let origin = headers.get(http::header::ORIGIN)?;
435 if mz_http_util::origin_is_allowed(origin, allowed) {
436 return None;
437 }
438 warn!(
439 origin = ?origin,
440 "MCP request rejected: origin not in allowlist",
441 );
442 Some(StatusCode::FORBIDDEN.into_response())
443}
444
445async fn handle_mcp_request(
446 mut client: AuthedClient,
447 request: McpRequest,
448 endpoint_type: McpEndpointType,
449 metrics: McpMetrics,
450) -> impl IntoResponse {
451 let endpoint_label = endpoint_type.as_label();
452 let method_label = request.method.to_string();
453 let record_request = |status: &str| {
454 metrics
455 .requests
456 .with_label_values(&[endpoint_label, &method_label, status])
457 .inc();
458 };
459
460 let catalog = client.client.catalog_snapshot("mcp").await;
462 let dyncfgs = catalog.system_config().dyncfgs();
463 let enabled = match endpoint_type {
464 McpEndpointType::Agent => ENABLE_MCP_AGENT.get(dyncfgs),
465 McpEndpointType::Developer => ENABLE_MCP_DEVELOPER.get(dyncfgs),
466 };
467 if !enabled {
468 debug!(endpoint = %endpoint_type, "MCP endpoint disabled by feature flag");
469 record_request("endpoint_disabled");
470 return StatusCode::SERVICE_UNAVAILABLE.into_response();
471 }
472
473 let query_tool_enabled = ENABLE_MCP_AGENT_QUERY_TOOL.get(dyncfgs);
474 let max_response_size = MCP_MAX_RESPONSE_SIZE.get(dyncfgs);
475
476 let app_name = match endpoint_type {
480 McpEndpointType::Agent => "mz_mcp_agents",
481 McpEndpointType::Developer => "mz_mcp_developer",
482 };
483 client
484 .client
485 .session()
486 .vars_mut()
487 .set_default(APPLICATION_NAME.name(), VarInput::Flat(app_name))
488 .expect("application_name is a known session var");
489
490 let user = client.client.session().user().name.clone();
491 let is_notification = request.id.is_none();
492
493 debug!(
494 method = %request.method,
495 endpoint = %endpoint_type,
496 user = %user,
497 is_notification = is_notification,
498 "MCP request received"
499 );
500
501 if is_notification {
503 debug!(method = %request.method, "Received notification (no response will be sent)");
504 record_request("ok");
505 return StatusCode::OK.into_response();
506 }
507
508 let request_id = request.id.clone().unwrap_or(serde_json::Value::Null);
509
510 let metrics_inner = metrics.clone();
515 let result = tokio::time::timeout(
516 MCP_REQUEST_TIMEOUT,
517 mz_ore::task::spawn(|| "mcp_request", async move {
518 handle_mcp_request_inner(
519 &mut client,
520 request,
521 endpoint_type,
522 query_tool_enabled,
523 max_response_size,
524 metrics_inner,
525 )
526 .await
527 })
528 .abort_on_drop(),
529 )
530 .await;
531
532 let (response, status_label): (McpResponse, &'static str) = match result {
533 Ok(inner) => inner,
534 Err(_elapsed) => {
535 warn!(
536 endpoint = %endpoint_type,
537 timeout = ?MCP_REQUEST_TIMEOUT,
538 "MCP request timed out",
539 );
540 let response = McpResponse {
541 jsonrpc: JSONRPC_VERSION.to_string(),
542 id: request_id,
543 result: None,
544 error: Some(
545 McpRequestError::QueryExecutionFailed(format!(
546 "Request timed out after {} seconds.",
547 MCP_REQUEST_TIMEOUT.as_secs(),
548 ))
549 .into(),
550 ),
551 };
552 (response, "timeout")
553 }
554 };
555
556 record_request(status_label);
557 (StatusCode::OK, Json(response)).into_response()
558}
559
560async fn handle_mcp_request_inner(
561 client: &mut AuthedClient,
562 request: McpRequest,
563 endpoint_type: McpEndpointType,
564 query_tool_enabled: bool,
565 max_response_size: usize,
566 metrics: McpMetrics,
567) -> (McpResponse, &'static str) {
568 let request_id = request.id.clone().unwrap_or(serde_json::Value::Null);
570
571 let result = handle_mcp_method(
572 client,
573 &request,
574 endpoint_type,
575 query_tool_enabled,
576 max_response_size,
577 &metrics,
578 )
579 .await;
580
581 let status_label = match &result {
582 Ok(_) => "ok",
583 Err(e) => e.error_type(),
584 };
585
586 let response = match result {
587 Ok(result_value) => McpResponse {
588 jsonrpc: JSONRPC_VERSION.to_string(),
589 id: request_id,
590 result: Some(result_value),
591 error: None,
592 },
593 Err(e) => {
594 if !matches!(
596 e,
597 McpRequestError::MethodNotFound(_) | McpRequestError::InvalidJsonRpcVersion
598 ) {
599 warn!(error = %e, method = %request.method, "MCP method execution failed");
600 }
601 McpResponse {
602 jsonrpc: JSONRPC_VERSION.to_string(),
603 id: request_id,
604 result: None,
605 error: Some(e.into()),
606 }
607 }
608 };
609
610 (response, status_label)
611}
612
613async fn handle_mcp_method(
614 client: &mut AuthedClient,
615 request: &McpRequest,
616 endpoint_type: McpEndpointType,
617 query_tool_enabled: bool,
618 max_response_size: usize,
619 metrics: &McpMetrics,
620) -> Result<McpResult, McpRequestError> {
621 if request.jsonrpc != JSONRPC_VERSION {
623 return Err(McpRequestError::InvalidJsonRpcVersion);
624 }
625
626 match &request.method {
628 McpMethod::Initialize(_) => {
629 debug!(endpoint = %endpoint_type, "Processing initialize");
630 handle_initialize(endpoint_type).await
631 }
632 McpMethod::ToolsList => {
633 debug!(endpoint = %endpoint_type, "Processing tools/list");
634 handle_tools_list(endpoint_type, query_tool_enabled, max_response_size).await
635 }
636 McpMethod::ToolsCall(params) => {
637 debug!(tool = %params, endpoint = %endpoint_type, "Processing tools/call");
638 handle_tools_call(
639 client,
640 params,
641 endpoint_type,
642 query_tool_enabled,
643 max_response_size,
644 metrics,
645 )
646 .await
647 }
648 McpMethod::Unknown => Err(McpRequestError::MethodNotFound(
649 "unknown method".to_string(),
650 )),
651 }
652}
653
654fn endpoint_instructions(endpoint_type: McpEndpointType) -> Option<String> {
657 match endpoint_type {
658 McpEndpointType::Agent => Some(concat!(
659 "You have access to Materialize data products via MCP. ",
660 "Prefer indexed objects (served from memory) over unindexed materialized views ",
661 "(read from persistent storage). `read_data_product` automatically routes the ",
662 "read to the cluster recorded in the data product catalog so indexes are used; ",
663 "you only need to set the `cluster` parameter if you intentionally want the ",
664 "read to run on a different cluster (e.g. one with larger or more replicas). ",
665 "`get_data_product_details` returns a `hydration` object with `hydrated`, ",
666 "`replica_count`, and `hydrated_replica_count` fields. Reads never return ",
667 "partial data: a read against a not-yet-hydrated product blocks until the ",
668 "dataflow catches up, and may hit the request timeout. Check `hydrated` ",
669 "before reading: if it is false and `replica_count` is greater than 0, the ",
670 "dataflow is still warming up, so wait and retry; if `replica_count` is 0 the ",
671 "cluster has no replicas and the read cannot make progress until one is added.",
672 ).to_string()),
673 McpEndpointType::Developer => Some(concat!(
674 "You are connected to the Materialize developer MCP server. ",
675 "You have read-only access to system catalog tables (mz_*, pg_catalog, information_schema) ",
676 "for troubleshooting and observability.\n\n",
677 "IMPORTANT: Before writing queries, discover table schemas using the mz_ontology tables:\n",
678 "- mz_internal.mz_ontology_entity_types: what catalog entities exist and which tables they map to\n",
679 "- mz_internal.mz_ontology_link_types: relationships between entities (foreign keys, metrics, etc.)\n",
680 "- mz_internal.mz_ontology_properties: column names, types, and descriptions for each entity\n",
681 "- mz_internal.mz_ontology_semantic_types: typed ID domains (CatalogItemId, ReplicaId, etc.)\n\n",
682 "Use these to find the correct tables, join paths, and column names instead of guessing.\n\n",
683 "Key rules:\n",
684 "- mz_source_statuses and mz_sink_statuses use `last_status_change_at` (NOT `updated_at`)\n",
685 "- mz_cluster_replica_utilization only has `replica_id` — JOIN with mz_cluster_replicas and mz_clusters to get names\n",
686 "- Do NOT query mz_introspection.mz_dataflow_arrangement_sizes — it is cluster-scoped and has uint8/text type mismatches\n",
687 "- Use SHOW COLUMNS FROM <table> to verify column names if unsure",
688 ).to_string()),
689 }
690}
691
692async fn handle_initialize(endpoint_type: McpEndpointType) -> Result<McpResult, McpRequestError> {
693 Ok(McpResult::Initialize(InitializeResult {
694 protocol_version: MCP_PROTOCOL_VERSION.to_string(),
695 capabilities: Capabilities { tools: json!({}) },
696 server_info: ServerInfo {
697 name: format!("materialize-mcp-{}", endpoint_type),
698 version: env!("CARGO_PKG_VERSION").to_string(),
699 },
700 instructions: endpoint_instructions(endpoint_type),
701 }))
702}
703
704async fn handle_tools_list(
705 endpoint_type: McpEndpointType,
706 query_tool_enabled: bool,
707 max_response_size: usize,
708) -> Result<McpResult, McpRequestError> {
709 let size_hint = format!("Response limit: {} MB.", max_response_size / 1_000_000);
710
711 let tools = match endpoint_type {
712 McpEndpointType::Agent => {
713 let mut tools = vec![
714 ToolDefinition {
715 name: "get_data_products".to_string(),
716 title: Some("List Data Products".to_string()),
717 description: "Discover all available real-time data views (data products) that represent business entities like customers, orders, products, etc. Each data product provides fresh, queryable data with defined schemas. Use this first to see what data is available before querying specific information.".to_string(),
718 input_schema: json!({
719 "type": "object",
720 "properties": {},
721 "required": []
722 }),
723 annotations: Some(READ_ONLY_ANNOTATIONS),
724 },
725 ToolDefinition {
726 name: "get_data_product_details".to_string(),
727 title: Some("Get Data Product Details".to_string()),
728 description: "Get the complete schema and structure of a specific data product, plus a `hydration` object reporting whether the dataflow is ready across the cluster's replicas (`{hydrated, replica_count, hydrated_replica_count}`). This shows you exactly what fields are available, their types, and what data you can query. Reads never return partial data, so check `hydration` before reading: if `hydrated` is false and `replica_count` is greater than 0 the dataflow is still warming up (a read would block until it catches up, possibly hitting the request timeout), so wait and retry; if `replica_count` is 0 the cluster has no replicas and the read cannot make progress until one is added.".to_string(),
729 input_schema: json!({
730 "type": "object",
731 "properties": {
732 "name": {
733 "type": "string",
734 "description": "Exact name of the data product from get_data_products() list"
735 }
736 },
737 "required": ["name"]
738 }),
739 annotations: Some(READ_ONLY_ANNOTATIONS),
740 },
741 ToolDefinition {
742 name: "read_data_product".to_string(),
743 title: Some("Read Data Product".to_string()),
744 description: format!("Read rows from a specific data product. Returns up to `limit` rows (default 500). The data product must exist in the catalog (use get_data_products() to discover available products). Use this to retrieve actual data from a known data product. {size_hint}"),
745 input_schema: json!({
746 "type": "object",
747 "properties": {
748 "name": {
749 "type": "string",
750 "description": "Exact fully-qualified name of the data product (e.g. '\"materialize\".\"schema\".\"view_name\"')"
751 },
752 "limit": {
753 "type": "integer",
754 "description": "Maximum number of rows to return (default 500)",
755 "default": 500
756 },
757 "cluster": {
758 "type": "string",
759 "description": "Optional override. By default, the read runs on the cluster recorded in the data product catalog (where the index or materialized view dataflow lives), so indexed reads actually hit their arrangement. Set this only to intentionally run the same read on a different cluster — e.g. one with more or larger replicas, or to compare cost/latency."
760 }
761 },
762 "required": ["name"]
763 }),
764 annotations: Some(READ_ONLY_ANNOTATIONS),
765 },
766 ];
767 if query_tool_enabled {
768 tools.push(ToolDefinition {
769 name: "query".to_string(),
770 title: Some("Query Data Products".to_string()),
771 description: format!("Execute SQL queries against real-time data products to retrieve current business information. Use standard PostgreSQL syntax. You can JOIN multiple data products together, but ONLY if they are all hosted on the same cluster. Always specify the cluster parameter from the data product details. This provides fresh, up-to-date results from materialized views. {size_hint}"),
772 input_schema: json!({
773 "type": "object",
774 "properties": {
775 "cluster": {
776 "type": "string",
777 "description": "Exact cluster name from the data product details - required for query execution"
778 },
779 "sql_query": {
780 "type": "string",
781 "description": "PostgreSQL-compatible SELECT statement to retrieve data. Use the fully qualified data product name exactly as provided (with double quotes). You can JOIN multiple data products, but only those on the same cluster."
782 }
783 },
784 "required": ["cluster", "sql_query"]
785 }),
786 annotations: Some(READ_ONLY_ANNOTATIONS),
787 });
788 }
789 tools
790 }
791 McpEndpointType::Developer => {
792 vec![ToolDefinition {
793 name: "query_system_catalog".to_string(),
794 title: Some("Query System Catalog".to_string()),
795 description: concat!(
796 "Query Materialize system catalog tables for troubleshooting and observability. ",
797 "Only mz_*, pg_catalog, and information_schema tables are accessible. ",
798 "Use the mz_internal.mz_ontology_* tables to discover tables, columns, and join paths before writing queries.",
799 ).to_owned() + &format!(" {size_hint}"),
800 input_schema: json!({
801 "type": "object",
802 "properties": {
803 "sql_query": {
804 "type": "string",
805 "description": "PostgreSQL-compatible SELECT, SHOW, or EXPLAIN query referencing mz_* system catalog tables"
806 }
807 },
808 "required": ["sql_query"]
809 }),
810 annotations: Some(READ_ONLY_ANNOTATIONS),
811 }]
812 }
813 };
814
815 Ok(McpResult::ToolsList(ToolsListResult { tools }))
816}
817
818async fn handle_tools_call(
819 client: &mut AuthedClient,
820 params: &ToolsCallParams,
821 endpoint_type: McpEndpointType,
822 query_tool_enabled: bool,
823 max_response_size: usize,
824 metrics: &McpMetrics,
825) -> Result<McpResult, McpRequestError> {
826 let mut guard = ToolCallGuard::new(metrics, endpoint_type.as_label(), params.to_string());
828
829 let result = match (endpoint_type, params) {
830 (McpEndpointType::Agent, ToolsCallParams::GetDataProducts(_)) => {
831 get_data_products(client, max_response_size).await
832 }
833 (McpEndpointType::Agent, ToolsCallParams::GetDataProductDetails(p)) => {
834 get_data_product_details(client, &p.name, max_response_size).await
835 }
836 (McpEndpointType::Agent, ToolsCallParams::ReadDataProduct(p)) => {
837 read_data_product(
838 client,
839 &p.name,
840 p.limit,
841 p.cluster.as_deref(),
842 max_response_size,
843 )
844 .await
845 }
846 (McpEndpointType::Agent, ToolsCallParams::Query(_)) if !query_tool_enabled => {
847 Err(McpRequestError::ToolNotFound(
848 "query tool is not available. Use get_data_products, get_data_product_details, and read_data_product instead.".to_string(),
849 ))
850 }
851 (McpEndpointType::Agent, ToolsCallParams::Query(p)) => {
852 execute_query(client, &p.cluster, &p.sql_query, max_response_size).await
853 }
854 (McpEndpointType::Developer, ToolsCallParams::QuerySystemCatalog(p)) => {
855 query_system_catalog(client, &p.sql_query, max_response_size).await
856 }
857 (endpoint, tool) => Err(McpRequestError::ToolNotFound(format!(
859 "{} is not available on {} endpoint",
860 tool, endpoint
861 ))),
862 };
863
864 guard.set_status(match &result {
865 Ok(_) => "ok",
866 Err(e) => e.error_type(),
867 });
868
869 result
870}
871
872async fn execute_sql(
874 client: &mut AuthedClient,
875 query: &str,
876) -> Result<Vec<Vec<serde_json::Value>>, McpRequestError> {
877 let mut response = SqlResponse::new();
878
879 execute_request(
880 client,
881 SqlRequest::Simple {
882 query: query.to_string(),
883 },
884 &mut response,
885 )
886 .await
887 .map_err(|e| McpRequestError::QueryExecutionFailed(e.to_string()))?;
888
889 for result in response.results {
892 match result {
893 SqlResult::Rows { rows, .. } => return Ok(rows),
894 SqlResult::Err { error, .. } => {
895 return Err(McpRequestError::QueryExecutionFailed(error.message));
896 }
897 SqlResult::Ok { .. } => continue,
898 }
899 }
900
901 Err(McpRequestError::QueryExecutionFailed(
902 "Query did not return any results".to_string(),
903 ))
904}
905
906fn format_rows_response(
913 rows: Vec<Vec<serde_json::Value>>,
914 max_size: usize,
915) -> Result<McpResult, McpRequestError> {
916 let text =
917 serde_json::to_string_pretty(&rows).map_err(|e| McpRequestError::Internal(anyhow!(e)))?;
918
919 if text.len() > max_size {
920 return Err(McpRequestError::QueryExecutionFailed(format!(
921 "Response size ({} bytes) exceeds the {} byte limit. \
922 Use LIMIT or WHERE to narrow your query.",
923 text.len(),
924 max_size,
925 )));
926 }
927
928 Ok(McpResult::ToolContent(ToolContentResult {
929 content: vec![ContentBlock {
930 content_type: "text".to_string(),
931 text,
932 }],
933 is_error: false,
934 }))
935}
936
937async fn get_data_products(
938 client: &mut AuthedClient,
939 max_response_size: usize,
940) -> Result<McpResult, McpRequestError> {
941 debug!("Executing get_data_products");
942 let rows = execute_sql(client, DISCOVERY_QUERY).await?;
943 debug!("get_data_products returned {} rows", rows.len());
944
945 format_rows_response(rows, max_response_size)
946}
947
948async fn get_data_product_details(
949 client: &mut AuthedClient,
950 name: &str,
951 max_response_size: usize,
952) -> Result<McpResult, McpRequestError> {
953 debug!(name = %name, "Executing get_data_product_details");
954
955 let query = format!("{}{}", DETAILS_QUERY_PREFIX, escaped_string_literal(name));
956
957 let rows = execute_sql(client, &query).await?;
958
959 if rows.is_empty() {
960 return Err(McpRequestError::DataProductNotFound(name.to_string()));
961 }
962
963 format_rows_response(rows, max_response_size)
964}
965
966fn safe_data_product_name(name: &str) -> Result<String, McpRequestError> {
973 let name = name.trim();
974 if name.is_empty() {
975 return Err(McpRequestError::QueryValidationFailed(
976 "Data product name cannot be empty".to_string(),
977 ));
978 }
979
980 let parsed = parse_item_name(name).map_err(|_| {
981 McpRequestError::QueryValidationFailed(format!(
982 "Invalid data product name: {}. Expected a valid object name, \
983 e.g. '\"database\".\"schema\".\"name\"' or 'my_view'",
984 name
985 ))
986 })?;
987
988 Ok(parsed.to_ast_string_stable())
991}
992
993async fn read_data_product(
1012 client: &mut AuthedClient,
1013 name: &str,
1014 limit: u32,
1015 cluster_override: Option<&str>,
1016 max_response_size: usize,
1017) -> Result<McpResult, McpRequestError> {
1018 debug!(name = %name, limit = limit, cluster_override = ?cluster_override, "Executing read_data_product");
1019
1020 let safe_name = safe_data_product_name(name)?;
1022
1023 let lookup_query = format!(
1028 "SELECT \
1029 cluster, \
1030 cluster IS NULL OR has_cluster_privilege(cluster, 'USAGE') \
1031 AS has_cluster_usage \
1032 FROM mz_internal.mz_mcp_data_products \
1033 WHERE object_name = {} \
1034 ORDER BY \
1035 (cluster IS NOT NULL \
1036 AND has_cluster_privilege(cluster, 'USAGE')) DESC, \
1037 cluster NULLS LAST \
1038 LIMIT 1",
1039 escaped_string_literal(name)
1040 );
1041 let lookup_rows = execute_sql(client, &lookup_query).await?;
1042 if lookup_rows.is_empty() {
1043 return Err(McpRequestError::DataProductNotFound(name.to_string()));
1044 }
1045 let lookup_row = lookup_rows.first();
1046 let catalog_cluster: Option<&str> = lookup_row
1047 .and_then(|row| row.first())
1048 .and_then(|v| v.as_str());
1049 let has_cluster_usage: bool = lookup_row
1052 .and_then(|row| row.get(1))
1053 .and_then(|v| v.as_bool())
1054 .unwrap_or(false);
1055
1056 let target_cluster = match cluster_override {
1062 Some(c) => c,
1063 None => match catalog_cluster {
1064 Some(c) if has_cluster_usage => c,
1065 Some(c) => {
1066 return Err(McpRequestError::ClusterPrivilegeMissing(format!(
1067 "Data product {name} is hosted on cluster {c:?}, which your role \
1068 does not have USAGE on. Pass `cluster: \"<a-cluster-you-have-USAGE-on>\"` \
1069 to read it from a different cluster (slower, no index), or have USAGE \
1070 granted on {c:?}.",
1071 )));
1072 }
1073 None => {
1074 return Err(McpRequestError::Internal(anyhow!(
1077 "data product {name} has no cluster in the catalog"
1078 )));
1079 }
1080 },
1081 };
1082
1083 let read_query = build_read_query(&safe_name, limit, target_cluster);
1088
1089 let rows = execute_sql(client, &read_query).await?;
1090
1091 format_rows_response(rows, max_response_size)
1092}
1093
1094fn build_read_query(safe_name: &str, limit: u32, target_cluster: &str) -> String {
1102 format!(
1103 "BEGIN READ ONLY; SET CLUSTER = {}; SELECT * FROM {} LIMIT {}\n; COMMIT;",
1104 escaped_string_literal(target_cluster),
1105 safe_name,
1106 limit,
1107 )
1108}
1109
1110fn validate_readonly_query(sql: &str) -> Result<(), McpRequestError> {
1112 let sql = sql.trim();
1113 if sql.is_empty() {
1114 return Err(McpRequestError::QueryValidationFailed(
1115 "Empty query".to_string(),
1116 ));
1117 }
1118
1119 let stmts = parse(sql).map_err(|e| {
1121 McpRequestError::QueryValidationFailed(format!("Failed to parse SQL: {}", e))
1122 })?;
1123
1124 if stmts.len() != 1 {
1126 return Err(McpRequestError::QueryValidationFailed(format!(
1127 "Only one query allowed at a time. Found {} statements.",
1128 stmts.len()
1129 )));
1130 }
1131
1132 let stmt = &stmts[0];
1134 use mz_sql_parser::ast::Statement;
1135
1136 match &stmt.ast {
1137 Statement::Select(_) | Statement::Show(_) | Statement::ExplainPlan(_) => {
1138 Ok(())
1140 }
1141 _ => Err(McpRequestError::QueryValidationFailed(
1142 "Only SELECT, SHOW, and EXPLAIN statements are allowed".to_string(),
1143 )),
1144 }
1145}
1146
1147async fn execute_query(
1148 client: &mut AuthedClient,
1149 cluster: &str,
1150 sql_query: &str,
1151 max_response_size: usize,
1152) -> Result<McpResult, McpRequestError> {
1153 debug!(cluster = %cluster, "Executing user query");
1154
1155 validate_readonly_query(sql_query)?;
1156
1157 let combined_query = format!(
1160 "BEGIN READ ONLY; SET CLUSTER = {}; {}\n; COMMIT;",
1161 escaped_string_literal(cluster),
1162 sql_query
1163 );
1164
1165 let rows = execute_sql(client, &combined_query).await?;
1166
1167 format_rows_response(rows, max_response_size)
1168}
1169
1170async fn query_system_catalog(
1171 client: &mut AuthedClient,
1172 sql_query: &str,
1173 max_response_size: usize,
1174) -> Result<McpResult, McpRequestError> {
1175 debug!("Executing query_system_catalog");
1176
1177 validate_readonly_query(sql_query)?;
1179
1180 validate_system_catalog_query(sql_query)?;
1182
1183 let combined_query = format!(
1189 "BEGIN READ ONLY; SET search_path = mz_catalog, mz_internal, pg_catalog, information_schema; {}; COMMIT;",
1190 sql_query
1191 );
1192
1193 let rows = execute_sql(client, &combined_query).await?;
1194
1195 format_rows_response(rows, max_response_size)
1196}
1197
1198struct TableReferenceCollector {
1200 tables: Vec<(Option<String>, String)>,
1202 cte_names: std::collections::BTreeSet<String>,
1204}
1205
1206impl TableReferenceCollector {
1207 fn new() -> Self {
1208 Self {
1209 tables: Vec::new(),
1210 cte_names: std::collections::BTreeSet::new(),
1211 }
1212 }
1213}
1214
1215impl<'ast> Visit<'ast, Raw> for TableReferenceCollector {
1216 fn visit_cte(&mut self, cte: &'ast mz_sql_parser::ast::Cte<Raw>) {
1217 self.cte_names
1219 .insert(cte.alias.name.as_str().to_lowercase());
1220 visit::visit_cte(self, cte);
1221 }
1222
1223 fn visit_table_factor(&mut self, table_factor: &'ast mz_sql_parser::ast::TableFactor<Raw>) {
1224 if let mz_sql_parser::ast::TableFactor::Table { name, .. } = table_factor {
1226 match name {
1227 RawItemName::Name(n) | RawItemName::Id(_, n, _) => {
1228 let parts = &n.0;
1229 if !parts.is_empty() {
1230 let table_name = parts.last().unwrap().as_str().to_lowercase();
1231
1232 if self.cte_names.contains(&table_name) {
1234 visit::visit_table_factor(self, table_factor);
1235 return;
1236 }
1237
1238 let schema = if parts.len() >= 2 {
1240 Some(parts[parts.len() - 2].as_str().to_lowercase())
1241 } else {
1242 None
1243 };
1244 self.tables.push((schema, table_name));
1245 }
1246 }
1247 }
1248 }
1249 visit::visit_table_factor(self, table_factor);
1250 }
1251}
1252
1253fn validate_system_catalog_query(sql: &str) -> Result<(), McpRequestError> {
1261 let stmts = parse(sql).map_err(|e| {
1263 McpRequestError::QueryValidationFailed(format!("Failed to parse SQL: {}", e))
1264 })?;
1265
1266 if stmts.is_empty() {
1267 return Err(McpRequestError::QueryValidationFailed(
1268 "Empty query".to_string(),
1269 ));
1270 }
1271
1272 let mut collector = TableReferenceCollector::new();
1274 for stmt in &stmts {
1275 collector.visit_statement(&stmt.ast);
1276 }
1277
1278 let is_allowed_schema =
1281 |s: &str| SYSTEM_SCHEMAS.contains(&s) && s != namespaces::MZ_UNSAFE_SCHEMA;
1282
1283 let is_system_table = |(schema, table_name): &(Option<String>, String)| match schema {
1288 Some(s) => is_allowed_schema(s.as_str()),
1289 None => table_name.starts_with("mz_"),
1290 };
1291
1292 let non_system_tables: Vec<String> = collector
1294 .tables
1295 .iter()
1296 .filter(|t| !is_system_table(t))
1297 .map(|(schema, table)| match schema {
1298 Some(s) => format!("{}.{}", s, table),
1299 None => table.clone(),
1300 })
1301 .collect();
1302
1303 if !non_system_tables.is_empty() {
1304 return Err(McpRequestError::QueryValidationFailed(format!(
1305 "Query references non-system tables: {}. Only system catalog tables (mz_*, pg_catalog, information_schema) are allowed.",
1306 non_system_tables.join(", ")
1307 )));
1308 }
1309
1310 use mz_sql_parser::ast::Statement;
1313 let is_select = stmts.iter().any(|s| matches!(&s.ast, Statement::Select(_)));
1314
1315 if is_select && (collector.tables.is_empty() || !collector.tables.iter().any(is_system_table)) {
1316 return Err(McpRequestError::QueryValidationFailed(
1317 "Query must reference at least one system catalog table".to_string(),
1318 ));
1319 }
1320
1321 Ok(())
1322}
1323
1324#[cfg(test)]
1325mod tests {
1326 use super::*;
1327
1328 #[mz_ore::test]
1329 fn test_validate_readonly_query_select() {
1330 assert!(validate_readonly_query("SELECT * FROM mz_tables").is_ok());
1331 assert!(validate_readonly_query("SELECT 1 + 2").is_ok());
1332 assert!(validate_readonly_query(" SELECT 1 ").is_ok());
1333 }
1334
1335 #[mz_ore::test]
1336 fn test_validate_readonly_query_subqueries() {
1337 assert!(
1339 validate_readonly_query(
1340 "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM mz_columns)"
1341 )
1342 .is_ok()
1343 );
1344
1345 assert!(
1347 validate_readonly_query(
1348 "SELECT * FROM (SELECT * FROM mz_tables WHERE name LIKE 'test%') AS t"
1349 )
1350 .is_ok()
1351 );
1352
1353 assert!(validate_readonly_query(
1355 "SELECT * FROM mz_tables t WHERE EXISTS (SELECT 1 FROM mz_columns c WHERE c.id = t.id)"
1356 )
1357 .is_ok());
1358
1359 assert!(validate_readonly_query(
1361 "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM mz_columns WHERE type IN (SELECT name FROM mz_types))"
1362 )
1363 .is_ok());
1364
1365 assert!(
1367 validate_readonly_query(
1368 "SELECT * FROM mz_tables WHERE id = (SELECT MAX(id) FROM mz_columns)"
1369 )
1370 .is_ok()
1371 );
1372 }
1373
1374 #[mz_ore::test]
1375 fn test_validate_readonly_query_show() {
1376 assert!(validate_readonly_query("SHOW CLUSTERS").is_ok());
1377 assert!(validate_readonly_query("SHOW TABLES").is_ok());
1378 }
1379
1380 #[mz_ore::test]
1381 fn test_validate_readonly_query_explain() {
1382 assert!(validate_readonly_query("EXPLAIN SELECT 1").is_ok());
1383 }
1384
1385 #[mz_ore::test]
1386 fn test_validate_readonly_query_rejects_writes() {
1387 assert!(validate_readonly_query("INSERT INTO t VALUES (1)").is_err());
1388 assert!(validate_readonly_query("UPDATE t SET a = 1").is_err());
1389 assert!(validate_readonly_query("DELETE FROM t").is_err());
1390 assert!(validate_readonly_query("CREATE TABLE t (a INT)").is_err());
1391 assert!(validate_readonly_query("DROP TABLE t").is_err());
1392 }
1393
1394 #[mz_ore::test]
1395 fn test_validate_readonly_query_rejects_multiple() {
1396 assert!(validate_readonly_query("SELECT 1; SELECT 2").is_err());
1397 }
1398
1399 #[mz_ore::test]
1400 fn test_validate_readonly_query_rejects_empty() {
1401 assert!(validate_readonly_query("").is_err());
1402 assert!(validate_readonly_query(" ").is_err());
1403 }
1404
1405 #[mz_ore::test]
1406 fn test_validate_system_catalog_query_accepts_mz_tables() {
1407 assert!(validate_system_catalog_query("SELECT * FROM mz_tables").is_ok());
1408 assert!(validate_system_catalog_query("SELECT * FROM mz_internal.mz_comments").is_ok());
1409 assert!(
1410 validate_system_catalog_query(
1411 "SELECT * FROM mz_tables t JOIN mz_columns c ON t.id = c.id"
1412 )
1413 .is_ok()
1414 );
1415 }
1416
1417 #[mz_ore::test]
1418 fn test_validate_system_catalog_query_subqueries() {
1419 assert!(
1421 validate_system_catalog_query(
1422 "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM mz_columns)"
1423 )
1424 .is_ok()
1425 );
1426
1427 assert!(validate_system_catalog_query(
1429 "SELECT * FROM mz_tables WHERE id IN (SELECT table_id FROM mz_columns WHERE type IN (SELECT id FROM mz_types))"
1430 )
1431 .is_ok());
1432
1433 assert!(
1435 validate_system_catalog_query(
1436 "SELECT * FROM (SELECT * FROM mz_tables WHERE name LIKE 'test%') AS t"
1437 )
1438 .is_ok()
1439 );
1440
1441 assert!(
1443 validate_system_catalog_query(
1444 "SELECT * FROM mz_tables WHERE id IN (SELECT table_id FROM user_data)"
1445 )
1446 .is_err()
1447 );
1448
1449 assert!(validate_system_catalog_query(
1451 "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM (SELECT id FROM user_table) AS t)"
1452 )
1453 .is_err());
1454 }
1455
1456 #[mz_ore::test]
1457 fn test_validate_system_catalog_query_rejects_user_tables() {
1458 assert!(validate_system_catalog_query("SELECT * FROM user_data").is_err());
1459 assert!(validate_system_catalog_query("SELECT * FROM my_table").is_err());
1460 assert!(
1462 validate_system_catalog_query("SELECT * FROM user_data WHERE 'mz_' IS NOT NULL")
1463 .is_err()
1464 );
1465 }
1466
1467 #[mz_ore::test]
1468 fn test_validate_system_catalog_query_allows_functions() {
1469 assert!(
1471 validate_system_catalog_query(
1472 "SELECT date_part('year', now())::int4 AS y FROM mz_tables LIMIT 1"
1473 )
1474 .is_ok()
1475 );
1476 assert!(validate_system_catalog_query("SELECT length(name) FROM mz_tables").is_ok());
1477 assert!(
1478 validate_system_catalog_query(
1479 "SELECT count(*) FROM mz_sources WHERE now() > created_at"
1480 )
1481 .is_ok()
1482 );
1483 }
1484
1485 #[mz_ore::test]
1486 fn test_validate_system_catalog_query_schema_qualified() {
1487 assert!(validate_system_catalog_query("SELECT * FROM mz_catalog.mz_tables").is_ok());
1489 assert!(validate_system_catalog_query("SELECT * FROM mz_internal.mz_sessions").is_ok());
1490 assert!(validate_system_catalog_query("SELECT * FROM pg_catalog.pg_type").is_ok());
1491 assert!(validate_system_catalog_query("SELECT * FROM information_schema.tables").is_ok());
1492
1493 assert!(validate_system_catalog_query("SELECT * FROM public.user_table").is_err());
1495 assert!(validate_system_catalog_query("SELECT * FROM myschema.mytable").is_err());
1496
1497 assert!(
1499 validate_system_catalog_query("SELECT * FROM mz_unsafe.mz_some_table").is_err(),
1500 "mz_unsafe schema should be blocked even though it is a system schema"
1501 );
1502
1503 assert!(
1505 validate_system_catalog_query(
1506 "SELECT * FROM mz_catalog.mz_tables JOIN public.user_data ON true"
1507 )
1508 .is_err()
1509 );
1510 }
1511
1512 #[mz_ore::test]
1513 fn test_validate_system_catalog_query_adversarial_cases() {
1514 assert!(
1516 validate_system_catalog_query(
1517 "WITH user_cte AS (SELECT * FROM user_data) \
1518 SELECT * FROM mz_tables, user_cte"
1519 )
1520 .is_err(),
1521 "Should reject CTE referencing user table"
1522 );
1523
1524 assert!(
1526 validate_system_catalog_query(
1527 "WITH \
1528 cte1 AS (SELECT * FROM mz_tables), \
1529 cte2 AS (SELECT * FROM cte1), \
1530 cte3 AS (SELECT * FROM user_data) \
1531 SELECT * FROM cte2"
1532 )
1533 .is_err(),
1534 "Should reject CTE chain with user table"
1535 );
1536
1537 assert!(
1539 validate_system_catalog_query(
1540 "SELECT * FROM mz_tables t1 \
1541 JOIN user_data u ON t1.id = u.id \
1542 JOIN mz_sources s ON t1.id = s.id"
1543 )
1544 .is_err(),
1545 "Should reject multi-join with user table"
1546 );
1547
1548 assert!(
1550 validate_system_catalog_query(
1551 "SELECT * FROM mz_tables t \
1552 LEFT JOIN user_data u ON t.id = u.table_id \
1553 WHERE u.id IS NULL"
1554 )
1555 .is_err(),
1556 "Should reject LEFT JOIN with user table"
1557 );
1558
1559 assert!(
1561 validate_system_catalog_query(
1562 "SELECT * FROM mz_tables WHERE id IN \
1563 (SELECT table_id FROM (SELECT * FROM user_data) AS u)"
1564 )
1565 .is_err(),
1566 "Should reject nested subquery with user table"
1567 );
1568
1569 assert!(
1571 validate_system_catalog_query(
1572 "SELECT name FROM mz_tables \
1573 UNION \
1574 SELECT name FROM user_data"
1575 )
1576 .is_err(),
1577 "Should reject UNION with user table"
1578 );
1579
1580 assert!(
1582 validate_system_catalog_query(
1583 "SELECT id FROM mz_sources \
1584 UNION ALL \
1585 SELECT id FROM products"
1586 )
1587 .is_err(),
1588 "Should reject UNION ALL with user table"
1589 );
1590
1591 assert!(
1593 validate_system_catalog_query("SELECT * FROM mz_tables CROSS JOIN user_data").is_err(),
1594 "Should reject CROSS JOIN with user table"
1595 );
1596
1597 assert!(
1599 validate_system_catalog_query(
1600 "SELECT t.*, (SELECT COUNT(*) FROM user_data) AS cnt FROM mz_tables t"
1601 )
1602 .is_err(),
1603 "Should reject subquery in SELECT with user table"
1604 );
1605
1606 assert!(
1608 validate_system_catalog_query("SELECT * FROM mz_catalogg.fake_table").is_err(),
1609 "Should reject typo-squatting schema name"
1610 );
1611 assert!(
1612 validate_system_catalog_query("SELECT * FROM mz_catalog_hack.fake_table").is_err(),
1613 "Should reject fake schema with mz_catalog prefix"
1614 );
1615
1616 assert!(
1618 validate_system_catalog_query(
1619 "SELECT * FROM mz_tables t, LATERAL (SELECT * FROM user_data WHERE id = t.id) u"
1620 )
1621 .is_err(),
1622 "Should reject LATERAL join with user table"
1623 );
1624
1625 assert!(
1627 validate_system_catalog_query(
1628 "WITH \
1629 tables AS (SELECT * FROM mz_tables), \
1630 sources AS (SELECT * FROM mz_sources) \
1631 SELECT t.name, s.name \
1632 FROM tables t \
1633 JOIN sources s ON t.id = s.id \
1634 WHERE t.id IN (SELECT id FROM mz_columns)"
1635 )
1636 .is_ok(),
1637 "Should allow complex query with only system tables"
1638 );
1639
1640 assert!(
1642 validate_system_catalog_query(
1643 "SELECT name FROM mz_tables \
1644 UNION \
1645 SELECT name FROM mz_sources"
1646 )
1647 .is_ok(),
1648 "Should allow UNION of system tables"
1649 );
1650 }
1651
1652 #[mz_ore::test]
1653 fn test_validate_system_catalog_query_rejects_constant_queries() {
1654 assert!(
1657 validate_system_catalog_query("SELECT 1").is_err(),
1658 "Should reject constant SELECT with no table references"
1659 );
1660 assert!(
1661 validate_system_catalog_query("SELECT 1 + 2, 'hello'").is_err(),
1662 "Should reject constant expression SELECT"
1663 );
1664 assert!(
1665 validate_system_catalog_query("SELECT now()").is_err(),
1666 "Should reject function-only SELECT with no table references"
1667 );
1668 }
1669
1670 #[mz_ore::test]
1671 fn test_validate_system_catalog_query_rejects_mixed_tables() {
1672 assert!(
1673 validate_system_catalog_query(
1674 "SELECT * FROM mz_tables t JOIN user_data u ON t.id = u.table_id"
1675 )
1676 .is_err()
1677 );
1678 }
1679
1680 #[mz_ore::test]
1681 fn test_validate_system_catalog_query_allows_show() {
1682 assert!(
1684 validate_system_catalog_query("SHOW TABLES FROM mz_internal").is_ok(),
1685 "SHOW TABLES FROM mz_internal should be allowed"
1686 );
1687 assert!(
1688 validate_system_catalog_query("SHOW TABLES FROM mz_catalog").is_ok(),
1689 "SHOW TABLES FROM mz_catalog should be allowed"
1690 );
1691 assert!(
1692 validate_system_catalog_query("SHOW CLUSTERS").is_ok(),
1693 "SHOW CLUSTERS should be allowed"
1694 );
1695 assert!(
1696 validate_system_catalog_query("SHOW SOURCES").is_ok(),
1697 "SHOW SOURCES should be allowed"
1698 );
1699 assert!(
1700 validate_system_catalog_query("SHOW TABLES").is_ok(),
1701 "SHOW TABLES should be allowed"
1702 );
1703 }
1704
1705 #[mz_ore::test]
1706 fn test_validate_system_catalog_query_allows_explain() {
1707 assert!(
1708 validate_system_catalog_query("EXPLAIN SELECT * FROM mz_tables").is_ok(),
1709 "EXPLAIN of system table query should be allowed"
1710 );
1711 assert!(
1712 validate_system_catalog_query("EXPLAIN SELECT 1").is_ok(),
1713 "EXPLAIN SELECT 1 should be allowed"
1714 );
1715 }
1716
1717 #[mz_ore::test(tokio::test)]
1720 async fn test_tools_list_agent_query_tool_disabled() {
1721 let result = handle_tools_list(McpEndpointType::Agent, false, 1_000_000)
1722 .await
1723 .unwrap();
1724 let McpResult::ToolsList(list) = result else {
1725 panic!("Expected ToolsList result");
1726 };
1727 let tool_names: Vec<&str> = list.tools.iter().map(|t| t.name.as_str()).collect();
1728 assert!(
1729 tool_names.contains(&"get_data_products"),
1730 "get_data_products should always be present"
1731 );
1732 assert!(
1733 tool_names.contains(&"get_data_product_details"),
1734 "get_data_product_details should always be present"
1735 );
1736 assert!(
1737 tool_names.contains(&"read_data_product"),
1738 "read_data_product should always be present"
1739 );
1740 assert!(
1741 !tool_names.contains(&"query"),
1742 "query tool should be hidden when disabled"
1743 );
1744 }
1745
1746 #[mz_ore::test(tokio::test)]
1747 async fn test_tools_list_agent_query_tool_enabled() {
1748 let result = handle_tools_list(McpEndpointType::Agent, true, 1_000_000)
1749 .await
1750 .unwrap();
1751 let McpResult::ToolsList(list) = result else {
1752 panic!("Expected ToolsList result");
1753 };
1754 let tool_names: Vec<&str> = list.tools.iter().map(|t| t.name.as_str()).collect();
1755 assert!(
1756 tool_names.contains(&"get_data_products"),
1757 "get_data_products should always be present"
1758 );
1759 assert!(
1760 tool_names.contains(&"get_data_product_details"),
1761 "get_data_product_details should always be present"
1762 );
1763 assert!(
1764 tool_names.contains(&"read_data_product"),
1765 "read_data_product should always be present"
1766 );
1767 assert!(
1768 tool_names.contains(&"query"),
1769 "query tool should be present when enabled"
1770 );
1771 }
1772
1773 #[mz_ore::test(tokio::test)]
1774 async fn test_tools_list_developer_unaffected_by_query_flag() {
1775 for flag in [true, false] {
1777 let result = handle_tools_list(McpEndpointType::Developer, flag, 1_000_000)
1778 .await
1779 .unwrap();
1780 let McpResult::ToolsList(list) = result else {
1781 panic!("Expected ToolsList result");
1782 };
1783 let tool_names: Vec<&str> = list.tools.iter().map(|t| t.name.as_str()).collect();
1784 assert!(
1785 tool_names.contains(&"query_system_catalog"),
1786 "query_system_catalog should always be present on developer"
1787 );
1788 assert!(
1789 !tool_names.contains(&"query"),
1790 "query tool should never appear on developer"
1791 );
1792 }
1793 }
1794
1795 #[mz_ore::test]
1798 fn test_format_rows_response_within_limit() {
1799 let rows = vec![vec![json!("a"), json!(1)], vec![json!("b"), json!(2)]];
1800 let result = format_rows_response(rows, 1_000_000).unwrap();
1801 let McpResult::ToolContent(content) = result else {
1802 panic!("Expected ToolContent");
1803 };
1804 assert_eq!(content.content.len(), 1);
1805 assert!(content.content[0].text.contains("\"a\""));
1806 assert!(content.content[0].text.contains("\"b\""));
1807 }
1808
1809 #[mz_ore::test]
1810 fn test_format_rows_response_errors_when_over_limit() {
1811 let rows: Vec<Vec<serde_json::Value>> = (0..100)
1812 .map(|i| vec![json!(format!("row_{}", i)), json!(i)])
1813 .collect();
1814 let err = format_rows_response(rows, 500).unwrap_err();
1815 let msg = err.to_string();
1816 assert!(
1817 msg.contains("exceeds the 500 byte limit"),
1818 "Error should mention the size limit, got: {msg}"
1819 );
1820 assert!(
1821 msg.contains("Use LIMIT or WHERE"),
1822 "Error should suggest narrowing the query, got: {msg}"
1823 );
1824 }
1825
1826 #[mz_ore::test]
1827 fn test_format_rows_response_empty_rows() {
1828 let rows: Vec<Vec<serde_json::Value>> = vec![];
1829 let result = format_rows_response(rows, 1000).unwrap();
1830 let McpResult::ToolContent(content) = result else {
1831 panic!("Expected ToolContent");
1832 };
1833 assert_eq!(content.content.len(), 1);
1834 assert_eq!(content.content[0].text, "[]");
1835 }
1836
1837 #[mz_ore::test]
1840 fn test_safe_data_product_name_valid() {
1841 assert_eq!(
1843 safe_data_product_name(r#""materialize"."public"."my_view""#).unwrap(),
1844 r#""materialize"."public"."my_view""#
1845 );
1846 assert_eq!(
1848 safe_data_product_name(r#""public"."my_view""#).unwrap(),
1849 r#""public"."my_view""#
1850 );
1851 assert_eq!(safe_data_product_name("my_view").unwrap(), r#""my_view""#);
1853 }
1854
1855 #[mz_ore::test]
1856 fn test_safe_data_product_name_rejects_empty() {
1857 assert!(safe_data_product_name("").is_err());
1858 assert!(safe_data_product_name(" ").is_err());
1859 }
1860
1861 #[mz_ore::test]
1862 fn test_safe_data_product_name_rejects_sql_injection() {
1863 assert!(safe_data_product_name("my_view; DROP TABLE users").is_err());
1865 assert!(safe_data_product_name("my_view UNION SELECT * FROM secrets").is_err());
1867 assert!(safe_data_product_name("my_view, secrets").is_err());
1869 assert!(safe_data_product_name("my_view WHERE 1=1 --").is_err());
1871 }
1872
1873 #[mz_ore::test]
1879 fn test_build_read_query_with_cluster() {
1880 let sql = build_read_query("\"db\".\"sch\".\"v\"", 50, "prod_cluster");
1881 assert!(sql.contains("BEGIN READ ONLY"), "{sql}");
1882 assert!(sql.contains("SET CLUSTER = 'prod_cluster'"), "{sql}");
1883 assert!(
1884 sql.contains("SELECT * FROM \"db\".\"sch\".\"v\" LIMIT 50"),
1885 "{sql}",
1886 );
1887 assert!(sql.contains("COMMIT"), "{sql}");
1888 }
1889
1890 #[mz_ore::test]
1895 fn test_build_read_query_escapes_cluster_name() {
1896 let sql = build_read_query("\"db\".\"sch\".\"v\"", 10, "evil'; DROP TABLE secrets; --");
1897 assert!(
1899 sql.contains("SET CLUSTER = 'evil''; DROP TABLE secrets; --'"),
1900 "single quote should be doubled inside the literal: {sql}",
1901 );
1902 assert_eq!(
1905 sql.matches("SET CLUSTER").count(),
1906 1,
1907 "exactly one SET CLUSTER statement: {sql}",
1908 );
1909 assert_eq!(
1910 sql.matches("DROP TABLE").count(),
1911 1,
1912 "DROP TABLE should appear once, inside the quoted literal: {sql}",
1913 );
1914 }
1915
1916 #[mz_ore::test]
1917 fn test_mcp_error_codes() {
1918 assert_eq!(
1919 McpRequestError::InvalidJsonRpcVersion.error_code(),
1920 error_codes::INVALID_REQUEST
1921 );
1922 assert_eq!(
1923 McpRequestError::MethodNotFound("test".to_string()).error_code(),
1924 error_codes::METHOD_NOT_FOUND
1925 );
1926 assert_eq!(
1927 McpRequestError::QueryExecutionFailed("test".to_string()).error_code(),
1928 error_codes::INTERNAL_ERROR
1929 );
1930 }
1931}