Skip to main content

mz_environmentd/http/
mcp.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Model Context Protocol (MCP) HTTP handlers.
11//!
12//! Exposes Materialize data products to AI agents via JSON-RPC 2.0 over HTTP POST.
13//!
14//! ## Endpoints
15//!
16//! - `/api/mcp/agent` - User data products for customer AI agents
17//! - `/api/mcp/developer` - System catalog (`mz_*`) for troubleshooting
18//!
19//! ## Tools
20//!
21//! **Agent:** `get_data_products`, `get_data_product_details`, `read_data_product`, `query`
22//! **Developer:** `query_system_catalog`
23//!
24//! Data products are discovered via `mz_internal.mz_mcp_data_products` system view.
25
26use std::sync::Arc;
27use std::time::Duration;
28
29use anyhow::anyhow;
30use axum::Extension;
31use axum::Json;
32use axum::response::IntoResponse;
33use http::{HeaderMap, HeaderValue, StatusCode};
34use mz_adapter_types::dyncfgs::{
35    ENABLE_MCP_AGENT, ENABLE_MCP_AGENT_QUERY_TOOL, ENABLE_MCP_DEVELOPER, MCP_MAX_RESPONSE_SIZE,
36};
37use mz_repr::namespaces::{self, SYSTEM_SCHEMAS};
38use mz_sql::parse::parse;
39use mz_sql::session::metadata::SessionMetadata;
40use mz_sql_parser::ast::display::{AstDisplay, escaped_string_literal};
41use mz_sql_parser::ast::visit::{self, Visit};
42use mz_sql_parser::ast::{Raw, RawItemName};
43use mz_sql_parser::parser::parse_item_name;
44use serde::{Deserialize, Serialize};
45use serde_json::json;
46use thiserror::Error;
47use tracing::{debug, warn};
48
49use crate::http::AuthedClient;
50use crate::http::sql::{SqlRequest, SqlResponse, SqlResult, execute_request};
51
52// To add a new tool: add entry to tools/list, add handler function, add dispatch case.
53
54/// JSON-RPC protocol version used in all MCP requests and responses.
55const JSONRPC_VERSION: &str = "2.0";
56
57/// MCP protocol version returned in the `initialize` response.
58/// Spec: <https://modelcontextprotocol.io/specification/2025-11-25>
59const MCP_PROTOCOL_VERSION: &str = "2025-11-25";
60
61/// Maximum time an MCP tool call can run before the HTTP response is returned.
62/// Note: this returns a clean JSON-RPC error to the caller, but the underlying
63/// query may continue running on the cluster until it completes or is cancelled
64/// separately (see database-issues#9947 for SELECT timeout gaps).
65const MCP_REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
66
67// Discovery uses the lightweight view (no JSON schema computation).
68const DISCOVERY_QUERY: &str = "SELECT * FROM mz_internal.mz_mcp_data_products";
69// Details uses the full view with JSON schema.
70const DETAILS_QUERY_PREFIX: &str =
71    "SELECT * FROM mz_internal.mz_mcp_data_product_details WHERE object_name = ";
72
73/// MCP request errors, mapped to JSON-RPC error codes.
74#[derive(Debug, Error)]
75enum McpRequestError {
76    #[error("Invalid JSON-RPC version: expected 2.0")]
77    InvalidJsonRpcVersion,
78    #[error("Method not found: {0}")]
79    #[allow(dead_code)] // Handled by serde deserialization, kept for error mapping
80    MethodNotFound(String),
81    #[error("Tool not found: {0}")]
82    ToolNotFound(String),
83    #[error("Data product not found: {0}")]
84    DataProductNotFound(String),
85    #[error("Query validation failed: {0}")]
86    QueryValidationFailed(String),
87    #[error("Query execution failed: {0}")]
88    QueryExecutionFailed(String),
89    #[error("Internal error: {0}")]
90    Internal(#[from] anyhow::Error),
91}
92
93impl McpRequestError {
94    fn error_code(&self) -> i32 {
95        match self {
96            Self::InvalidJsonRpcVersion => error_codes::INVALID_REQUEST,
97            Self::MethodNotFound(_) => error_codes::METHOD_NOT_FOUND,
98            Self::ToolNotFound(_) => error_codes::INVALID_PARAMS,
99            Self::DataProductNotFound(_) => error_codes::INVALID_PARAMS,
100            Self::QueryValidationFailed(_) => error_codes::INVALID_PARAMS,
101            Self::QueryExecutionFailed(_) | Self::Internal(_) => error_codes::INTERNAL_ERROR,
102        }
103    }
104
105    fn error_type(&self) -> &'static str {
106        match self {
107            Self::InvalidJsonRpcVersion => "InvalidRequest",
108            Self::MethodNotFound(_) => "MethodNotFound",
109            Self::ToolNotFound(_) => "ToolNotFound",
110            Self::DataProductNotFound(_) => "DataProductNotFound",
111            Self::QueryValidationFailed(_) => "ValidationError",
112            Self::QueryExecutionFailed(_) => "ExecutionError",
113            Self::Internal(_) => "InternalError",
114        }
115    }
116}
117
118/// JSON-RPC 2.0 request. Requests have `id`; notifications don't.
119#[derive(Debug, Deserialize)]
120pub(crate) struct McpRequest {
121    jsonrpc: String,
122    id: Option<serde_json::Value>,
123    #[serde(flatten)]
124    method: McpMethod,
125}
126
127/// MCP method variants with their associated parameters.
128#[derive(Debug, Deserialize)]
129#[serde(tag = "method", content = "params")]
130enum McpMethod {
131    /// Initialize method - params accepted but not currently used
132    #[serde(rename = "initialize")]
133    Initialize(#[allow(dead_code)] InitializeParams),
134    #[serde(rename = "tools/list")]
135    ToolsList,
136    #[serde(rename = "tools/call")]
137    ToolsCall(ToolsCallParams),
138    /// Catch-all for unknown methods (e.g. `notifications/initialized`)
139    #[serde(other)]
140    Unknown,
141}
142
143impl std::fmt::Display for McpMethod {
144    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145        match self {
146            McpMethod::Initialize(_) => write!(f, "initialize"),
147            McpMethod::ToolsList => write!(f, "tools/list"),
148            McpMethod::ToolsCall(_) => write!(f, "tools/call"),
149            McpMethod::Unknown => write!(f, "unknown"),
150        }
151    }
152}
153
154#[derive(Debug, Deserialize)]
155struct InitializeParams {
156    /// Protocol version from client. Not currently validated but accepted for MCP compliance.
157    #[serde(rename = "protocolVersion")]
158    #[allow(dead_code)]
159    protocol_version: String,
160    /// Client capabilities. Not currently used but accepted for MCP compliance.
161    #[serde(default)]
162    #[allow(dead_code)]
163    capabilities: serde_json::Value,
164    /// Client information (name, version). Not currently used but accepted for MCP compliance.
165    #[serde(rename = "clientInfo")]
166    #[allow(dead_code)]
167    client_info: Option<ClientInfo>,
168}
169
170#[derive(Debug, Deserialize)]
171struct ClientInfo {
172    #[allow(dead_code)]
173    name: String,
174    #[allow(dead_code)]
175    version: String,
176}
177
178/// Tool call parameters, deserialized via adjacently tagged enum.
179/// Serde maps `name` to the variant and `arguments` to the variant's data.
180#[derive(Debug, Deserialize)]
181#[serde(tag = "name", content = "arguments")]
182#[serde(rename_all = "snake_case")]
183enum ToolsCallParams {
184    // Agent endpoint tools
185    // Uses an ignored empty struct so MCP clients sending `"arguments": {}` can deserialize.
186    GetDataProducts(#[serde(default)] ()),
187    GetDataProductDetails(GetDataProductDetailsParams),
188    ReadDataProduct(ReadDataProductParams),
189    Query(QueryParams),
190    // Developer endpoint tools
191    QuerySystemCatalog(QuerySystemCatalogParams),
192}
193
194impl std::fmt::Display for ToolsCallParams {
195    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196        match self {
197            ToolsCallParams::GetDataProducts(_) => write!(f, "get_data_products"),
198            ToolsCallParams::GetDataProductDetails(_) => write!(f, "get_data_product_details"),
199            ToolsCallParams::ReadDataProduct(_) => write!(f, "read_data_product"),
200            ToolsCallParams::Query(_) => write!(f, "query"),
201            ToolsCallParams::QuerySystemCatalog(_) => write!(f, "query_system_catalog"),
202        }
203    }
204}
205
206#[derive(Debug, Deserialize)]
207struct GetDataProductDetailsParams {
208    name: String,
209}
210
211#[derive(Debug, Deserialize)]
212struct ReadDataProductParams {
213    name: String,
214    #[serde(default = "default_read_limit")]
215    limit: u32,
216    cluster: Option<String>,
217}
218
219fn default_read_limit() -> u32 {
220    500
221}
222
223/// Maximum number of rows that can be returned by read_data_product.
224const MAX_READ_LIMIT: u32 = 1000;
225
226#[derive(Debug, Deserialize)]
227struct QueryParams {
228    cluster: String,
229    sql_query: String,
230}
231
232#[derive(Debug, Deserialize)]
233struct QuerySystemCatalogParams {
234    sql_query: String,
235}
236
237#[derive(Debug, Serialize)]
238struct McpResponse {
239    jsonrpc: String,
240    id: serde_json::Value,
241    #[serde(skip_serializing_if = "Option::is_none")]
242    result: Option<McpResult>,
243    #[serde(skip_serializing_if = "Option::is_none")]
244    error: Option<McpError>,
245}
246
247/// Typed MCP response results.
248#[derive(Debug, Serialize)]
249#[serde(untagged)]
250enum McpResult {
251    Initialize(InitializeResult),
252    ToolsList(ToolsListResult),
253    ToolContent(ToolContentResult),
254}
255
256#[derive(Debug, Serialize)]
257struct InitializeResult {
258    #[serde(rename = "protocolVersion")]
259    protocol_version: String,
260    capabilities: Capabilities,
261    #[serde(rename = "serverInfo")]
262    server_info: ServerInfo,
263    #[serde(skip_serializing_if = "Option::is_none")]
264    instructions: Option<String>,
265}
266
267#[derive(Debug, Serialize)]
268struct Capabilities {
269    tools: serde_json::Value,
270}
271
272#[derive(Debug, Serialize)]
273struct ServerInfo {
274    name: String,
275    version: String,
276}
277
278#[derive(Debug, Serialize)]
279struct ToolsListResult {
280    tools: Vec<ToolDefinition>,
281}
282
283#[derive(Debug, Serialize)]
284struct ToolDefinition {
285    name: String,
286    #[serde(skip_serializing_if = "Option::is_none")]
287    title: Option<String>,
288    description: String,
289    #[serde(rename = "inputSchema")]
290    input_schema: serde_json::Value,
291    #[serde(skip_serializing_if = "Option::is_none")]
292    annotations: Option<ToolAnnotations>,
293}
294
295/// MCP 2025-11-25 tool annotations that describe tool behavior.
296/// These hints help clients make trust and safety decisions.
297#[derive(Debug, Serialize)]
298struct ToolAnnotations {
299    #[serde(rename = "readOnlyHint", skip_serializing_if = "Option::is_none")]
300    read_only_hint: Option<bool>,
301    #[serde(rename = "destructiveHint", skip_serializing_if = "Option::is_none")]
302    destructive_hint: Option<bool>,
303    #[serde(rename = "idempotentHint", skip_serializing_if = "Option::is_none")]
304    idempotent_hint: Option<bool>,
305    #[serde(rename = "openWorldHint", skip_serializing_if = "Option::is_none")]
306    open_world_hint: Option<bool>,
307}
308
309/// Annotations for all MCP tools: read-only, non-destructive, idempotent.
310const READ_ONLY_ANNOTATIONS: ToolAnnotations = ToolAnnotations {
311    read_only_hint: Some(true),
312    destructive_hint: Some(false),
313    idempotent_hint: Some(true),
314    open_world_hint: Some(false),
315};
316
317#[derive(Debug, Serialize)]
318struct ToolContentResult {
319    content: Vec<ContentBlock>,
320    #[serde(rename = "isError")]
321    is_error: bool,
322}
323
324#[derive(Debug, Serialize)]
325struct ContentBlock {
326    #[serde(rename = "type")]
327    content_type: String,
328    text: String,
329}
330
331/// JSON-RPC 2.0 error codes.
332mod error_codes {
333    pub const INVALID_REQUEST: i32 = -32600;
334    pub const METHOD_NOT_FOUND: i32 = -32601;
335    pub const INVALID_PARAMS: i32 = -32602;
336    pub const INTERNAL_ERROR: i32 = -32603;
337}
338
339#[derive(Debug, Serialize)]
340struct McpError {
341    code: i32,
342    message: String,
343    #[serde(skip_serializing_if = "Option::is_none")]
344    data: Option<serde_json::Value>,
345}
346
347impl From<McpRequestError> for McpError {
348    fn from(err: McpRequestError) -> Self {
349        McpError {
350            code: err.error_code(),
351            message: err.to_string(),
352            data: Some(json!({
353                "error_type": err.error_type(),
354            })),
355        }
356    }
357}
358
359#[derive(Debug, Clone, Copy)]
360enum McpEndpointType {
361    Agent,
362    Developer,
363}
364
365impl std::fmt::Display for McpEndpointType {
366    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
367        match self {
368            McpEndpointType::Agent => write!(f, "agent"),
369            McpEndpointType::Developer => write!(f, "developer"),
370        }
371    }
372}
373
374/// MCP 2025-11-25 requires servers to return 405 for GET requests
375/// on endpoints that only support POST.
376pub async fn handle_mcp_method_not_allowed() -> impl IntoResponse {
377    StatusCode::METHOD_NOT_ALLOWED
378}
379
380/// Agent endpoint: exposes user data products.
381pub async fn handle_mcp_agent(
382    headers: HeaderMap,
383    Extension(allowed_origins): Extension<Arc<Vec<HeaderValue>>>,
384    client: AuthedClient,
385    Json(body): Json<McpRequest>,
386) -> axum::response::Response {
387    if let Some(resp) = validate_origin(&headers, &allowed_origins) {
388        return resp;
389    }
390    handle_mcp_request(client, body, McpEndpointType::Agent)
391        .await
392        .into_response()
393}
394
395/// Developer endpoint: exposes system catalog (mz_*) only.
396pub async fn handle_mcp_developer(
397    headers: HeaderMap,
398    Extension(allowed_origins): Extension<Arc<Vec<HeaderValue>>>,
399    client: AuthedClient,
400    Json(body): Json<McpRequest>,
401) -> axum::response::Response {
402    if let Some(resp) = validate_origin(&headers, &allowed_origins) {
403        return resp;
404    }
405    handle_mcp_request(client, body, McpEndpointType::Developer)
406        .await
407        .into_response()
408}
409
410/// Validates the Origin header against the CORS allowlist to prevent DNS
411/// rebinding attacks (MCP spec 2025-11-25). Returns Some(403) if Origin is
412/// present but not on the allowlist. Returns None if absent (non-browser
413/// client) or allowed.
414///
415/// Note: this server-side check is required in addition to the CorsLayer.
416/// CorsLayer only controls response headers and can be bypassed when the
417/// attacker arranges same-origin DNS rebinding (no preflight fires).
418fn validate_origin(
419    headers: &HeaderMap,
420    allowed: &[HeaderValue],
421) -> Option<axum::response::Response> {
422    let origin = headers.get(http::header::ORIGIN)?;
423    if mz_http_util::origin_is_allowed(origin, allowed) {
424        return None;
425    }
426    warn!(
427        origin = ?origin,
428        "MCP request rejected: origin not in allowlist",
429    );
430    Some(StatusCode::FORBIDDEN.into_response())
431}
432
433async fn handle_mcp_request(
434    mut client: AuthedClient,
435    request: McpRequest,
436    endpoint_type: McpEndpointType,
437) -> impl IntoResponse {
438    // Check the per-endpoint feature flag via a catalog snapshot, similar to frontend_peek.rs.
439    let catalog = client.client.catalog_snapshot("mcp").await;
440    let dyncfgs = catalog.system_config().dyncfgs();
441    let enabled = match endpoint_type {
442        McpEndpointType::Agent => ENABLE_MCP_AGENT.get(dyncfgs),
443        McpEndpointType::Developer => ENABLE_MCP_DEVELOPER.get(dyncfgs),
444    };
445    if !enabled {
446        debug!(endpoint = %endpoint_type, "MCP endpoint disabled by feature flag");
447        return StatusCode::SERVICE_UNAVAILABLE.into_response();
448    }
449
450    let query_tool_enabled = ENABLE_MCP_AGENT_QUERY_TOOL.get(dyncfgs);
451    let max_response_size = MCP_MAX_RESPONSE_SIZE.get(dyncfgs);
452
453    let user = client.client.session().user().name.clone();
454    let is_notification = request.id.is_none();
455
456    debug!(
457        method = %request.method,
458        endpoint = %endpoint_type,
459        user = %user,
460        is_notification = is_notification,
461        "MCP request received"
462    );
463
464    // Handle notifications (no response needed)
465    if is_notification {
466        debug!(method = %request.method, "Received notification (no response will be sent)");
467        return StatusCode::OK.into_response();
468    }
469
470    let request_id = request.id.clone().unwrap_or(serde_json::Value::Null);
471
472    // Spawn task for fault isolation, with a timeout safety net.
473    let result = tokio::time::timeout(
474        MCP_REQUEST_TIMEOUT,
475        mz_ore::task::spawn(|| "mcp_request", async move {
476            handle_mcp_request_inner(
477                &mut client,
478                request,
479                endpoint_type,
480                query_tool_enabled,
481                max_response_size,
482            )
483            .await
484        }),
485    )
486    .await;
487
488    let response = match result {
489        Ok(inner) => inner,
490        Err(_elapsed) => {
491            warn!(
492                endpoint = %endpoint_type,
493                timeout = ?MCP_REQUEST_TIMEOUT,
494                "MCP request timed out",
495            );
496            McpResponse {
497                jsonrpc: JSONRPC_VERSION.to_string(),
498                id: request_id,
499                result: None,
500                error: Some(
501                    McpRequestError::QueryExecutionFailed(format!(
502                        "Request timed out after {} seconds.",
503                        MCP_REQUEST_TIMEOUT.as_secs(),
504                    ))
505                    .into(),
506                ),
507            }
508        }
509    };
510
511    (StatusCode::OK, Json(response)).into_response()
512}
513
514async fn handle_mcp_request_inner(
515    client: &mut AuthedClient,
516    request: McpRequest,
517    endpoint_type: McpEndpointType,
518    query_tool_enabled: bool,
519    max_response_size: usize,
520) -> McpResponse {
521    // Extract request ID (guaranteed to be Some since notifications are filtered earlier)
522    let request_id = request.id.clone().unwrap_or(serde_json::Value::Null);
523
524    let result = handle_mcp_method(
525        client,
526        &request,
527        endpoint_type,
528        query_tool_enabled,
529        max_response_size,
530    )
531    .await;
532
533    match result {
534        Ok(result_value) => McpResponse {
535            jsonrpc: JSONRPC_VERSION.to_string(),
536            id: request_id,
537            result: Some(result_value),
538            error: None,
539        },
540        Err(e) => {
541            // Log non-trivial errors
542            if !matches!(
543                e,
544                McpRequestError::MethodNotFound(_) | McpRequestError::InvalidJsonRpcVersion
545            ) {
546                warn!(error = %e, method = %request.method, "MCP method execution failed");
547            }
548            McpResponse {
549                jsonrpc: JSONRPC_VERSION.to_string(),
550                id: request_id,
551                result: None,
552                error: Some(e.into()),
553            }
554        }
555    }
556}
557
558async fn handle_mcp_method(
559    client: &mut AuthedClient,
560    request: &McpRequest,
561    endpoint_type: McpEndpointType,
562    query_tool_enabled: bool,
563    max_response_size: usize,
564) -> Result<McpResult, McpRequestError> {
565    // Validate JSON-RPC version
566    if request.jsonrpc != JSONRPC_VERSION {
567        return Err(McpRequestError::InvalidJsonRpcVersion);
568    }
569
570    // Handle different MCP methods using pattern matching
571    match &request.method {
572        McpMethod::Initialize(_) => {
573            debug!(endpoint = %endpoint_type, "Processing initialize");
574            handle_initialize(endpoint_type).await
575        }
576        McpMethod::ToolsList => {
577            debug!(endpoint = %endpoint_type, "Processing tools/list");
578            handle_tools_list(endpoint_type, query_tool_enabled, max_response_size).await
579        }
580        McpMethod::ToolsCall(params) => {
581            debug!(tool = %params, endpoint = %endpoint_type, "Processing tools/call");
582            handle_tools_call(
583                client,
584                params,
585                endpoint_type,
586                query_tool_enabled,
587                max_response_size,
588            )
589            .await
590        }
591        McpMethod::Unknown => Err(McpRequestError::MethodNotFound(
592            "unknown method".to_string(),
593        )),
594    }
595}
596
597/// Instructions returned in the `initialize` response for each endpoint type.
598/// These guide the AI agent on how to use the server correctly.
599fn endpoint_instructions(endpoint_type: McpEndpointType) -> Option<String> {
600    match endpoint_type {
601        McpEndpointType::Agent => None,
602        McpEndpointType::Developer => Some(concat!(
603            "You are connected to the Materialize developer MCP server. ",
604            "You have read-only access to system catalog tables (mz_*, pg_catalog, information_schema) ",
605            "for troubleshooting and observability.\n\n",
606            "IMPORTANT: Before writing queries, discover table schemas using the mz_ontology tables:\n",
607            "- mz_internal.mz_ontology_entity_types: what catalog entities exist and which tables they map to\n",
608            "- mz_internal.mz_ontology_link_types: relationships between entities (foreign keys, metrics, etc.)\n",
609            "- mz_internal.mz_ontology_properties: column names, types, and descriptions for each entity\n",
610            "- mz_internal.mz_ontology_semantic_types: typed ID domains (CatalogItemId, ReplicaId, etc.)\n\n",
611            "Use these to find the correct tables, join paths, and column names instead of guessing.\n\n",
612            "Key rules:\n",
613            "- mz_source_statuses and mz_sink_statuses use `last_status_change_at` (NOT `updated_at`)\n",
614            "- mz_cluster_replica_utilization only has `replica_id` — JOIN with mz_cluster_replicas and mz_clusters to get names\n",
615            "- Do NOT query mz_introspection.mz_dataflow_arrangement_sizes — it is cluster-scoped and has uint8/text type mismatches\n",
616            "- Use SHOW COLUMNS FROM <table> to verify column names if unsure",
617        ).to_string()),
618    }
619}
620
621async fn handle_initialize(endpoint_type: McpEndpointType) -> Result<McpResult, McpRequestError> {
622    Ok(McpResult::Initialize(InitializeResult {
623        protocol_version: MCP_PROTOCOL_VERSION.to_string(),
624        capabilities: Capabilities { tools: json!({}) },
625        server_info: ServerInfo {
626            name: format!("materialize-mcp-{}", endpoint_type),
627            version: env!("CARGO_PKG_VERSION").to_string(),
628        },
629        instructions: endpoint_instructions(endpoint_type),
630    }))
631}
632
633async fn handle_tools_list(
634    endpoint_type: McpEndpointType,
635    query_tool_enabled: bool,
636    max_response_size: usize,
637) -> Result<McpResult, McpRequestError> {
638    let size_hint = format!("Response limit: {} MB.", max_response_size / 1_000_000);
639
640    let tools = match endpoint_type {
641        McpEndpointType::Agent => {
642            let mut tools = vec![
643                ToolDefinition {
644                    name: "get_data_products".to_string(),
645                    title: Some("List Data Products".to_string()),
646                    description: "Discover all available real-time data views (data products) that represent business entities like customers, orders, products, etc. Each data product provides fresh, queryable data with defined schemas. Use this first to see what data is available before querying specific information.".to_string(),
647                    input_schema: json!({
648                        "type": "object",
649                        "properties": {},
650                        "required": []
651                    }),
652                    annotations: Some(READ_ONLY_ANNOTATIONS),
653                },
654                ToolDefinition {
655                    name: "get_data_product_details".to_string(),
656                    title: Some("Get Data Product Details".to_string()),
657                    description: "Get the complete schema and structure of a specific data product. This shows you exactly what fields are available, their types, and what data you can query. Use this after finding a data product from get_data_products() to understand how to query it.".to_string(),
658                    input_schema: json!({
659                        "type": "object",
660                        "properties": {
661                            "name": {
662                                "type": "string",
663                                "description": "Exact name of the data product from get_data_products() list"
664                            }
665                        },
666                        "required": ["name"]
667                    }),
668                    annotations: Some(READ_ONLY_ANNOTATIONS),
669                },
670                ToolDefinition {
671                    name: "read_data_product".to_string(),
672                    title: Some("Read Data Product".to_string()),
673                    description: format!("Read rows from a specific data product. Returns up to `limit` rows (default 500, max 1000). The data product must exist in the catalog (use get_data_products() to discover available products). Use this to retrieve actual data from a known data product. {size_hint}"),
674                    input_schema: json!({
675                        "type": "object",
676                        "properties": {
677                            "name": {
678                                "type": "string",
679                                "description": "Exact fully-qualified name of the data product (e.g. '\"materialize\".\"schema\".\"view_name\"')"
680                            },
681                            "limit": {
682                                "type": "integer",
683                                "description": "Maximum number of rows to return (default 500, max 1000)",
684                                "default": 500
685                            },
686                            "cluster": {
687                                "type": "string",
688                                "description": "Optional cluster override. If omitted, uses the cluster from the data product catalog."
689                            }
690                        },
691                        "required": ["name"]
692                    }),
693                    annotations: Some(READ_ONLY_ANNOTATIONS),
694                },
695            ];
696            if query_tool_enabled {
697                tools.push(ToolDefinition {
698                    name: "query".to_string(),
699                    title: Some("Query Data Products".to_string()),
700                    description: format!("Execute SQL queries against real-time data products to retrieve current business information. Use standard PostgreSQL syntax. You can JOIN multiple data products together, but ONLY if they are all hosted on the same cluster. Always specify the cluster parameter from the data product details. This provides fresh, up-to-date results from materialized views. {size_hint}"),
701                    input_schema: json!({
702                        "type": "object",
703                        "properties": {
704                            "cluster": {
705                                "type": "string",
706                                "description": "Exact cluster name from the data product details - required for query execution"
707                            },
708                            "sql_query": {
709                                "type": "string",
710                                "description": "PostgreSQL-compatible SELECT statement to retrieve data. Use the fully qualified data product name exactly as provided (with double quotes). You can JOIN multiple data products, but only those on the same cluster."
711                            }
712                        },
713                        "required": ["cluster", "sql_query"]
714                    }),
715                    annotations: Some(READ_ONLY_ANNOTATIONS),
716                });
717            }
718            tools
719        }
720        McpEndpointType::Developer => {
721            vec![ToolDefinition {
722                name: "query_system_catalog".to_string(),
723                title: Some("Query System Catalog".to_string()),
724                description: concat!(
725                    "Query Materialize system catalog tables for troubleshooting and observability. ",
726                    "Only mz_*, pg_catalog, and information_schema tables are accessible. ",
727                    "Use the mz_internal.mz_ontology_* tables to discover tables, columns, and join paths before writing queries.",
728                ).to_owned() + &format!(" {size_hint}"),
729                input_schema: json!({
730                    "type": "object",
731                    "properties": {
732                        "sql_query": {
733                            "type": "string",
734                            "description": "PostgreSQL-compatible SELECT, SHOW, or EXPLAIN query referencing mz_* system catalog tables"
735                        }
736                    },
737                    "required": ["sql_query"]
738                }),
739                annotations: Some(READ_ONLY_ANNOTATIONS),
740            }]
741        }
742    };
743
744    Ok(McpResult::ToolsList(ToolsListResult { tools }))
745}
746
747async fn handle_tools_call(
748    client: &mut AuthedClient,
749    params: &ToolsCallParams,
750    endpoint_type: McpEndpointType,
751    query_tool_enabled: bool,
752    max_response_size: usize,
753) -> Result<McpResult, McpRequestError> {
754    match (endpoint_type, params) {
755        (McpEndpointType::Agent, ToolsCallParams::GetDataProducts(_)) => {
756            get_data_products(client, max_response_size).await
757        }
758        (McpEndpointType::Agent, ToolsCallParams::GetDataProductDetails(p)) => {
759            get_data_product_details(client, &p.name, max_response_size).await
760        }
761        (McpEndpointType::Agent, ToolsCallParams::ReadDataProduct(p)) => {
762            read_data_product(
763                client,
764                &p.name,
765                p.limit,
766                p.cluster.as_deref(),
767                max_response_size,
768            )
769            .await
770        }
771        (McpEndpointType::Agent, ToolsCallParams::Query(_)) if !query_tool_enabled => {
772            Err(McpRequestError::ToolNotFound(
773                "query tool is not available. Use get_data_products, get_data_product_details, and read_data_product instead.".to_string(),
774            ))
775        }
776        (McpEndpointType::Agent, ToolsCallParams::Query(p)) => {
777            execute_query(client, &p.cluster, &p.sql_query, max_response_size).await
778        }
779        (McpEndpointType::Developer, ToolsCallParams::QuerySystemCatalog(p)) => {
780            query_system_catalog(client, &p.sql_query, max_response_size).await
781        }
782        // Tool called on wrong endpoint
783        (endpoint, tool) => Err(McpRequestError::ToolNotFound(format!(
784            "{} is not available on {} endpoint",
785            tool, endpoint
786        ))),
787    }
788}
789
790/// Execute SQL via `execute_request` from sql.rs.
791async fn execute_sql(
792    client: &mut AuthedClient,
793    query: &str,
794) -> Result<Vec<Vec<serde_json::Value>>, McpRequestError> {
795    let mut response = SqlResponse::new();
796
797    execute_request(
798        client,
799        SqlRequest::Simple {
800            query: query.to_string(),
801        },
802        &mut response,
803    )
804    .await
805    .map_err(|e| McpRequestError::QueryExecutionFailed(e.to_string()))?;
806
807    // Extract the result with rows (the user's single SELECT/SHOW query)
808    // Other results will be OK (from BEGIN, SET, COMMIT) or Err
809    for result in response.results {
810        match result {
811            SqlResult::Rows { rows, .. } => return Ok(rows),
812            SqlResult::Err { error, .. } => {
813                return Err(McpRequestError::QueryExecutionFailed(error.message));
814            }
815            SqlResult::Ok { .. } => continue,
816        }
817    }
818
819    Err(McpRequestError::QueryExecutionFailed(
820        "Query did not return any results".to_string(),
821    ))
822}
823
824/// Serialize rows to JSON and enforce the response size cap.
825///
826/// If the serialized response exceeds `max_size` bytes, returns an error
827/// telling the agent to narrow its query. This mirrors how the HTTP SQL
828/// endpoint handles `max_result_size` in sql.rs — fail cleanly rather
829/// than silently truncating.
830fn format_rows_response(
831    rows: Vec<Vec<serde_json::Value>>,
832    max_size: usize,
833) -> Result<McpResult, McpRequestError> {
834    let text =
835        serde_json::to_string_pretty(&rows).map_err(|e| McpRequestError::Internal(anyhow!(e)))?;
836
837    if text.len() > max_size {
838        return Err(McpRequestError::QueryExecutionFailed(format!(
839            "Response size ({} bytes) exceeds the {} byte limit. \
840             Use LIMIT or WHERE to narrow your query.",
841            text.len(),
842            max_size,
843        )));
844    }
845
846    Ok(McpResult::ToolContent(ToolContentResult {
847        content: vec![ContentBlock {
848            content_type: "text".to_string(),
849            text,
850        }],
851        is_error: false,
852    }))
853}
854
855async fn get_data_products(
856    client: &mut AuthedClient,
857    max_response_size: usize,
858) -> Result<McpResult, McpRequestError> {
859    debug!("Executing get_data_products");
860    let rows = execute_sql(client, DISCOVERY_QUERY).await?;
861    debug!("get_data_products returned {} rows", rows.len());
862
863    format_rows_response(rows, max_response_size)
864}
865
866async fn get_data_product_details(
867    client: &mut AuthedClient,
868    name: &str,
869    max_response_size: usize,
870) -> Result<McpResult, McpRequestError> {
871    debug!(name = %name, "Executing get_data_product_details");
872
873    let query = format!("{}{}", DETAILS_QUERY_PREFIX, escaped_string_literal(name));
874
875    let rows = execute_sql(client, &query).await?;
876
877    if rows.is_empty() {
878        return Err(McpRequestError::DataProductNotFound(name.to_string()));
879    }
880
881    format_rows_response(rows, max_response_size)
882}
883
884/// Parses a data product name and returns it safely quoted for SQL interpolation.
885///
886/// Uses the SQL parser to validate the name as an `UnresolvedItemName`, then
887/// formats it with `FormatMode::Stable` so every identifier part is
888/// double-quoted with proper escaping. This prevents SQL injection regardless
889/// of the input.
890fn safe_data_product_name(name: &str) -> Result<String, McpRequestError> {
891    let name = name.trim();
892    if name.is_empty() {
893        return Err(McpRequestError::QueryValidationFailed(
894            "Data product name cannot be empty".to_string(),
895        ));
896    }
897
898    let parsed = parse_item_name(name).map_err(|_| {
899        McpRequestError::QueryValidationFailed(format!(
900            "Invalid data product name: {}. Expected a valid object name, \
901             e.g. '\"database\".\"schema\".\"name\"' or 'my_view'",
902            name
903        ))
904    })?;
905
906    // Stable formatting forces all identifiers to be double-quoted,
907    // so SQL keywords and special characters cannot escape.
908    Ok(parsed.to_ast_string_stable())
909}
910
911/// Read rows from a data product. Issues a single read-only query.
912///
913/// When `cluster_override` is provided, sets the cluster explicitly.
914/// Otherwise the query runs on the session's default cluster.
915///
916/// The name is expected to come from `get_data_products()` / `get_data_product_details()`.
917/// The query runs inside a READ ONLY transaction, preventing mutations.
918async fn read_data_product(
919    client: &mut AuthedClient,
920    name: &str,
921    limit: u32,
922    cluster_override: Option<&str>,
923    max_response_size: usize,
924) -> Result<McpResult, McpRequestError> {
925    debug!(name = %name, limit = limit, cluster_override = ?cluster_override, "Executing read_data_product");
926
927    // Parse and safely quote the name for SQL interpolation.
928    let safe_name = safe_data_product_name(name)?;
929
930    // Lightweight existence check: verify the data product is visible in the
931    // catalog before running the read query. This gives a clean DataProductNotFound
932    // error for missing or inaccessible products (including RBAC revocations)
933    // without relying on fragile error code matching.
934    // TODO: Remove this extra round-trip once catalog errors get specific SQL
935    // error codes (see TODO in src/adapter/src/error.rs `fn code()`), then we
936    // can translate the query error directly.
937    let lookup_query = format!(
938        "SELECT 1 FROM mz_internal.mz_mcp_data_products WHERE object_name = {}",
939        escaped_string_literal(name)
940    );
941    let lookup_rows = execute_sql(client, &lookup_query).await?;
942    if lookup_rows.is_empty() {
943        return Err(McpRequestError::DataProductNotFound(name.to_string()));
944    }
945
946    let clamped_limit = limit.min(MAX_READ_LIMIT);
947
948    let read_query = match cluster_override {
949        Some(cluster) => format!(
950            "BEGIN READ ONLY; SET CLUSTER = {}; SELECT * FROM {} LIMIT {}\n; COMMIT;",
951            escaped_string_literal(cluster),
952            safe_name,
953            clamped_limit,
954        ),
955        // Single statement — skip explicit transaction for better performance.
956        None => format!("SELECT * FROM {} LIMIT {}", safe_name, clamped_limit),
957    };
958
959    let rows = execute_sql(client, &read_query).await?;
960
961    format_rows_response(rows, max_response_size)
962}
963
964/// Validates query is a single SELECT, SHOW, or EXPLAIN statement.
965fn validate_readonly_query(sql: &str) -> Result<(), McpRequestError> {
966    let sql = sql.trim();
967    if sql.is_empty() {
968        return Err(McpRequestError::QueryValidationFailed(
969            "Empty query".to_string(),
970        ));
971    }
972
973    // Parse the SQL to get AST
974    let stmts = parse(sql).map_err(|e| {
975        McpRequestError::QueryValidationFailed(format!("Failed to parse SQL: {}", e))
976    })?;
977
978    // Only allow a single statement
979    if stmts.len() != 1 {
980        return Err(McpRequestError::QueryValidationFailed(format!(
981            "Only one query allowed at a time. Found {} statements.",
982            stmts.len()
983        )));
984    }
985
986    // Allowlist: Only SELECT, SHOW, and EXPLAIN statements permitted
987    let stmt = &stmts[0];
988    use mz_sql_parser::ast::Statement;
989
990    match &stmt.ast {
991        Statement::Select(_) | Statement::Show(_) | Statement::ExplainPlan(_) => {
992            // Allowed - read-only operations
993            Ok(())
994        }
995        _ => Err(McpRequestError::QueryValidationFailed(
996            "Only SELECT, SHOW, and EXPLAIN statements are allowed".to_string(),
997        )),
998    }
999}
1000
1001async fn execute_query(
1002    client: &mut AuthedClient,
1003    cluster: &str,
1004    sql_query: &str,
1005    max_response_size: usize,
1006) -> Result<McpResult, McpRequestError> {
1007    debug!(cluster = %cluster, "Executing user query");
1008
1009    validate_readonly_query(sql_query)?;
1010
1011    // Use READ ONLY transaction to prevent modifications
1012    // Combine with SET CLUSTER (prometheus.rs:29-33 pattern)
1013    let combined_query = format!(
1014        "BEGIN READ ONLY; SET CLUSTER = {}; {}\n; COMMIT;",
1015        escaped_string_literal(cluster),
1016        sql_query
1017    );
1018
1019    let rows = execute_sql(client, &combined_query).await?;
1020
1021    format_rows_response(rows, max_response_size)
1022}
1023
1024async fn query_system_catalog(
1025    client: &mut AuthedClient,
1026    sql_query: &str,
1027    max_response_size: usize,
1028) -> Result<McpResult, McpRequestError> {
1029    debug!("Executing query_system_catalog");
1030
1031    // First validate it's a read-only query
1032    validate_readonly_query(sql_query)?;
1033
1034    // Then validate that query only references mz_* tables by parsing the SQL
1035    validate_system_catalog_query(sql_query)?;
1036
1037    // Wrap the query in a READ ONLY transaction with a tight search_path
1038    // restricted to system schemas. This prevents unqualified `mz_*` references
1039    // from resolving to user-created objects (e.g. a view `public.mz_leak`) via
1040    // the session's search_path (mirrors the `BEGIN READ ONLY; SET ...` pattern
1041    // used by the agent `query` tool).
1042    let combined_query = format!(
1043        "BEGIN READ ONLY; SET search_path = mz_catalog, mz_internal, pg_catalog, information_schema; {}; COMMIT;",
1044        sql_query
1045    );
1046
1047    let rows = execute_sql(client, &combined_query).await?;
1048
1049    format_rows_response(rows, max_response_size)
1050}
1051
1052/// Collects table references from SQL AST with their schema qualification.
1053struct TableReferenceCollector {
1054    /// Stores (schema, table_name) tuples. Schema is None if unqualified.
1055    tables: Vec<(Option<String>, String)>,
1056    /// CTE names to exclude from validation (they're not real tables)
1057    cte_names: std::collections::BTreeSet<String>,
1058}
1059
1060impl TableReferenceCollector {
1061    fn new() -> Self {
1062        Self {
1063            tables: Vec::new(),
1064            cte_names: std::collections::BTreeSet::new(),
1065        }
1066    }
1067}
1068
1069impl<'ast> Visit<'ast, Raw> for TableReferenceCollector {
1070    fn visit_cte(&mut self, cte: &'ast mz_sql_parser::ast::Cte<Raw>) {
1071        // Track CTE names so we don't treat them as table references
1072        self.cte_names
1073            .insert(cte.alias.name.as_str().to_lowercase());
1074        visit::visit_cte(self, cte);
1075    }
1076
1077    fn visit_table_factor(&mut self, table_factor: &'ast mz_sql_parser::ast::TableFactor<Raw>) {
1078        // Only visit actual table references in FROM/JOIN clauses, not function names
1079        if let mz_sql_parser::ast::TableFactor::Table { name, .. } = table_factor {
1080            match name {
1081                RawItemName::Name(n) | RawItemName::Id(_, n, _) => {
1082                    let parts = &n.0;
1083                    if !parts.is_empty() {
1084                        let table_name = parts.last().unwrap().as_str().to_lowercase();
1085
1086                        // Skip if this is a CTE reference, not a real table
1087                        if self.cte_names.contains(&table_name) {
1088                            visit::visit_table_factor(self, table_factor);
1089                            return;
1090                        }
1091
1092                        // Extract schema if qualified (e.g., mz_catalog.mz_tables)
1093                        let schema = if parts.len() >= 2 {
1094                            Some(parts[parts.len() - 2].as_str().to_lowercase())
1095                        } else {
1096                            None
1097                        };
1098                        self.tables.push((schema, table_name));
1099                    }
1100                }
1101            }
1102        }
1103        visit::visit_table_factor(self, table_factor);
1104    }
1105}
1106
1107/// Validates that a query only references system catalog tables.
1108///
1109/// For SELECT statements, all table references must be in system schemas
1110/// (from `SYSTEM_SCHEMAS`, excluding `mz_unsafe`), and at least one system
1111/// table must be referenced (constant queries like `SELECT 1` are rejected
1112/// to prevent misuse of the developer endpoint for arbitrary computation).
1113/// SHOW and EXPLAIN statements are allowed without table references.
1114fn validate_system_catalog_query(sql: &str) -> Result<(), McpRequestError> {
1115    // Parse the SQL to validate it
1116    let stmts = parse(sql).map_err(|e| {
1117        McpRequestError::QueryValidationFailed(format!("Failed to parse SQL: {}", e))
1118    })?;
1119
1120    if stmts.is_empty() {
1121        return Err(McpRequestError::QueryValidationFailed(
1122            "Empty query".to_string(),
1123        ));
1124    }
1125
1126    // Walk the AST to collect all table references
1127    let mut collector = TableReferenceCollector::new();
1128    for stmt in &stmts {
1129        collector.visit_statement(&stmt.ast);
1130    }
1131
1132    // Use the canonical system schema list, excluding mz_unsafe which contains
1133    // internal-only objects that should not be exposed to MCP clients.
1134    let is_allowed_schema =
1135        |s: &str| SYSTEM_SCHEMAS.contains(&s) && s != namespaces::MZ_UNSAFE_SCHEMA;
1136
1137    // Helper to check if a table reference is allowed. Unqualified references
1138    // are accepted here because execution uses a tight `search_path` containing
1139    // only system schemas (see `query_system_catalog`), so user-created views
1140    // like `public.mz_leak` cannot be reached by an unqualified name.
1141    let is_system_table = |(schema, table_name): &(Option<String>, String)| match schema {
1142        Some(s) => is_allowed_schema(s.as_str()),
1143        None => table_name.starts_with("mz_"),
1144    };
1145
1146    // Check that all table references are system tables
1147    let non_system_tables: Vec<String> = collector
1148        .tables
1149        .iter()
1150        .filter(|t| !is_system_table(t))
1151        .map(|(schema, table)| match schema {
1152            Some(s) => format!("{}.{}", s, table),
1153            None => table.clone(),
1154        })
1155        .collect();
1156
1157    if !non_system_tables.is_empty() {
1158        return Err(McpRequestError::QueryValidationFailed(format!(
1159            "Query references non-system tables: {}. Only system catalog tables (mz_*, pg_catalog, information_schema) are allowed.",
1160            non_system_tables.join(", ")
1161        )));
1162    }
1163
1164    // SHOW and EXPLAIN statements don't reference tables in the AST, but are safe
1165    // read-only operations. Only require system table references for SELECT.
1166    use mz_sql_parser::ast::Statement;
1167    let is_select = stmts.iter().any(|s| matches!(&s.ast, Statement::Select(_)));
1168
1169    if is_select && (collector.tables.is_empty() || !collector.tables.iter().any(is_system_table)) {
1170        return Err(McpRequestError::QueryValidationFailed(
1171            "Query must reference at least one system catalog table".to_string(),
1172        ));
1173    }
1174
1175    Ok(())
1176}
1177
1178#[cfg(test)]
1179mod tests {
1180    use super::*;
1181
1182    #[mz_ore::test]
1183    fn test_validate_readonly_query_select() {
1184        assert!(validate_readonly_query("SELECT * FROM mz_tables").is_ok());
1185        assert!(validate_readonly_query("SELECT 1 + 2").is_ok());
1186        assert!(validate_readonly_query("  SELECT 1  ").is_ok());
1187    }
1188
1189    #[mz_ore::test]
1190    fn test_validate_readonly_query_subqueries() {
1191        // Simple subquery in WHERE clause
1192        assert!(
1193            validate_readonly_query(
1194                "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM mz_columns)"
1195            )
1196            .is_ok()
1197        );
1198
1199        // Subquery in FROM clause
1200        assert!(
1201            validate_readonly_query(
1202                "SELECT * FROM (SELECT * FROM mz_tables WHERE name LIKE 'test%') AS t"
1203            )
1204            .is_ok()
1205        );
1206
1207        // Correlated subquery
1208        assert!(validate_readonly_query(
1209            "SELECT * FROM mz_tables t WHERE EXISTS (SELECT 1 FROM mz_columns c WHERE c.id = t.id)"
1210        )
1211        .is_ok());
1212
1213        // Nested subqueries
1214        assert!(validate_readonly_query(
1215            "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM mz_columns WHERE type IN (SELECT name FROM mz_types))"
1216        )
1217        .is_ok());
1218
1219        // Subquery with aggregation
1220        assert!(
1221            validate_readonly_query(
1222                "SELECT * FROM mz_tables WHERE id = (SELECT MAX(id) FROM mz_columns)"
1223            )
1224            .is_ok()
1225        );
1226    }
1227
1228    #[mz_ore::test]
1229    fn test_validate_readonly_query_show() {
1230        assert!(validate_readonly_query("SHOW CLUSTERS").is_ok());
1231        assert!(validate_readonly_query("SHOW TABLES").is_ok());
1232    }
1233
1234    #[mz_ore::test]
1235    fn test_validate_readonly_query_explain() {
1236        assert!(validate_readonly_query("EXPLAIN SELECT 1").is_ok());
1237    }
1238
1239    #[mz_ore::test]
1240    fn test_validate_readonly_query_rejects_writes() {
1241        assert!(validate_readonly_query("INSERT INTO t VALUES (1)").is_err());
1242        assert!(validate_readonly_query("UPDATE t SET a = 1").is_err());
1243        assert!(validate_readonly_query("DELETE FROM t").is_err());
1244        assert!(validate_readonly_query("CREATE TABLE t (a INT)").is_err());
1245        assert!(validate_readonly_query("DROP TABLE t").is_err());
1246    }
1247
1248    #[mz_ore::test]
1249    fn test_validate_readonly_query_rejects_multiple() {
1250        assert!(validate_readonly_query("SELECT 1; SELECT 2").is_err());
1251    }
1252
1253    #[mz_ore::test]
1254    fn test_validate_readonly_query_rejects_empty() {
1255        assert!(validate_readonly_query("").is_err());
1256        assert!(validate_readonly_query("   ").is_err());
1257    }
1258
1259    #[mz_ore::test]
1260    fn test_validate_system_catalog_query_accepts_mz_tables() {
1261        assert!(validate_system_catalog_query("SELECT * FROM mz_tables").is_ok());
1262        assert!(validate_system_catalog_query("SELECT * FROM mz_internal.mz_comments").is_ok());
1263        assert!(
1264            validate_system_catalog_query(
1265                "SELECT * FROM mz_tables t JOIN mz_columns c ON t.id = c.id"
1266            )
1267            .is_ok()
1268        );
1269    }
1270
1271    #[mz_ore::test]
1272    fn test_validate_system_catalog_query_subqueries() {
1273        // Subquery with mz_* tables
1274        assert!(
1275            validate_system_catalog_query(
1276                "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM mz_columns)"
1277            )
1278            .is_ok()
1279        );
1280
1281        // Nested subqueries with mz_* tables
1282        assert!(validate_system_catalog_query(
1283            "SELECT * FROM mz_tables WHERE id IN (SELECT table_id FROM mz_columns WHERE type IN (SELECT id FROM mz_types))"
1284        )
1285        .is_ok());
1286
1287        // Subquery in FROM clause
1288        assert!(
1289            validate_system_catalog_query(
1290                "SELECT * FROM (SELECT * FROM mz_tables WHERE name LIKE 'test%') AS t"
1291            )
1292            .is_ok()
1293        );
1294
1295        // Reject subqueries that reference non-mz_* tables
1296        assert!(
1297            validate_system_catalog_query(
1298                "SELECT * FROM mz_tables WHERE id IN (SELECT table_id FROM user_data)"
1299            )
1300            .is_err()
1301        );
1302
1303        // Reject mixed references in nested subqueries
1304        assert!(validate_system_catalog_query(
1305            "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM (SELECT id FROM user_table) AS t)"
1306        )
1307        .is_err());
1308    }
1309
1310    #[mz_ore::test]
1311    fn test_validate_system_catalog_query_rejects_user_tables() {
1312        assert!(validate_system_catalog_query("SELECT * FROM user_data").is_err());
1313        assert!(validate_system_catalog_query("SELECT * FROM my_table").is_err());
1314        // Security: reject queries that mention mz_ in a non-table context
1315        assert!(
1316            validate_system_catalog_query("SELECT * FROM user_data WHERE 'mz_' IS NOT NULL")
1317                .is_err()
1318        );
1319    }
1320
1321    #[mz_ore::test]
1322    fn test_validate_system_catalog_query_allows_functions() {
1323        // Function names should not be treated as table references
1324        assert!(
1325            validate_system_catalog_query(
1326                "SELECT date_part('year', now())::int4 AS y FROM mz_tables LIMIT 1"
1327            )
1328            .is_ok()
1329        );
1330        assert!(validate_system_catalog_query("SELECT length(name) FROM mz_tables").is_ok());
1331        assert!(
1332            validate_system_catalog_query(
1333                "SELECT count(*) FROM mz_sources WHERE now() > created_at"
1334            )
1335            .is_ok()
1336        );
1337    }
1338
1339    #[mz_ore::test]
1340    fn test_validate_system_catalog_query_schema_qualified() {
1341        // Qualified with allowed schemas should work
1342        assert!(validate_system_catalog_query("SELECT * FROM mz_catalog.mz_tables").is_ok());
1343        assert!(validate_system_catalog_query("SELECT * FROM mz_internal.mz_sessions").is_ok());
1344        assert!(validate_system_catalog_query("SELECT * FROM pg_catalog.pg_type").is_ok());
1345        assert!(validate_system_catalog_query("SELECT * FROM information_schema.tables").is_ok());
1346
1347        // Qualified with disallowed schema should fail
1348        assert!(validate_system_catalog_query("SELECT * FROM public.user_table").is_err());
1349        assert!(validate_system_catalog_query("SELECT * FROM myschema.mytable").is_err());
1350
1351        // mz_unsafe is a system schema but explicitly blocked for MCP
1352        assert!(
1353            validate_system_catalog_query("SELECT * FROM mz_unsafe.mz_some_table").is_err(),
1354            "mz_unsafe schema should be blocked even though it is a system schema"
1355        );
1356
1357        // Mixed: system and user schemas should fail
1358        assert!(
1359            validate_system_catalog_query(
1360                "SELECT * FROM mz_catalog.mz_tables JOIN public.user_data ON true"
1361            )
1362            .is_err()
1363        );
1364    }
1365
1366    #[mz_ore::test]
1367    fn test_validate_system_catalog_query_adversarial_cases() {
1368        // Try to sneak in user table via CTE
1369        assert!(
1370            validate_system_catalog_query(
1371                "WITH user_cte AS (SELECT * FROM user_data) \
1372                 SELECT * FROM mz_tables, user_cte"
1373            )
1374            .is_err(),
1375            "Should reject CTE referencing user table"
1376        );
1377
1378        // Complex multi-level CTE with user table buried deep
1379        assert!(
1380            validate_system_catalog_query(
1381                "WITH \
1382                   cte1 AS (SELECT * FROM mz_tables), \
1383                   cte2 AS (SELECT * FROM cte1), \
1384                   cte3 AS (SELECT * FROM user_data) \
1385                 SELECT * FROM cte2"
1386            )
1387            .is_err(),
1388            "Should reject CTE chain with user table"
1389        );
1390
1391        // Multiple joins - user table in the middle
1392        assert!(
1393            validate_system_catalog_query(
1394                "SELECT * FROM mz_tables t1 \
1395                 JOIN user_data u ON t1.id = u.id \
1396                 JOIN mz_sources s ON t1.id = s.id"
1397            )
1398            .is_err(),
1399            "Should reject multi-join with user table"
1400        );
1401
1402        // LEFT JOIN trying to hide user table
1403        assert!(
1404            validate_system_catalog_query(
1405                "SELECT * FROM mz_tables t \
1406                 LEFT JOIN user_data u ON t.id = u.table_id \
1407                 WHERE u.id IS NULL"
1408            )
1409            .is_err(),
1410            "Should reject LEFT JOIN with user table"
1411        );
1412
1413        // Nested subquery with user table in FROM
1414        assert!(
1415            validate_system_catalog_query(
1416                "SELECT * FROM mz_tables WHERE id IN \
1417                 (SELECT table_id FROM (SELECT * FROM user_data) AS u)"
1418            )
1419            .is_err(),
1420            "Should reject nested subquery with user table"
1421        );
1422
1423        // UNION trying to mix system and user data
1424        assert!(
1425            validate_system_catalog_query(
1426                "SELECT name FROM mz_tables \
1427                 UNION \
1428                 SELECT name FROM user_data"
1429            )
1430            .is_err(),
1431            "Should reject UNION with user table"
1432        );
1433
1434        // UNION ALL variation
1435        assert!(
1436            validate_system_catalog_query(
1437                "SELECT id FROM mz_sources \
1438                 UNION ALL \
1439                 SELECT id FROM products"
1440            )
1441            .is_err(),
1442            "Should reject UNION ALL with user table"
1443        );
1444
1445        // Cross join with user table
1446        assert!(
1447            validate_system_catalog_query("SELECT * FROM mz_tables CROSS JOIN user_data").is_err(),
1448            "Should reject CROSS JOIN with user table"
1449        );
1450
1451        // Subquery in SELECT clause referencing user table
1452        assert!(
1453            validate_system_catalog_query(
1454                "SELECT t.*, (SELECT COUNT(*) FROM user_data) AS cnt FROM mz_tables t"
1455            )
1456            .is_err(),
1457            "Should reject subquery in SELECT with user table"
1458        );
1459
1460        // Try to use a schema name that looks similar to allowed ones
1461        assert!(
1462            validate_system_catalog_query("SELECT * FROM mz_catalogg.fake_table").is_err(),
1463            "Should reject typo-squatting schema name"
1464        );
1465        assert!(
1466            validate_system_catalog_query("SELECT * FROM mz_catalog_hack.fake_table").is_err(),
1467            "Should reject fake schema with mz_catalog prefix"
1468        );
1469
1470        // Lateral join with user table
1471        assert!(
1472            validate_system_catalog_query(
1473                "SELECT * FROM mz_tables t, LATERAL (SELECT * FROM user_data WHERE id = t.id) u"
1474            )
1475            .is_err(),
1476            "Should reject LATERAL join with user table"
1477        );
1478
1479        // Valid complex query - all system tables
1480        assert!(
1481            validate_system_catalog_query(
1482                "WITH \
1483                   tables AS (SELECT * FROM mz_tables), \
1484                   sources AS (SELECT * FROM mz_sources) \
1485                 SELECT t.name, s.name \
1486                 FROM tables t \
1487                 JOIN sources s ON t.id = s.id \
1488                 WHERE t.id IN (SELECT id FROM mz_columns)"
1489            )
1490            .is_ok(),
1491            "Should allow complex query with only system tables"
1492        );
1493
1494        // Valid UNION of system tables
1495        assert!(
1496            validate_system_catalog_query(
1497                "SELECT name FROM mz_tables \
1498                 UNION \
1499                 SELECT name FROM mz_sources"
1500            )
1501            .is_ok(),
1502            "Should allow UNION of system tables"
1503        );
1504    }
1505
1506    #[mz_ore::test]
1507    fn test_validate_system_catalog_query_rejects_constant_queries() {
1508        // SELECT without any table reference should be rejected — the developer
1509        // endpoint is for system catalog queries, not arbitrary computation.
1510        assert!(
1511            validate_system_catalog_query("SELECT 1").is_err(),
1512            "Should reject constant SELECT with no table references"
1513        );
1514        assert!(
1515            validate_system_catalog_query("SELECT 1 + 2, 'hello'").is_err(),
1516            "Should reject constant expression SELECT"
1517        );
1518        assert!(
1519            validate_system_catalog_query("SELECT now()").is_err(),
1520            "Should reject function-only SELECT with no table references"
1521        );
1522    }
1523
1524    #[mz_ore::test]
1525    fn test_validate_system_catalog_query_rejects_mixed_tables() {
1526        assert!(
1527            validate_system_catalog_query(
1528                "SELECT * FROM mz_tables t JOIN user_data u ON t.id = u.table_id"
1529            )
1530            .is_err()
1531        );
1532    }
1533
1534    #[mz_ore::test]
1535    fn test_validate_system_catalog_query_allows_show() {
1536        // SHOW queries don't reference tables in the AST but are safe read-only ops
1537        assert!(
1538            validate_system_catalog_query("SHOW TABLES FROM mz_internal").is_ok(),
1539            "SHOW TABLES FROM mz_internal should be allowed"
1540        );
1541        assert!(
1542            validate_system_catalog_query("SHOW TABLES FROM mz_catalog").is_ok(),
1543            "SHOW TABLES FROM mz_catalog should be allowed"
1544        );
1545        assert!(
1546            validate_system_catalog_query("SHOW CLUSTERS").is_ok(),
1547            "SHOW CLUSTERS should be allowed"
1548        );
1549        assert!(
1550            validate_system_catalog_query("SHOW SOURCES").is_ok(),
1551            "SHOW SOURCES should be allowed"
1552        );
1553        assert!(
1554            validate_system_catalog_query("SHOW TABLES").is_ok(),
1555            "SHOW TABLES should be allowed"
1556        );
1557    }
1558
1559    #[mz_ore::test]
1560    fn test_validate_system_catalog_query_allows_explain() {
1561        assert!(
1562            validate_system_catalog_query("EXPLAIN SELECT * FROM mz_tables").is_ok(),
1563            "EXPLAIN of system table query should be allowed"
1564        );
1565        assert!(
1566            validate_system_catalog_query("EXPLAIN SELECT 1").is_ok(),
1567            "EXPLAIN SELECT 1 should be allowed"
1568        );
1569    }
1570
1571    // ── Query tool feature flag tests ──────────────────────────────────────
1572
1573    #[mz_ore::test(tokio::test)]
1574    async fn test_tools_list_agent_query_tool_disabled() {
1575        let result = handle_tools_list(McpEndpointType::Agent, false, 1_000_000)
1576            .await
1577            .unwrap();
1578        let McpResult::ToolsList(list) = result else {
1579            panic!("Expected ToolsList result");
1580        };
1581        let tool_names: Vec<&str> = list.tools.iter().map(|t| t.name.as_str()).collect();
1582        assert!(
1583            tool_names.contains(&"get_data_products"),
1584            "get_data_products should always be present"
1585        );
1586        assert!(
1587            tool_names.contains(&"get_data_product_details"),
1588            "get_data_product_details should always be present"
1589        );
1590        assert!(
1591            tool_names.contains(&"read_data_product"),
1592            "read_data_product should always be present"
1593        );
1594        assert!(
1595            !tool_names.contains(&"query"),
1596            "query tool should be hidden when disabled"
1597        );
1598    }
1599
1600    #[mz_ore::test(tokio::test)]
1601    async fn test_tools_list_agent_query_tool_enabled() {
1602        let result = handle_tools_list(McpEndpointType::Agent, true, 1_000_000)
1603            .await
1604            .unwrap();
1605        let McpResult::ToolsList(list) = result else {
1606            panic!("Expected ToolsList result");
1607        };
1608        let tool_names: Vec<&str> = list.tools.iter().map(|t| t.name.as_str()).collect();
1609        assert!(
1610            tool_names.contains(&"get_data_products"),
1611            "get_data_products should always be present"
1612        );
1613        assert!(
1614            tool_names.contains(&"get_data_product_details"),
1615            "get_data_product_details should always be present"
1616        );
1617        assert!(
1618            tool_names.contains(&"read_data_product"),
1619            "read_data_product should always be present"
1620        );
1621        assert!(
1622            tool_names.contains(&"query"),
1623            "query tool should be present when enabled"
1624        );
1625    }
1626
1627    #[mz_ore::test(tokio::test)]
1628    async fn test_tools_list_developer_unaffected_by_query_flag() {
1629        // Developer endpoint should not be affected by the query tool flag
1630        for flag in [true, false] {
1631            let result = handle_tools_list(McpEndpointType::Developer, flag, 1_000_000)
1632                .await
1633                .unwrap();
1634            let McpResult::ToolsList(list) = result else {
1635                panic!("Expected ToolsList result");
1636            };
1637            let tool_names: Vec<&str> = list.tools.iter().map(|t| t.name.as_str()).collect();
1638            assert!(
1639                tool_names.contains(&"query_system_catalog"),
1640                "query_system_catalog should always be present on developer"
1641            );
1642            assert!(
1643                !tool_names.contains(&"query"),
1644                "query tool should never appear on developer"
1645            );
1646        }
1647    }
1648
1649    // ── Response size cap tests ────────────────────────────────────────
1650
1651    #[mz_ore::test]
1652    fn test_format_rows_response_within_limit() {
1653        let rows = vec![vec![json!("a"), json!(1)], vec![json!("b"), json!(2)]];
1654        let result = format_rows_response(rows, 1_000_000).unwrap();
1655        let McpResult::ToolContent(content) = result else {
1656            panic!("Expected ToolContent");
1657        };
1658        assert_eq!(content.content.len(), 1);
1659        assert!(content.content[0].text.contains("\"a\""));
1660        assert!(content.content[0].text.contains("\"b\""));
1661    }
1662
1663    #[mz_ore::test]
1664    fn test_format_rows_response_errors_when_over_limit() {
1665        let rows: Vec<Vec<serde_json::Value>> = (0..100)
1666            .map(|i| vec![json!(format!("row_{}", i)), json!(i)])
1667            .collect();
1668        let err = format_rows_response(rows, 500).unwrap_err();
1669        let msg = err.to_string();
1670        assert!(
1671            msg.contains("exceeds the 500 byte limit"),
1672            "Error should mention the size limit, got: {msg}"
1673        );
1674        assert!(
1675            msg.contains("Use LIMIT or WHERE"),
1676            "Error should suggest narrowing the query, got: {msg}"
1677        );
1678    }
1679
1680    #[mz_ore::test]
1681    fn test_format_rows_response_empty_rows() {
1682        let rows: Vec<Vec<serde_json::Value>> = vec![];
1683        let result = format_rows_response(rows, 1000).unwrap();
1684        let McpResult::ToolContent(content) = result else {
1685            panic!("Expected ToolContent");
1686        };
1687        assert_eq!(content.content.len(), 1);
1688        assert_eq!(content.content[0].text, "[]");
1689    }
1690
1691    // ── Data product name validation tests ─────────────────────────────
1692
1693    #[mz_ore::test]
1694    fn test_safe_data_product_name_valid() {
1695        // Fully qualified quoted identifiers
1696        assert_eq!(
1697            safe_data_product_name(r#""materialize"."public"."my_view""#).unwrap(),
1698            r#""materialize"."public"."my_view""#
1699        );
1700        // Two-part name
1701        assert_eq!(
1702            safe_data_product_name(r#""public"."my_view""#).unwrap(),
1703            r#""public"."my_view""#
1704        );
1705        // Unquoted name gets quoted in stable mode
1706        assert_eq!(safe_data_product_name("my_view").unwrap(), r#""my_view""#);
1707    }
1708
1709    #[mz_ore::test]
1710    fn test_safe_data_product_name_rejects_empty() {
1711        assert!(safe_data_product_name("").is_err());
1712        assert!(safe_data_product_name("   ").is_err());
1713    }
1714
1715    #[mz_ore::test]
1716    fn test_safe_data_product_name_rejects_sql_injection() {
1717        // Attempted injection via semicolon
1718        assert!(safe_data_product_name("my_view; DROP TABLE users").is_err());
1719        // Attempted injection via subquery
1720        assert!(safe_data_product_name("my_view UNION SELECT * FROM secrets").is_err());
1721        // Multiple table references via comma
1722        assert!(safe_data_product_name("my_view, secrets").is_err());
1723        // SQL keywords after name are rejected by the parser
1724        assert!(safe_data_product_name("my_view WHERE 1=1 --").is_err());
1725    }
1726
1727    #[mz_ore::test]
1728    fn test_mcp_error_codes() {
1729        assert_eq!(
1730            McpRequestError::InvalidJsonRpcVersion.error_code(),
1731            error_codes::INVALID_REQUEST
1732        );
1733        assert_eq!(
1734            McpRequestError::MethodNotFound("test".to_string()).error_code(),
1735            error_codes::METHOD_NOT_FOUND
1736        );
1737        assert_eq!(
1738            McpRequestError::QueryExecutionFailed("test".to_string()).error_code(),
1739            error_codes::INTERNAL_ERROR
1740        );
1741    }
1742}