Skip to main content

mz_environmentd/http/
mcp.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Model Context Protocol (MCP) HTTP handlers.
11//!
12//! Exposes Materialize data products to AI agents via JSON-RPC 2.0 over HTTP POST.
13//!
14//! ## Endpoints
15//!
16//! - `/api/mcp/agent` - User data products for customer AI agents
17//! - `/api/mcp/developer` - System catalog (`mz_*`) for troubleshooting
18//!
19//! ## Tools
20//!
21//! **Agent:** `get_data_products`, `get_data_product_details`, `read_data_product`, `query`
22//! **Developer:** `query_system_catalog`
23//!
24//! Data products are discovered via `mz_internal.mz_mcp_data_products` system view.
25
26use std::sync::Arc;
27use std::time::Duration;
28
29use anyhow::anyhow;
30use axum::Extension;
31use axum::Json;
32use axum::response::IntoResponse;
33use http::{HeaderMap, HeaderValue, StatusCode};
34use mz_adapter_types::dyncfgs::{
35    ENABLE_MCP_AGENT, ENABLE_MCP_AGENT_QUERY_TOOL, ENABLE_MCP_DEVELOPER, MCP_MAX_RESPONSE_SIZE,
36};
37use mz_repr::namespaces::{self, SYSTEM_SCHEMAS};
38use mz_sql::parse::parse;
39use mz_sql::session::metadata::SessionMetadata;
40use mz_sql_parser::ast::display::{AstDisplay, escaped_string_literal};
41use mz_sql_parser::ast::visit::{self, Visit};
42use mz_sql_parser::ast::{Raw, RawItemName};
43use mz_sql_parser::parser::parse_item_name;
44use serde::{Deserialize, Serialize};
45use serde_json::json;
46use thiserror::Error;
47use tracing::{debug, warn};
48
49use crate::http::AuthedClient;
50use crate::http::sql::{SqlRequest, SqlResponse, SqlResult, execute_request};
51
52// To add a new tool: add entry to tools/list, add handler function, add dispatch case.
53
54/// JSON-RPC protocol version used in all MCP requests and responses.
55const JSONRPC_VERSION: &str = "2.0";
56
57/// MCP protocol version returned in the `initialize` response.
58/// Spec: <https://modelcontextprotocol.io/specification/2025-11-25>
59const MCP_PROTOCOL_VERSION: &str = "2025-11-25";
60
61/// Maximum time an MCP tool call can run before the HTTP response is returned.
62/// Note: this returns a clean JSON-RPC error to the caller, but the underlying
63/// query may continue running on the cluster until it completes or is cancelled
64/// separately (see database-issues#9947 for SELECT timeout gaps).
65const MCP_REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
66
67// Discovery uses the lightweight view (no JSON schema computation).
68const DISCOVERY_QUERY: &str = "SELECT * FROM mz_internal.mz_mcp_data_products";
69// Details uses the full view with JSON schema.
70const DETAILS_QUERY_PREFIX: &str =
71    "SELECT * FROM mz_internal.mz_mcp_data_product_details WHERE object_name = ";
72
73/// MCP request errors, mapped to JSON-RPC error codes.
74#[derive(Debug, Error)]
75enum McpRequestError {
76    #[error("Invalid JSON-RPC version: expected 2.0")]
77    InvalidJsonRpcVersion,
78    #[error("Method not found: {0}")]
79    #[allow(dead_code)] // Handled by serde deserialization, kept for error mapping
80    MethodNotFound(String),
81    #[error("Tool not found: {0}")]
82    ToolNotFound(String),
83    #[error("Data product not found: {0}")]
84    DataProductNotFound(String),
85    #[error("Query validation failed: {0}")]
86    QueryValidationFailed(String),
87    #[error("Query execution failed: {0}")]
88    QueryExecutionFailed(String),
89    #[error("Internal error: {0}")]
90    Internal(#[from] anyhow::Error),
91}
92
93impl McpRequestError {
94    fn error_code(&self) -> i32 {
95        match self {
96            Self::InvalidJsonRpcVersion => error_codes::INVALID_REQUEST,
97            Self::MethodNotFound(_) => error_codes::METHOD_NOT_FOUND,
98            Self::ToolNotFound(_) => error_codes::INVALID_PARAMS,
99            Self::DataProductNotFound(_) => error_codes::INVALID_PARAMS,
100            Self::QueryValidationFailed(_) => error_codes::INVALID_PARAMS,
101            Self::QueryExecutionFailed(_) | Self::Internal(_) => error_codes::INTERNAL_ERROR,
102        }
103    }
104
105    fn error_type(&self) -> &'static str {
106        match self {
107            Self::InvalidJsonRpcVersion => "InvalidRequest",
108            Self::MethodNotFound(_) => "MethodNotFound",
109            Self::ToolNotFound(_) => "ToolNotFound",
110            Self::DataProductNotFound(_) => "DataProductNotFound",
111            Self::QueryValidationFailed(_) => "ValidationError",
112            Self::QueryExecutionFailed(_) => "ExecutionError",
113            Self::Internal(_) => "InternalError",
114        }
115    }
116}
117
118/// JSON-RPC 2.0 request. Requests have `id`; notifications don't.
119#[derive(Debug, Deserialize)]
120pub(crate) struct McpRequest {
121    jsonrpc: String,
122    id: Option<serde_json::Value>,
123    #[serde(flatten)]
124    method: McpMethod,
125}
126
127/// MCP method variants with their associated parameters.
128#[derive(Debug, Deserialize)]
129#[serde(tag = "method", content = "params")]
130enum McpMethod {
131    /// Initialize method - params accepted but not currently used
132    #[serde(rename = "initialize")]
133    Initialize(#[allow(dead_code)] InitializeParams),
134    #[serde(rename = "tools/list")]
135    ToolsList,
136    #[serde(rename = "tools/call")]
137    ToolsCall(ToolsCallParams),
138    /// Catch-all for unknown methods (e.g. `notifications/initialized`)
139    #[serde(other)]
140    Unknown,
141}
142
143impl std::fmt::Display for McpMethod {
144    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145        match self {
146            McpMethod::Initialize(_) => write!(f, "initialize"),
147            McpMethod::ToolsList => write!(f, "tools/list"),
148            McpMethod::ToolsCall(_) => write!(f, "tools/call"),
149            McpMethod::Unknown => write!(f, "unknown"),
150        }
151    }
152}
153
154#[derive(Debug, Deserialize)]
155struct InitializeParams {
156    /// Protocol version from client. Not currently validated but accepted for MCP compliance.
157    #[serde(rename = "protocolVersion")]
158    #[allow(dead_code)]
159    protocol_version: String,
160    /// Client capabilities. Not currently used but accepted for MCP compliance.
161    #[serde(default)]
162    #[allow(dead_code)]
163    capabilities: serde_json::Value,
164    /// Client information (name, version). Not currently used but accepted for MCP compliance.
165    #[serde(rename = "clientInfo")]
166    #[allow(dead_code)]
167    client_info: Option<ClientInfo>,
168}
169
170#[derive(Debug, Deserialize)]
171struct ClientInfo {
172    #[allow(dead_code)]
173    name: String,
174    #[allow(dead_code)]
175    version: String,
176}
177
178/// Tool call parameters, deserialized via adjacently tagged enum.
179/// Serde maps `name` to the variant and `arguments` to the variant's data.
180#[derive(Debug, Deserialize)]
181#[serde(tag = "name", content = "arguments")]
182#[serde(rename_all = "snake_case")]
183enum ToolsCallParams {
184    // Agent endpoint tools
185    // Uses an ignored empty struct so MCP clients sending `"arguments": {}` can deserialize.
186    GetDataProducts(#[serde(default)] ()),
187    GetDataProductDetails(GetDataProductDetailsParams),
188    ReadDataProduct(ReadDataProductParams),
189    Query(QueryParams),
190    // Developer endpoint tools
191    QuerySystemCatalog(QuerySystemCatalogParams),
192}
193
194impl std::fmt::Display for ToolsCallParams {
195    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196        match self {
197            ToolsCallParams::GetDataProducts(_) => write!(f, "get_data_products"),
198            ToolsCallParams::GetDataProductDetails(_) => write!(f, "get_data_product_details"),
199            ToolsCallParams::ReadDataProduct(_) => write!(f, "read_data_product"),
200            ToolsCallParams::Query(_) => write!(f, "query"),
201            ToolsCallParams::QuerySystemCatalog(_) => write!(f, "query_system_catalog"),
202        }
203    }
204}
205
206#[derive(Debug, Deserialize)]
207struct GetDataProductDetailsParams {
208    name: String,
209}
210
211#[derive(Debug, Deserialize)]
212struct ReadDataProductParams {
213    name: String,
214    #[serde(default = "default_read_limit")]
215    limit: u32,
216    cluster: Option<String>,
217}
218
219fn default_read_limit() -> u32 {
220    500
221}
222
223/// Maximum number of rows that can be returned by read_data_product.
224const MAX_READ_LIMIT: u32 = 1000;
225
226#[derive(Debug, Deserialize)]
227struct QueryParams {
228    cluster: String,
229    sql_query: String,
230}
231
232#[derive(Debug, Deserialize)]
233struct QuerySystemCatalogParams {
234    sql_query: String,
235}
236
237#[derive(Debug, Serialize)]
238struct McpResponse {
239    jsonrpc: String,
240    id: serde_json::Value,
241    #[serde(skip_serializing_if = "Option::is_none")]
242    result: Option<McpResult>,
243    #[serde(skip_serializing_if = "Option::is_none")]
244    error: Option<McpError>,
245}
246
247/// Typed MCP response results.
248#[derive(Debug, Serialize)]
249#[serde(untagged)]
250enum McpResult {
251    Initialize(InitializeResult),
252    ToolsList(ToolsListResult),
253    ToolContent(ToolContentResult),
254}
255
256#[derive(Debug, Serialize)]
257struct InitializeResult {
258    #[serde(rename = "protocolVersion")]
259    protocol_version: String,
260    capabilities: Capabilities,
261    #[serde(rename = "serverInfo")]
262    server_info: ServerInfo,
263    #[serde(skip_serializing_if = "Option::is_none")]
264    instructions: Option<String>,
265}
266
267#[derive(Debug, Serialize)]
268struct Capabilities {
269    tools: serde_json::Value,
270}
271
272#[derive(Debug, Serialize)]
273struct ServerInfo {
274    name: String,
275    version: String,
276}
277
278#[derive(Debug, Serialize)]
279struct ToolsListResult {
280    tools: Vec<ToolDefinition>,
281}
282
283#[derive(Debug, Serialize)]
284struct ToolDefinition {
285    name: String,
286    #[serde(skip_serializing_if = "Option::is_none")]
287    title: Option<String>,
288    description: String,
289    #[serde(rename = "inputSchema")]
290    input_schema: serde_json::Value,
291    #[serde(skip_serializing_if = "Option::is_none")]
292    annotations: Option<ToolAnnotations>,
293}
294
295/// MCP 2025-11-25 tool annotations that describe tool behavior.
296/// These hints help clients make trust and safety decisions.
297#[derive(Debug, Serialize)]
298struct ToolAnnotations {
299    #[serde(rename = "readOnlyHint", skip_serializing_if = "Option::is_none")]
300    read_only_hint: Option<bool>,
301    #[serde(rename = "destructiveHint", skip_serializing_if = "Option::is_none")]
302    destructive_hint: Option<bool>,
303    #[serde(rename = "idempotentHint", skip_serializing_if = "Option::is_none")]
304    idempotent_hint: Option<bool>,
305    #[serde(rename = "openWorldHint", skip_serializing_if = "Option::is_none")]
306    open_world_hint: Option<bool>,
307}
308
309/// Annotations for all MCP tools: read-only, non-destructive, idempotent.
310const READ_ONLY_ANNOTATIONS: ToolAnnotations = ToolAnnotations {
311    read_only_hint: Some(true),
312    destructive_hint: Some(false),
313    idempotent_hint: Some(true),
314    open_world_hint: Some(false),
315};
316
317#[derive(Debug, Serialize)]
318struct ToolContentResult {
319    content: Vec<ContentBlock>,
320    #[serde(rename = "isError")]
321    is_error: bool,
322}
323
324#[derive(Debug, Serialize)]
325struct ContentBlock {
326    #[serde(rename = "type")]
327    content_type: String,
328    text: String,
329}
330
331/// JSON-RPC 2.0 error codes.
332mod error_codes {
333    pub const INVALID_REQUEST: i32 = -32600;
334    pub const METHOD_NOT_FOUND: i32 = -32601;
335    pub const INVALID_PARAMS: i32 = -32602;
336    pub const INTERNAL_ERROR: i32 = -32603;
337}
338
339#[derive(Debug, Serialize)]
340struct McpError {
341    code: i32,
342    message: String,
343    #[serde(skip_serializing_if = "Option::is_none")]
344    data: Option<serde_json::Value>,
345}
346
347impl From<McpRequestError> for McpError {
348    fn from(err: McpRequestError) -> Self {
349        McpError {
350            code: err.error_code(),
351            message: err.to_string(),
352            data: Some(json!({
353                "error_type": err.error_type(),
354            })),
355        }
356    }
357}
358
359#[derive(Debug, Clone, Copy)]
360enum McpEndpointType {
361    Agent,
362    Developer,
363}
364
365impl std::fmt::Display for McpEndpointType {
366    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
367        match self {
368            McpEndpointType::Agent => write!(f, "agent"),
369            McpEndpointType::Developer => write!(f, "developer"),
370        }
371    }
372}
373
374/// MCP 2025-11-25 requires servers to return 405 for GET requests
375/// on endpoints that only support POST.
376pub async fn handle_mcp_method_not_allowed() -> impl IntoResponse {
377    StatusCode::METHOD_NOT_ALLOWED
378}
379
380/// Agent endpoint: exposes user data products.
381pub async fn handle_mcp_agent(
382    headers: HeaderMap,
383    Extension(allowed_origins): Extension<Arc<Vec<HeaderValue>>>,
384    client: AuthedClient,
385    Json(body): Json<McpRequest>,
386) -> axum::response::Response {
387    if let Some(resp) = validate_origin(&headers, &allowed_origins) {
388        return resp;
389    }
390    handle_mcp_request(client, body, McpEndpointType::Agent)
391        .await
392        .into_response()
393}
394
395/// Developer endpoint: exposes system catalog (mz_*) only.
396pub async fn handle_mcp_developer(
397    headers: HeaderMap,
398    Extension(allowed_origins): Extension<Arc<Vec<HeaderValue>>>,
399    client: AuthedClient,
400    Json(body): Json<McpRequest>,
401) -> axum::response::Response {
402    if let Some(resp) = validate_origin(&headers, &allowed_origins) {
403        return resp;
404    }
405    handle_mcp_request(client, body, McpEndpointType::Developer)
406        .await
407        .into_response()
408}
409
410/// Validates the Origin header against the CORS allowlist to prevent DNS
411/// rebinding attacks (MCP spec 2025-11-25). Returns Some(403) if Origin is
412/// present but not on the allowlist. Returns None if absent (non-browser
413/// client) or allowed.
414///
415/// Note: this server-side check is required in addition to the CorsLayer.
416/// CorsLayer only controls response headers and can be bypassed when the
417/// attacker arranges same-origin DNS rebinding (no preflight fires).
418fn validate_origin(
419    headers: &HeaderMap,
420    allowed: &[HeaderValue],
421) -> Option<axum::response::Response> {
422    let origin = headers.get(http::header::ORIGIN)?;
423    if mz_http_util::origin_is_allowed(origin, allowed) {
424        return None;
425    }
426    warn!(
427        origin = ?origin,
428        "MCP request rejected: origin not in allowlist",
429    );
430    Some(StatusCode::FORBIDDEN.into_response())
431}
432
433async fn handle_mcp_request(
434    mut client: AuthedClient,
435    request: McpRequest,
436    endpoint_type: McpEndpointType,
437) -> impl IntoResponse {
438    // Check the per-endpoint feature flag via a catalog snapshot, similar to frontend_peek.rs.
439    let catalog = client.client.catalog_snapshot("mcp").await;
440    let dyncfgs = catalog.system_config().dyncfgs();
441    let enabled = match endpoint_type {
442        McpEndpointType::Agent => ENABLE_MCP_AGENT.get(dyncfgs),
443        McpEndpointType::Developer => ENABLE_MCP_DEVELOPER.get(dyncfgs),
444    };
445    if !enabled {
446        debug!(endpoint = %endpoint_type, "MCP endpoint disabled by feature flag");
447        return StatusCode::SERVICE_UNAVAILABLE.into_response();
448    }
449
450    let query_tool_enabled = ENABLE_MCP_AGENT_QUERY_TOOL.get(dyncfgs);
451    let max_response_size = MCP_MAX_RESPONSE_SIZE.get(dyncfgs);
452
453    let user = client.client.session().user().name.clone();
454    let is_notification = request.id.is_none();
455
456    debug!(
457        method = %request.method,
458        endpoint = %endpoint_type,
459        user = %user,
460        is_notification = is_notification,
461        "MCP request received"
462    );
463
464    // Handle notifications (no response needed)
465    if is_notification {
466        debug!(method = %request.method, "Received notification (no response will be sent)");
467        return StatusCode::OK.into_response();
468    }
469
470    let request_id = request.id.clone().unwrap_or(serde_json::Value::Null);
471
472    // Spawn task for fault isolation, with a timeout safety net.
473    let result = tokio::time::timeout(
474        MCP_REQUEST_TIMEOUT,
475        mz_ore::task::spawn(|| "mcp_request", async move {
476            handle_mcp_request_inner(
477                &mut client,
478                request,
479                endpoint_type,
480                query_tool_enabled,
481                max_response_size,
482            )
483            .await
484        }),
485    )
486    .await;
487
488    let response = match result {
489        Ok(inner) => inner,
490        Err(_elapsed) => {
491            warn!(
492                endpoint = %endpoint_type,
493                timeout = ?MCP_REQUEST_TIMEOUT,
494                "MCP request timed out",
495            );
496            McpResponse {
497                jsonrpc: JSONRPC_VERSION.to_string(),
498                id: request_id,
499                result: None,
500                error: Some(
501                    McpRequestError::QueryExecutionFailed(format!(
502                        "Request timed out after {} seconds.",
503                        MCP_REQUEST_TIMEOUT.as_secs(),
504                    ))
505                    .into(),
506                ),
507            }
508        }
509    };
510
511    (StatusCode::OK, Json(response)).into_response()
512}
513
514async fn handle_mcp_request_inner(
515    client: &mut AuthedClient,
516    request: McpRequest,
517    endpoint_type: McpEndpointType,
518    query_tool_enabled: bool,
519    max_response_size: usize,
520) -> McpResponse {
521    // Extract request ID (guaranteed to be Some since notifications are filtered earlier)
522    let request_id = request.id.clone().unwrap_or(serde_json::Value::Null);
523
524    let result = handle_mcp_method(
525        client,
526        &request,
527        endpoint_type,
528        query_tool_enabled,
529        max_response_size,
530    )
531    .await;
532
533    match result {
534        Ok(result_value) => McpResponse {
535            jsonrpc: JSONRPC_VERSION.to_string(),
536            id: request_id,
537            result: Some(result_value),
538            error: None,
539        },
540        Err(e) => {
541            // Log non-trivial errors
542            if !matches!(
543                e,
544                McpRequestError::MethodNotFound(_) | McpRequestError::InvalidJsonRpcVersion
545            ) {
546                warn!(error = %e, method = %request.method, "MCP method execution failed");
547            }
548            McpResponse {
549                jsonrpc: JSONRPC_VERSION.to_string(),
550                id: request_id,
551                result: None,
552                error: Some(e.into()),
553            }
554        }
555    }
556}
557
558async fn handle_mcp_method(
559    client: &mut AuthedClient,
560    request: &McpRequest,
561    endpoint_type: McpEndpointType,
562    query_tool_enabled: bool,
563    max_response_size: usize,
564) -> Result<McpResult, McpRequestError> {
565    // Validate JSON-RPC version
566    if request.jsonrpc != JSONRPC_VERSION {
567        return Err(McpRequestError::InvalidJsonRpcVersion);
568    }
569
570    // Handle different MCP methods using pattern matching
571    match &request.method {
572        McpMethod::Initialize(_) => {
573            debug!(endpoint = %endpoint_type, "Processing initialize");
574            handle_initialize(endpoint_type).await
575        }
576        McpMethod::ToolsList => {
577            debug!(endpoint = %endpoint_type, "Processing tools/list");
578            handle_tools_list(endpoint_type, query_tool_enabled, max_response_size).await
579        }
580        McpMethod::ToolsCall(params) => {
581            debug!(tool = %params, endpoint = %endpoint_type, "Processing tools/call");
582            handle_tools_call(
583                client,
584                params,
585                endpoint_type,
586                query_tool_enabled,
587                max_response_size,
588            )
589            .await
590        }
591        McpMethod::Unknown => Err(McpRequestError::MethodNotFound(
592            "unknown method".to_string(),
593        )),
594    }
595}
596
597async fn handle_initialize(endpoint_type: McpEndpointType) -> Result<McpResult, McpRequestError> {
598    Ok(McpResult::Initialize(InitializeResult {
599        protocol_version: MCP_PROTOCOL_VERSION.to_string(),
600        capabilities: Capabilities { tools: json!({}) },
601        server_info: ServerInfo {
602            name: format!("materialize-mcp-{}", endpoint_type),
603            version: env!("CARGO_PKG_VERSION").to_string(),
604        },
605        instructions: None,
606    }))
607}
608
609async fn handle_tools_list(
610    endpoint_type: McpEndpointType,
611    query_tool_enabled: bool,
612    max_response_size: usize,
613) -> Result<McpResult, McpRequestError> {
614    let size_hint = format!("Response limit: {} MB.", max_response_size / 1_000_000);
615
616    let tools = match endpoint_type {
617        McpEndpointType::Agent => {
618            let mut tools = vec![
619                ToolDefinition {
620                    name: "get_data_products".to_string(),
621                    title: Some("List Data Products".to_string()),
622                    description: "Discover all available real-time data views (data products) that represent business entities like customers, orders, products, etc. Each data product provides fresh, queryable data with defined schemas. Use this first to see what data is available before querying specific information.".to_string(),
623                    input_schema: json!({
624                        "type": "object",
625                        "properties": {},
626                        "required": []
627                    }),
628                    annotations: Some(READ_ONLY_ANNOTATIONS),
629                },
630                ToolDefinition {
631                    name: "get_data_product_details".to_string(),
632                    title: Some("Get Data Product Details".to_string()),
633                    description: "Get the complete schema and structure of a specific data product. This shows you exactly what fields are available, their types, and what data you can query. Use this after finding a data product from get_data_products() to understand how to query it.".to_string(),
634                    input_schema: json!({
635                        "type": "object",
636                        "properties": {
637                            "name": {
638                                "type": "string",
639                                "description": "Exact name of the data product from get_data_products() list"
640                            }
641                        },
642                        "required": ["name"]
643                    }),
644                    annotations: Some(READ_ONLY_ANNOTATIONS),
645                },
646                ToolDefinition {
647                    name: "read_data_product".to_string(),
648                    title: Some("Read Data Product".to_string()),
649                    description: format!("Read rows from a specific data product. Returns up to `limit` rows (default 500, max 1000). The data product must exist in the catalog (use get_data_products() to discover available products). Use this to retrieve actual data from a known data product. {size_hint}"),
650                    input_schema: json!({
651                        "type": "object",
652                        "properties": {
653                            "name": {
654                                "type": "string",
655                                "description": "Exact fully-qualified name of the data product (e.g. '\"materialize\".\"schema\".\"view_name\"')"
656                            },
657                            "limit": {
658                                "type": "integer",
659                                "description": "Maximum number of rows to return (default 500, max 1000)",
660                                "default": 500
661                            },
662                            "cluster": {
663                                "type": "string",
664                                "description": "Optional cluster override. If omitted, uses the cluster from the data product catalog."
665                            }
666                        },
667                        "required": ["name"]
668                    }),
669                    annotations: Some(READ_ONLY_ANNOTATIONS),
670                },
671            ];
672            if query_tool_enabled {
673                tools.push(ToolDefinition {
674                    name: "query".to_string(),
675                    title: Some("Query Data Products".to_string()),
676                    description: format!("Execute SQL queries against real-time data products to retrieve current business information. Use standard PostgreSQL syntax. You can JOIN multiple data products together, but ONLY if they are all hosted on the same cluster. Always specify the cluster parameter from the data product details. This provides fresh, up-to-date results from materialized views. {size_hint}"),
677                    input_schema: json!({
678                        "type": "object",
679                        "properties": {
680                            "cluster": {
681                                "type": "string",
682                                "description": "Exact cluster name from the data product details - required for query execution"
683                            },
684                            "sql_query": {
685                                "type": "string",
686                                "description": "PostgreSQL-compatible SELECT statement to retrieve data. Use the fully qualified data product name exactly as provided (with double quotes). You can JOIN multiple data products, but only those on the same cluster."
687                            }
688                        },
689                        "required": ["cluster", "sql_query"]
690                    }),
691                    annotations: Some(READ_ONLY_ANNOTATIONS),
692                });
693            }
694            tools
695        }
696        McpEndpointType::Developer => {
697            vec![ToolDefinition {
698                name: "query_system_catalog".to_string(),
699                title: Some("Query System Catalog".to_string()),
700                description: concat!(
701                    "Query Materialize system catalog tables for troubleshooting and observability. ",
702                    "Only mz_*, pg_catalog, and information_schema tables are accessible.\n\n",
703                    "Key tables by scenario:\n",
704                    "- Freshness: mz_internal.mz_wallclock_global_lag_recent_history, mz_internal.mz_materialization_lag, mz_internal.mz_hydration_statuses\n",
705                    "- Memory: mz_internal.mz_cluster_replica_utilization, mz_internal.mz_cluster_replica_metrics, mz_internal.mz_dataflow_arrangement_sizes\n",
706                    "- Cluster health: mz_internal.mz_cluster_replica_statuses, mz_catalog.mz_cluster_replicas\n",
707                    "- Source/Sink health: mz_internal.mz_source_statuses, mz_internal.mz_sink_statuses, mz_internal.mz_source_statistics, mz_internal.mz_sink_statistics\n",
708                    "- Object catalog: mz_catalog.mz_objects (all objects), mz_catalog.mz_tables, mz_catalog.mz_materialized_views, mz_catalog.mz_sources, mz_catalog.mz_sinks\n\n",
709                    "Use SHOW TABLES FROM mz_internal or mz_catalog to discover more tables.",
710                ).to_owned() + &format!(" {size_hint}"),
711                input_schema: json!({
712                    "type": "object",
713                    "properties": {
714                        "sql_query": {
715                            "type": "string",
716                            "description": "PostgreSQL-compatible SELECT, SHOW, or EXPLAIN query referencing mz_* system catalog tables"
717                        }
718                    },
719                    "required": ["sql_query"]
720                }),
721                annotations: Some(READ_ONLY_ANNOTATIONS),
722            }]
723        }
724    };
725
726    Ok(McpResult::ToolsList(ToolsListResult { tools }))
727}
728
729async fn handle_tools_call(
730    client: &mut AuthedClient,
731    params: &ToolsCallParams,
732    endpoint_type: McpEndpointType,
733    query_tool_enabled: bool,
734    max_response_size: usize,
735) -> Result<McpResult, McpRequestError> {
736    match (endpoint_type, params) {
737        (McpEndpointType::Agent, ToolsCallParams::GetDataProducts(_)) => {
738            get_data_products(client, max_response_size).await
739        }
740        (McpEndpointType::Agent, ToolsCallParams::GetDataProductDetails(p)) => {
741            get_data_product_details(client, &p.name, max_response_size).await
742        }
743        (McpEndpointType::Agent, ToolsCallParams::ReadDataProduct(p)) => {
744            read_data_product(
745                client,
746                &p.name,
747                p.limit,
748                p.cluster.as_deref(),
749                max_response_size,
750            )
751            .await
752        }
753        (McpEndpointType::Agent, ToolsCallParams::Query(_)) if !query_tool_enabled => {
754            Err(McpRequestError::ToolNotFound(
755                "query tool is not available. Use get_data_products, get_data_product_details, and read_data_product instead.".to_string(),
756            ))
757        }
758        (McpEndpointType::Agent, ToolsCallParams::Query(p)) => {
759            execute_query(client, &p.cluster, &p.sql_query, max_response_size).await
760        }
761        (McpEndpointType::Developer, ToolsCallParams::QuerySystemCatalog(p)) => {
762            query_system_catalog(client, &p.sql_query, max_response_size).await
763        }
764        // Tool called on wrong endpoint
765        (endpoint, tool) => Err(McpRequestError::ToolNotFound(format!(
766            "{} is not available on {} endpoint",
767            tool, endpoint
768        ))),
769    }
770}
771
772/// Execute SQL via `execute_request` from sql.rs.
773async fn execute_sql(
774    client: &mut AuthedClient,
775    query: &str,
776) -> Result<Vec<Vec<serde_json::Value>>, McpRequestError> {
777    let mut response = SqlResponse::new();
778
779    execute_request(
780        client,
781        SqlRequest::Simple {
782            query: query.to_string(),
783        },
784        &mut response,
785    )
786    .await
787    .map_err(|e| McpRequestError::QueryExecutionFailed(e.to_string()))?;
788
789    // Extract the result with rows (the user's single SELECT/SHOW query)
790    // Other results will be OK (from BEGIN, SET, COMMIT) or Err
791    for result in response.results {
792        match result {
793            SqlResult::Rows { rows, .. } => return Ok(rows),
794            SqlResult::Err { error, .. } => {
795                return Err(McpRequestError::QueryExecutionFailed(error.message));
796            }
797            SqlResult::Ok { .. } => continue,
798        }
799    }
800
801    Err(McpRequestError::QueryExecutionFailed(
802        "Query did not return any results".to_string(),
803    ))
804}
805
806/// Serialize rows to JSON and enforce the response size cap.
807///
808/// If the serialized response exceeds `max_size` bytes, returns an error
809/// telling the agent to narrow its query. This mirrors how the HTTP SQL
810/// endpoint handles `max_result_size` in sql.rs — fail cleanly rather
811/// than silently truncating.
812fn format_rows_response(
813    rows: Vec<Vec<serde_json::Value>>,
814    max_size: usize,
815) -> Result<McpResult, McpRequestError> {
816    let text =
817        serde_json::to_string_pretty(&rows).map_err(|e| McpRequestError::Internal(anyhow!(e)))?;
818
819    if text.len() > max_size {
820        return Err(McpRequestError::QueryExecutionFailed(format!(
821            "Response size ({} bytes) exceeds the {} byte limit. \
822             Use LIMIT or WHERE to narrow your query.",
823            text.len(),
824            max_size,
825        )));
826    }
827
828    Ok(McpResult::ToolContent(ToolContentResult {
829        content: vec![ContentBlock {
830            content_type: "text".to_string(),
831            text,
832        }],
833        is_error: false,
834    }))
835}
836
837async fn get_data_products(
838    client: &mut AuthedClient,
839    max_response_size: usize,
840) -> Result<McpResult, McpRequestError> {
841    debug!("Executing get_data_products");
842    let rows = execute_sql(client, DISCOVERY_QUERY).await?;
843    debug!("get_data_products returned {} rows", rows.len());
844
845    format_rows_response(rows, max_response_size)
846}
847
848async fn get_data_product_details(
849    client: &mut AuthedClient,
850    name: &str,
851    max_response_size: usize,
852) -> Result<McpResult, McpRequestError> {
853    debug!(name = %name, "Executing get_data_product_details");
854
855    let query = format!("{}{}", DETAILS_QUERY_PREFIX, escaped_string_literal(name));
856
857    let rows = execute_sql(client, &query).await?;
858
859    if rows.is_empty() {
860        return Err(McpRequestError::DataProductNotFound(name.to_string()));
861    }
862
863    format_rows_response(rows, max_response_size)
864}
865
866/// Parses a data product name and returns it safely quoted for SQL interpolation.
867///
868/// Uses the SQL parser to validate the name as an `UnresolvedItemName`, then
869/// formats it with `FormatMode::Stable` so every identifier part is
870/// double-quoted with proper escaping. This prevents SQL injection regardless
871/// of the input.
872fn safe_data_product_name(name: &str) -> Result<String, McpRequestError> {
873    let name = name.trim();
874    if name.is_empty() {
875        return Err(McpRequestError::QueryValidationFailed(
876            "Data product name cannot be empty".to_string(),
877        ));
878    }
879
880    let parsed = parse_item_name(name).map_err(|_| {
881        McpRequestError::QueryValidationFailed(format!(
882            "Invalid data product name: {}. Expected a valid object name, \
883             e.g. '\"database\".\"schema\".\"name\"' or 'my_view'",
884            name
885        ))
886    })?;
887
888    // Stable formatting forces all identifiers to be double-quoted,
889    // so SQL keywords and special characters cannot escape.
890    Ok(parsed.to_ast_string_stable())
891}
892
893/// Read rows from a data product. Issues a single read-only query.
894///
895/// When `cluster_override` is provided, sets the cluster explicitly.
896/// Otherwise the query runs on the session's default cluster.
897///
898/// The name is expected to come from `get_data_products()` / `get_data_product_details()`.
899/// The query runs inside a READ ONLY transaction, preventing mutations.
900async fn read_data_product(
901    client: &mut AuthedClient,
902    name: &str,
903    limit: u32,
904    cluster_override: Option<&str>,
905    max_response_size: usize,
906) -> Result<McpResult, McpRequestError> {
907    debug!(name = %name, limit = limit, cluster_override = ?cluster_override, "Executing read_data_product");
908
909    // Parse and safely quote the name for SQL interpolation.
910    let safe_name = safe_data_product_name(name)?;
911
912    // Lightweight existence check: verify the data product is visible in the
913    // catalog before running the read query. This gives a clean DataProductNotFound
914    // error for missing or inaccessible products (including RBAC revocations)
915    // without relying on fragile error code matching.
916    // TODO: Remove this extra round-trip once catalog errors get specific SQL
917    // error codes (see TODO in src/adapter/src/error.rs `fn code()`), then we
918    // can translate the query error directly.
919    let lookup_query = format!(
920        "SELECT 1 FROM mz_internal.mz_mcp_data_products WHERE object_name = {}",
921        escaped_string_literal(name)
922    );
923    let lookup_rows = execute_sql(client, &lookup_query).await?;
924    if lookup_rows.is_empty() {
925        return Err(McpRequestError::DataProductNotFound(name.to_string()));
926    }
927
928    let clamped_limit = limit.min(MAX_READ_LIMIT);
929
930    let read_query = match cluster_override {
931        Some(cluster) => format!(
932            "BEGIN READ ONLY; SET CLUSTER = {}; SELECT * FROM {} LIMIT {}\n; COMMIT;",
933            escaped_string_literal(cluster),
934            safe_name,
935            clamped_limit,
936        ),
937        // Single statement — skip explicit transaction for better performance.
938        None => format!("SELECT * FROM {} LIMIT {}", safe_name, clamped_limit),
939    };
940
941    let rows = execute_sql(client, &read_query).await?;
942
943    format_rows_response(rows, max_response_size)
944}
945
946/// Validates query is a single SELECT, SHOW, or EXPLAIN statement.
947fn validate_readonly_query(sql: &str) -> Result<(), McpRequestError> {
948    let sql = sql.trim();
949    if sql.is_empty() {
950        return Err(McpRequestError::QueryValidationFailed(
951            "Empty query".to_string(),
952        ));
953    }
954
955    // Parse the SQL to get AST
956    let stmts = parse(sql).map_err(|e| {
957        McpRequestError::QueryValidationFailed(format!("Failed to parse SQL: {}", e))
958    })?;
959
960    // Only allow a single statement
961    if stmts.len() != 1 {
962        return Err(McpRequestError::QueryValidationFailed(format!(
963            "Only one query allowed at a time. Found {} statements.",
964            stmts.len()
965        )));
966    }
967
968    // Allowlist: Only SELECT, SHOW, and EXPLAIN statements permitted
969    let stmt = &stmts[0];
970    use mz_sql_parser::ast::Statement;
971
972    match &stmt.ast {
973        Statement::Select(_) | Statement::Show(_) | Statement::ExplainPlan(_) => {
974            // Allowed - read-only operations
975            Ok(())
976        }
977        _ => Err(McpRequestError::QueryValidationFailed(
978            "Only SELECT, SHOW, and EXPLAIN statements are allowed".to_string(),
979        )),
980    }
981}
982
983async fn execute_query(
984    client: &mut AuthedClient,
985    cluster: &str,
986    sql_query: &str,
987    max_response_size: usize,
988) -> Result<McpResult, McpRequestError> {
989    debug!(cluster = %cluster, "Executing user query");
990
991    validate_readonly_query(sql_query)?;
992
993    // Use READ ONLY transaction to prevent modifications
994    // Combine with SET CLUSTER (prometheus.rs:29-33 pattern)
995    let combined_query = format!(
996        "BEGIN READ ONLY; SET CLUSTER = {}; {}\n; COMMIT;",
997        escaped_string_literal(cluster),
998        sql_query
999    );
1000
1001    let rows = execute_sql(client, &combined_query).await?;
1002
1003    format_rows_response(rows, max_response_size)
1004}
1005
1006async fn query_system_catalog(
1007    client: &mut AuthedClient,
1008    sql_query: &str,
1009    max_response_size: usize,
1010) -> Result<McpResult, McpRequestError> {
1011    debug!("Executing query_system_catalog");
1012
1013    // First validate it's a read-only query
1014    validate_readonly_query(sql_query)?;
1015
1016    // Then validate that query only references mz_* tables by parsing the SQL
1017    validate_system_catalog_query(sql_query)?;
1018
1019    // Wrap the query in a READ ONLY transaction with a tight search_path
1020    // restricted to system schemas. This prevents unqualified `mz_*` references
1021    // from resolving to user-created objects (e.g. a view `public.mz_leak`) via
1022    // the session's search_path (mirrors the `BEGIN READ ONLY; SET ...` pattern
1023    // used by the agent `query` tool).
1024    let combined_query = format!(
1025        "BEGIN READ ONLY; SET search_path = mz_catalog, mz_internal, pg_catalog, information_schema; {}; COMMIT;",
1026        sql_query
1027    );
1028
1029    let rows = execute_sql(client, &combined_query).await?;
1030
1031    format_rows_response(rows, max_response_size)
1032}
1033
1034/// Collects table references from SQL AST with their schema qualification.
1035struct TableReferenceCollector {
1036    /// Stores (schema, table_name) tuples. Schema is None if unqualified.
1037    tables: Vec<(Option<String>, String)>,
1038    /// CTE names to exclude from validation (they're not real tables)
1039    cte_names: std::collections::BTreeSet<String>,
1040}
1041
1042impl TableReferenceCollector {
1043    fn new() -> Self {
1044        Self {
1045            tables: Vec::new(),
1046            cte_names: std::collections::BTreeSet::new(),
1047        }
1048    }
1049}
1050
1051impl<'ast> Visit<'ast, Raw> for TableReferenceCollector {
1052    fn visit_cte(&mut self, cte: &'ast mz_sql_parser::ast::Cte<Raw>) {
1053        // Track CTE names so we don't treat them as table references
1054        self.cte_names
1055            .insert(cte.alias.name.as_str().to_lowercase());
1056        visit::visit_cte(self, cte);
1057    }
1058
1059    fn visit_table_factor(&mut self, table_factor: &'ast mz_sql_parser::ast::TableFactor<Raw>) {
1060        // Only visit actual table references in FROM/JOIN clauses, not function names
1061        if let mz_sql_parser::ast::TableFactor::Table { name, .. } = table_factor {
1062            match name {
1063                RawItemName::Name(n) | RawItemName::Id(_, n, _) => {
1064                    let parts = &n.0;
1065                    if !parts.is_empty() {
1066                        let table_name = parts.last().unwrap().as_str().to_lowercase();
1067
1068                        // Skip if this is a CTE reference, not a real table
1069                        if self.cte_names.contains(&table_name) {
1070                            visit::visit_table_factor(self, table_factor);
1071                            return;
1072                        }
1073
1074                        // Extract schema if qualified (e.g., mz_catalog.mz_tables)
1075                        let schema = if parts.len() >= 2 {
1076                            Some(parts[parts.len() - 2].as_str().to_lowercase())
1077                        } else {
1078                            None
1079                        };
1080                        self.tables.push((schema, table_name));
1081                    }
1082                }
1083            }
1084        }
1085        visit::visit_table_factor(self, table_factor);
1086    }
1087}
1088
1089/// Validates that a query only references system catalog tables.
1090///
1091/// For SELECT statements, all table references must be in system schemas
1092/// (from `SYSTEM_SCHEMAS`, excluding `mz_unsafe`), and at least one system
1093/// table must be referenced (constant queries like `SELECT 1` are rejected
1094/// to prevent misuse of the developer endpoint for arbitrary computation).
1095/// SHOW and EXPLAIN statements are allowed without table references.
1096fn validate_system_catalog_query(sql: &str) -> Result<(), McpRequestError> {
1097    // Parse the SQL to validate it
1098    let stmts = parse(sql).map_err(|e| {
1099        McpRequestError::QueryValidationFailed(format!("Failed to parse SQL: {}", e))
1100    })?;
1101
1102    if stmts.is_empty() {
1103        return Err(McpRequestError::QueryValidationFailed(
1104            "Empty query".to_string(),
1105        ));
1106    }
1107
1108    // Walk the AST to collect all table references
1109    let mut collector = TableReferenceCollector::new();
1110    for stmt in &stmts {
1111        collector.visit_statement(&stmt.ast);
1112    }
1113
1114    // Use the canonical system schema list, excluding mz_unsafe which contains
1115    // internal-only objects that should not be exposed to MCP clients.
1116    let is_allowed_schema =
1117        |s: &str| SYSTEM_SCHEMAS.contains(&s) && s != namespaces::MZ_UNSAFE_SCHEMA;
1118
1119    // Helper to check if a table reference is allowed. Unqualified references
1120    // are accepted here because execution uses a tight `search_path` containing
1121    // only system schemas (see `query_system_catalog`), so user-created views
1122    // like `public.mz_leak` cannot be reached by an unqualified name.
1123    let is_system_table = |(schema, table_name): &(Option<String>, String)| match schema {
1124        Some(s) => is_allowed_schema(s.as_str()),
1125        None => table_name.starts_with("mz_"),
1126    };
1127
1128    // Check that all table references are system tables
1129    let non_system_tables: Vec<String> = collector
1130        .tables
1131        .iter()
1132        .filter(|t| !is_system_table(t))
1133        .map(|(schema, table)| match schema {
1134            Some(s) => format!("{}.{}", s, table),
1135            None => table.clone(),
1136        })
1137        .collect();
1138
1139    if !non_system_tables.is_empty() {
1140        return Err(McpRequestError::QueryValidationFailed(format!(
1141            "Query references non-system tables: {}. Only system catalog tables (mz_*, pg_catalog, information_schema) are allowed.",
1142            non_system_tables.join(", ")
1143        )));
1144    }
1145
1146    // SHOW and EXPLAIN statements don't reference tables in the AST, but are safe
1147    // read-only operations. Only require system table references for SELECT.
1148    use mz_sql_parser::ast::Statement;
1149    let is_select = stmts.iter().any(|s| matches!(&s.ast, Statement::Select(_)));
1150
1151    if is_select && (collector.tables.is_empty() || !collector.tables.iter().any(is_system_table)) {
1152        return Err(McpRequestError::QueryValidationFailed(
1153            "Query must reference at least one system catalog table".to_string(),
1154        ));
1155    }
1156
1157    Ok(())
1158}
1159
1160#[cfg(test)]
1161mod tests {
1162    use super::*;
1163
1164    #[mz_ore::test]
1165    fn test_validate_readonly_query_select() {
1166        assert!(validate_readonly_query("SELECT * FROM mz_tables").is_ok());
1167        assert!(validate_readonly_query("SELECT 1 + 2").is_ok());
1168        assert!(validate_readonly_query("  SELECT 1  ").is_ok());
1169    }
1170
1171    #[mz_ore::test]
1172    fn test_validate_readonly_query_subqueries() {
1173        // Simple subquery in WHERE clause
1174        assert!(
1175            validate_readonly_query(
1176                "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM mz_columns)"
1177            )
1178            .is_ok()
1179        );
1180
1181        // Subquery in FROM clause
1182        assert!(
1183            validate_readonly_query(
1184                "SELECT * FROM (SELECT * FROM mz_tables WHERE name LIKE 'test%') AS t"
1185            )
1186            .is_ok()
1187        );
1188
1189        // Correlated subquery
1190        assert!(validate_readonly_query(
1191            "SELECT * FROM mz_tables t WHERE EXISTS (SELECT 1 FROM mz_columns c WHERE c.id = t.id)"
1192        )
1193        .is_ok());
1194
1195        // Nested subqueries
1196        assert!(validate_readonly_query(
1197            "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM mz_columns WHERE type IN (SELECT name FROM mz_types))"
1198        )
1199        .is_ok());
1200
1201        // Subquery with aggregation
1202        assert!(
1203            validate_readonly_query(
1204                "SELECT * FROM mz_tables WHERE id = (SELECT MAX(id) FROM mz_columns)"
1205            )
1206            .is_ok()
1207        );
1208    }
1209
1210    #[mz_ore::test]
1211    fn test_validate_readonly_query_show() {
1212        assert!(validate_readonly_query("SHOW CLUSTERS").is_ok());
1213        assert!(validate_readonly_query("SHOW TABLES").is_ok());
1214    }
1215
1216    #[mz_ore::test]
1217    fn test_validate_readonly_query_explain() {
1218        assert!(validate_readonly_query("EXPLAIN SELECT 1").is_ok());
1219    }
1220
1221    #[mz_ore::test]
1222    fn test_validate_readonly_query_rejects_writes() {
1223        assert!(validate_readonly_query("INSERT INTO t VALUES (1)").is_err());
1224        assert!(validate_readonly_query("UPDATE t SET a = 1").is_err());
1225        assert!(validate_readonly_query("DELETE FROM t").is_err());
1226        assert!(validate_readonly_query("CREATE TABLE t (a INT)").is_err());
1227        assert!(validate_readonly_query("DROP TABLE t").is_err());
1228    }
1229
1230    #[mz_ore::test]
1231    fn test_validate_readonly_query_rejects_multiple() {
1232        assert!(validate_readonly_query("SELECT 1; SELECT 2").is_err());
1233    }
1234
1235    #[mz_ore::test]
1236    fn test_validate_readonly_query_rejects_empty() {
1237        assert!(validate_readonly_query("").is_err());
1238        assert!(validate_readonly_query("   ").is_err());
1239    }
1240
1241    #[mz_ore::test]
1242    fn test_validate_system_catalog_query_accepts_mz_tables() {
1243        assert!(validate_system_catalog_query("SELECT * FROM mz_tables").is_ok());
1244        assert!(validate_system_catalog_query("SELECT * FROM mz_internal.mz_comments").is_ok());
1245        assert!(
1246            validate_system_catalog_query(
1247                "SELECT * FROM mz_tables t JOIN mz_columns c ON t.id = c.id"
1248            )
1249            .is_ok()
1250        );
1251    }
1252
1253    #[mz_ore::test]
1254    fn test_validate_system_catalog_query_subqueries() {
1255        // Subquery with mz_* tables
1256        assert!(
1257            validate_system_catalog_query(
1258                "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM mz_columns)"
1259            )
1260            .is_ok()
1261        );
1262
1263        // Nested subqueries with mz_* tables
1264        assert!(validate_system_catalog_query(
1265            "SELECT * FROM mz_tables WHERE id IN (SELECT table_id FROM mz_columns WHERE type IN (SELECT id FROM mz_types))"
1266        )
1267        .is_ok());
1268
1269        // Subquery in FROM clause
1270        assert!(
1271            validate_system_catalog_query(
1272                "SELECT * FROM (SELECT * FROM mz_tables WHERE name LIKE 'test%') AS t"
1273            )
1274            .is_ok()
1275        );
1276
1277        // Reject subqueries that reference non-mz_* tables
1278        assert!(
1279            validate_system_catalog_query(
1280                "SELECT * FROM mz_tables WHERE id IN (SELECT table_id FROM user_data)"
1281            )
1282            .is_err()
1283        );
1284
1285        // Reject mixed references in nested subqueries
1286        assert!(validate_system_catalog_query(
1287            "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM (SELECT id FROM user_table) AS t)"
1288        )
1289        .is_err());
1290    }
1291
1292    #[mz_ore::test]
1293    fn test_validate_system_catalog_query_rejects_user_tables() {
1294        assert!(validate_system_catalog_query("SELECT * FROM user_data").is_err());
1295        assert!(validate_system_catalog_query("SELECT * FROM my_table").is_err());
1296        // Security: reject queries that mention mz_ in a non-table context
1297        assert!(
1298            validate_system_catalog_query("SELECT * FROM user_data WHERE 'mz_' IS NOT NULL")
1299                .is_err()
1300        );
1301    }
1302
1303    #[mz_ore::test]
1304    fn test_validate_system_catalog_query_allows_functions() {
1305        // Function names should not be treated as table references
1306        assert!(
1307            validate_system_catalog_query(
1308                "SELECT date_part('year', now())::int4 AS y FROM mz_tables LIMIT 1"
1309            )
1310            .is_ok()
1311        );
1312        assert!(validate_system_catalog_query("SELECT length(name) FROM mz_tables").is_ok());
1313        assert!(
1314            validate_system_catalog_query(
1315                "SELECT count(*) FROM mz_sources WHERE now() > created_at"
1316            )
1317            .is_ok()
1318        );
1319    }
1320
1321    #[mz_ore::test]
1322    fn test_validate_system_catalog_query_schema_qualified() {
1323        // Qualified with allowed schemas should work
1324        assert!(validate_system_catalog_query("SELECT * FROM mz_catalog.mz_tables").is_ok());
1325        assert!(validate_system_catalog_query("SELECT * FROM mz_internal.mz_sessions").is_ok());
1326        assert!(validate_system_catalog_query("SELECT * FROM pg_catalog.pg_type").is_ok());
1327        assert!(validate_system_catalog_query("SELECT * FROM information_schema.tables").is_ok());
1328
1329        // Qualified with disallowed schema should fail
1330        assert!(validate_system_catalog_query("SELECT * FROM public.user_table").is_err());
1331        assert!(validate_system_catalog_query("SELECT * FROM myschema.mytable").is_err());
1332
1333        // mz_unsafe is a system schema but explicitly blocked for MCP
1334        assert!(
1335            validate_system_catalog_query("SELECT * FROM mz_unsafe.mz_some_table").is_err(),
1336            "mz_unsafe schema should be blocked even though it is a system schema"
1337        );
1338
1339        // Mixed: system and user schemas should fail
1340        assert!(
1341            validate_system_catalog_query(
1342                "SELECT * FROM mz_catalog.mz_tables JOIN public.user_data ON true"
1343            )
1344            .is_err()
1345        );
1346    }
1347
1348    #[mz_ore::test]
1349    fn test_validate_system_catalog_query_adversarial_cases() {
1350        // Try to sneak in user table via CTE
1351        assert!(
1352            validate_system_catalog_query(
1353                "WITH user_cte AS (SELECT * FROM user_data) \
1354                 SELECT * FROM mz_tables, user_cte"
1355            )
1356            .is_err(),
1357            "Should reject CTE referencing user table"
1358        );
1359
1360        // Complex multi-level CTE with user table buried deep
1361        assert!(
1362            validate_system_catalog_query(
1363                "WITH \
1364                   cte1 AS (SELECT * FROM mz_tables), \
1365                   cte2 AS (SELECT * FROM cte1), \
1366                   cte3 AS (SELECT * FROM user_data) \
1367                 SELECT * FROM cte2"
1368            )
1369            .is_err(),
1370            "Should reject CTE chain with user table"
1371        );
1372
1373        // Multiple joins - user table in the middle
1374        assert!(
1375            validate_system_catalog_query(
1376                "SELECT * FROM mz_tables t1 \
1377                 JOIN user_data u ON t1.id = u.id \
1378                 JOIN mz_sources s ON t1.id = s.id"
1379            )
1380            .is_err(),
1381            "Should reject multi-join with user table"
1382        );
1383
1384        // LEFT JOIN trying to hide user table
1385        assert!(
1386            validate_system_catalog_query(
1387                "SELECT * FROM mz_tables t \
1388                 LEFT JOIN user_data u ON t.id = u.table_id \
1389                 WHERE u.id IS NULL"
1390            )
1391            .is_err(),
1392            "Should reject LEFT JOIN with user table"
1393        );
1394
1395        // Nested subquery with user table in FROM
1396        assert!(
1397            validate_system_catalog_query(
1398                "SELECT * FROM mz_tables WHERE id IN \
1399                 (SELECT table_id FROM (SELECT * FROM user_data) AS u)"
1400            )
1401            .is_err(),
1402            "Should reject nested subquery with user table"
1403        );
1404
1405        // UNION trying to mix system and user data
1406        assert!(
1407            validate_system_catalog_query(
1408                "SELECT name FROM mz_tables \
1409                 UNION \
1410                 SELECT name FROM user_data"
1411            )
1412            .is_err(),
1413            "Should reject UNION with user table"
1414        );
1415
1416        // UNION ALL variation
1417        assert!(
1418            validate_system_catalog_query(
1419                "SELECT id FROM mz_sources \
1420                 UNION ALL \
1421                 SELECT id FROM products"
1422            )
1423            .is_err(),
1424            "Should reject UNION ALL with user table"
1425        );
1426
1427        // Cross join with user table
1428        assert!(
1429            validate_system_catalog_query("SELECT * FROM mz_tables CROSS JOIN user_data").is_err(),
1430            "Should reject CROSS JOIN with user table"
1431        );
1432
1433        // Subquery in SELECT clause referencing user table
1434        assert!(
1435            validate_system_catalog_query(
1436                "SELECT t.*, (SELECT COUNT(*) FROM user_data) AS cnt FROM mz_tables t"
1437            )
1438            .is_err(),
1439            "Should reject subquery in SELECT with user table"
1440        );
1441
1442        // Try to use a schema name that looks similar to allowed ones
1443        assert!(
1444            validate_system_catalog_query("SELECT * FROM mz_catalogg.fake_table").is_err(),
1445            "Should reject typo-squatting schema name"
1446        );
1447        assert!(
1448            validate_system_catalog_query("SELECT * FROM mz_catalog_hack.fake_table").is_err(),
1449            "Should reject fake schema with mz_catalog prefix"
1450        );
1451
1452        // Lateral join with user table
1453        assert!(
1454            validate_system_catalog_query(
1455                "SELECT * FROM mz_tables t, LATERAL (SELECT * FROM user_data WHERE id = t.id) u"
1456            )
1457            .is_err(),
1458            "Should reject LATERAL join with user table"
1459        );
1460
1461        // Valid complex query - all system tables
1462        assert!(
1463            validate_system_catalog_query(
1464                "WITH \
1465                   tables AS (SELECT * FROM mz_tables), \
1466                   sources AS (SELECT * FROM mz_sources) \
1467                 SELECT t.name, s.name \
1468                 FROM tables t \
1469                 JOIN sources s ON t.id = s.id \
1470                 WHERE t.id IN (SELECT id FROM mz_columns)"
1471            )
1472            .is_ok(),
1473            "Should allow complex query with only system tables"
1474        );
1475
1476        // Valid UNION of system tables
1477        assert!(
1478            validate_system_catalog_query(
1479                "SELECT name FROM mz_tables \
1480                 UNION \
1481                 SELECT name FROM mz_sources"
1482            )
1483            .is_ok(),
1484            "Should allow UNION of system tables"
1485        );
1486    }
1487
1488    #[mz_ore::test]
1489    fn test_validate_system_catalog_query_rejects_constant_queries() {
1490        // SELECT without any table reference should be rejected — the developer
1491        // endpoint is for system catalog queries, not arbitrary computation.
1492        assert!(
1493            validate_system_catalog_query("SELECT 1").is_err(),
1494            "Should reject constant SELECT with no table references"
1495        );
1496        assert!(
1497            validate_system_catalog_query("SELECT 1 + 2, 'hello'").is_err(),
1498            "Should reject constant expression SELECT"
1499        );
1500        assert!(
1501            validate_system_catalog_query("SELECT now()").is_err(),
1502            "Should reject function-only SELECT with no table references"
1503        );
1504    }
1505
1506    #[mz_ore::test]
1507    fn test_validate_system_catalog_query_rejects_mixed_tables() {
1508        assert!(
1509            validate_system_catalog_query(
1510                "SELECT * FROM mz_tables t JOIN user_data u ON t.id = u.table_id"
1511            )
1512            .is_err()
1513        );
1514    }
1515
1516    #[mz_ore::test]
1517    fn test_validate_system_catalog_query_allows_show() {
1518        // SHOW queries don't reference tables in the AST but are safe read-only ops
1519        assert!(
1520            validate_system_catalog_query("SHOW TABLES FROM mz_internal").is_ok(),
1521            "SHOW TABLES FROM mz_internal should be allowed"
1522        );
1523        assert!(
1524            validate_system_catalog_query("SHOW TABLES FROM mz_catalog").is_ok(),
1525            "SHOW TABLES FROM mz_catalog should be allowed"
1526        );
1527        assert!(
1528            validate_system_catalog_query("SHOW CLUSTERS").is_ok(),
1529            "SHOW CLUSTERS should be allowed"
1530        );
1531        assert!(
1532            validate_system_catalog_query("SHOW SOURCES").is_ok(),
1533            "SHOW SOURCES should be allowed"
1534        );
1535        assert!(
1536            validate_system_catalog_query("SHOW TABLES").is_ok(),
1537            "SHOW TABLES should be allowed"
1538        );
1539    }
1540
1541    #[mz_ore::test]
1542    fn test_validate_system_catalog_query_allows_explain() {
1543        assert!(
1544            validate_system_catalog_query("EXPLAIN SELECT * FROM mz_tables").is_ok(),
1545            "EXPLAIN of system table query should be allowed"
1546        );
1547        assert!(
1548            validate_system_catalog_query("EXPLAIN SELECT 1").is_ok(),
1549            "EXPLAIN SELECT 1 should be allowed"
1550        );
1551    }
1552
1553    // ── Query tool feature flag tests ──────────────────────────────────────
1554
1555    #[mz_ore::test(tokio::test)]
1556    async fn test_tools_list_agent_query_tool_disabled() {
1557        let result = handle_tools_list(McpEndpointType::Agent, false, 1_000_000)
1558            .await
1559            .unwrap();
1560        let McpResult::ToolsList(list) = result else {
1561            panic!("Expected ToolsList result");
1562        };
1563        let tool_names: Vec<&str> = list.tools.iter().map(|t| t.name.as_str()).collect();
1564        assert!(
1565            tool_names.contains(&"get_data_products"),
1566            "get_data_products should always be present"
1567        );
1568        assert!(
1569            tool_names.contains(&"get_data_product_details"),
1570            "get_data_product_details should always be present"
1571        );
1572        assert!(
1573            tool_names.contains(&"read_data_product"),
1574            "read_data_product should always be present"
1575        );
1576        assert!(
1577            !tool_names.contains(&"query"),
1578            "query tool should be hidden when disabled"
1579        );
1580    }
1581
1582    #[mz_ore::test(tokio::test)]
1583    async fn test_tools_list_agent_query_tool_enabled() {
1584        let result = handle_tools_list(McpEndpointType::Agent, true, 1_000_000)
1585            .await
1586            .unwrap();
1587        let McpResult::ToolsList(list) = result else {
1588            panic!("Expected ToolsList result");
1589        };
1590        let tool_names: Vec<&str> = list.tools.iter().map(|t| t.name.as_str()).collect();
1591        assert!(
1592            tool_names.contains(&"get_data_products"),
1593            "get_data_products should always be present"
1594        );
1595        assert!(
1596            tool_names.contains(&"get_data_product_details"),
1597            "get_data_product_details should always be present"
1598        );
1599        assert!(
1600            tool_names.contains(&"read_data_product"),
1601            "read_data_product should always be present"
1602        );
1603        assert!(
1604            tool_names.contains(&"query"),
1605            "query tool should be present when enabled"
1606        );
1607    }
1608
1609    #[mz_ore::test(tokio::test)]
1610    async fn test_tools_list_developer_unaffected_by_query_flag() {
1611        // Developer endpoint should not be affected by the query tool flag
1612        for flag in [true, false] {
1613            let result = handle_tools_list(McpEndpointType::Developer, flag, 1_000_000)
1614                .await
1615                .unwrap();
1616            let McpResult::ToolsList(list) = result else {
1617                panic!("Expected ToolsList result");
1618            };
1619            let tool_names: Vec<&str> = list.tools.iter().map(|t| t.name.as_str()).collect();
1620            assert!(
1621                tool_names.contains(&"query_system_catalog"),
1622                "query_system_catalog should always be present on developer"
1623            );
1624            assert!(
1625                !tool_names.contains(&"query"),
1626                "query tool should never appear on developer"
1627            );
1628        }
1629    }
1630
1631    // ── Response size cap tests ────────────────────────────────────────
1632
1633    #[mz_ore::test]
1634    fn test_format_rows_response_within_limit() {
1635        let rows = vec![vec![json!("a"), json!(1)], vec![json!("b"), json!(2)]];
1636        let result = format_rows_response(rows, 1_000_000).unwrap();
1637        let McpResult::ToolContent(content) = result else {
1638            panic!("Expected ToolContent");
1639        };
1640        assert_eq!(content.content.len(), 1);
1641        assert!(content.content[0].text.contains("\"a\""));
1642        assert!(content.content[0].text.contains("\"b\""));
1643    }
1644
1645    #[mz_ore::test]
1646    fn test_format_rows_response_errors_when_over_limit() {
1647        let rows: Vec<Vec<serde_json::Value>> = (0..100)
1648            .map(|i| vec![json!(format!("row_{}", i)), json!(i)])
1649            .collect();
1650        let err = format_rows_response(rows, 500).unwrap_err();
1651        let msg = err.to_string();
1652        assert!(
1653            msg.contains("exceeds the 500 byte limit"),
1654            "Error should mention the size limit, got: {msg}"
1655        );
1656        assert!(
1657            msg.contains("Use LIMIT or WHERE"),
1658            "Error should suggest narrowing the query, got: {msg}"
1659        );
1660    }
1661
1662    #[mz_ore::test]
1663    fn test_format_rows_response_empty_rows() {
1664        let rows: Vec<Vec<serde_json::Value>> = vec![];
1665        let result = format_rows_response(rows, 1000).unwrap();
1666        let McpResult::ToolContent(content) = result else {
1667            panic!("Expected ToolContent");
1668        };
1669        assert_eq!(content.content.len(), 1);
1670        assert_eq!(content.content[0].text, "[]");
1671    }
1672
1673    // ── Data product name validation tests ─────────────────────────────
1674
1675    #[mz_ore::test]
1676    fn test_safe_data_product_name_valid() {
1677        // Fully qualified quoted identifiers
1678        assert_eq!(
1679            safe_data_product_name(r#""materialize"."public"."my_view""#).unwrap(),
1680            r#""materialize"."public"."my_view""#
1681        );
1682        // Two-part name
1683        assert_eq!(
1684            safe_data_product_name(r#""public"."my_view""#).unwrap(),
1685            r#""public"."my_view""#
1686        );
1687        // Unquoted name gets quoted in stable mode
1688        assert_eq!(safe_data_product_name("my_view").unwrap(), r#""my_view""#);
1689    }
1690
1691    #[mz_ore::test]
1692    fn test_safe_data_product_name_rejects_empty() {
1693        assert!(safe_data_product_name("").is_err());
1694        assert!(safe_data_product_name("   ").is_err());
1695    }
1696
1697    #[mz_ore::test]
1698    fn test_safe_data_product_name_rejects_sql_injection() {
1699        // Attempted injection via semicolon
1700        assert!(safe_data_product_name("my_view; DROP TABLE users").is_err());
1701        // Attempted injection via subquery
1702        assert!(safe_data_product_name("my_view UNION SELECT * FROM secrets").is_err());
1703        // Multiple table references via comma
1704        assert!(safe_data_product_name("my_view, secrets").is_err());
1705        // SQL keywords after name are rejected by the parser
1706        assert!(safe_data_product_name("my_view WHERE 1=1 --").is_err());
1707    }
1708
1709    #[mz_ore::test]
1710    fn test_mcp_error_codes() {
1711        assert_eq!(
1712            McpRequestError::InvalidJsonRpcVersion.error_code(),
1713            error_codes::INVALID_REQUEST
1714        );
1715        assert_eq!(
1716            McpRequestError::MethodNotFound("test".to_string()).error_code(),
1717            error_codes::METHOD_NOT_FOUND
1718        );
1719        assert_eq!(
1720            McpRequestError::QueryExecutionFailed("test".to_string()).error_code(),
1721            error_codes::INTERNAL_ERROR
1722        );
1723    }
1724}