Skip to main content

mz_environmentd/http/
mcp.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Model Context Protocol (MCP) HTTP handlers.
11//!
12//! Exposes Materialize data products to AI agents via JSON-RPC 2.0 over HTTP POST.
13//!
14//! ## Endpoints
15//!
16//! - `/api/mcp/agents` - User data products for customer AI agents
17//! - `/api/mcp/observatory` - System catalog (`mz_*`) for troubleshooting
18//!
19//! ## Tools
20//!
21//! **Agents:** `get_data_products`, `get_data_product_details`, `query`
22//! **Observatory:** `query_system_catalog`
23//!
24//! Data products are discovered via `mz_internal.mz_mcp_data_products` system view.
25
26use std::time::Duration;
27
28use anyhow::anyhow;
29use axum::Json;
30use axum::response::IntoResponse;
31use http::{HeaderMap, StatusCode};
32use mz_adapter_types::dyncfgs::{
33    ENABLE_MCP_AGENTS, ENABLE_MCP_AGENTS_QUERY_TOOL, ENABLE_MCP_OBSERVATORY, MCP_MAX_RESPONSE_SIZE,
34};
35use mz_repr::namespaces::{self, SYSTEM_SCHEMAS};
36use mz_sql::parse::parse;
37use mz_sql::session::metadata::SessionMetadata;
38use mz_sql_parser::ast::display::{AstDisplay, escaped_string_literal};
39use mz_sql_parser::ast::visit::{self, Visit};
40use mz_sql_parser::ast::{Raw, RawItemName};
41use mz_sql_parser::parser::parse_item_name;
42use serde::{Deserialize, Serialize};
43use serde_json::json;
44use thiserror::Error;
45use tracing::{debug, warn};
46
47use crate::http::AuthedClient;
48use crate::http::sql::{SqlRequest, SqlResponse, SqlResult, execute_request};
49
50// To add a new tool: add entry to tools/list, add handler function, add dispatch case.
51
52/// JSON-RPC protocol version used in all MCP requests and responses.
53const JSONRPC_VERSION: &str = "2.0";
54
55/// MCP protocol version returned in the `initialize` response.
56/// Spec: <https://modelcontextprotocol.io/specification/2025-11-25>
57const MCP_PROTOCOL_VERSION: &str = "2025-11-25";
58
59/// Maximum time an MCP tool call can run before the HTTP response is returned.
60/// Note: this returns a clean JSON-RPC error to the caller, but the underlying
61/// query may continue running on the cluster until it completes or is cancelled
62/// separately (see database-issues#9947 for SELECT timeout gaps).
63const MCP_REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
64
65// Discovery uses the lightweight view (no JSON schema computation).
66const DISCOVERY_QUERY: &str = "SELECT * FROM mz_internal.mz_mcp_data_products";
67// Details uses the full view with JSON schema.
68const DETAILS_QUERY_PREFIX: &str =
69    "SELECT * FROM mz_internal.mz_mcp_data_product_details WHERE object_name = ";
70
71/// MCP request errors, mapped to JSON-RPC error codes.
72#[derive(Debug, Error)]
73enum McpRequestError {
74    #[error("Invalid JSON-RPC version: expected 2.0")]
75    InvalidJsonRpcVersion,
76    #[error("Method not found: {0}")]
77    #[allow(dead_code)] // Handled by serde deserialization, kept for error mapping
78    MethodNotFound(String),
79    #[error("Tool not found: {0}")]
80    ToolNotFound(String),
81    #[error("Data product not found: {0}")]
82    DataProductNotFound(String),
83    #[error("Query validation failed: {0}")]
84    QueryValidationFailed(String),
85    #[error("Query execution failed: {0}")]
86    QueryExecutionFailed(String),
87    #[error("Internal error: {0}")]
88    Internal(#[from] anyhow::Error),
89}
90
91impl McpRequestError {
92    fn error_code(&self) -> i32 {
93        match self {
94            Self::InvalidJsonRpcVersion => error_codes::INVALID_REQUEST,
95            Self::MethodNotFound(_) => error_codes::METHOD_NOT_FOUND,
96            Self::ToolNotFound(_) => error_codes::INVALID_PARAMS,
97            Self::DataProductNotFound(_) => error_codes::INVALID_PARAMS,
98            Self::QueryValidationFailed(_) => error_codes::INVALID_PARAMS,
99            Self::QueryExecutionFailed(_) | Self::Internal(_) => error_codes::INTERNAL_ERROR,
100        }
101    }
102
103    fn error_type(&self) -> &'static str {
104        match self {
105            Self::InvalidJsonRpcVersion => "InvalidRequest",
106            Self::MethodNotFound(_) => "MethodNotFound",
107            Self::ToolNotFound(_) => "ToolNotFound",
108            Self::DataProductNotFound(_) => "DataProductNotFound",
109            Self::QueryValidationFailed(_) => "ValidationError",
110            Self::QueryExecutionFailed(_) => "ExecutionError",
111            Self::Internal(_) => "InternalError",
112        }
113    }
114}
115
116/// JSON-RPC 2.0 request. Requests have `id`; notifications don't.
117#[derive(Debug, Deserialize)]
118pub(crate) struct McpRequest {
119    jsonrpc: String,
120    id: Option<serde_json::Value>,
121    #[serde(flatten)]
122    method: McpMethod,
123}
124
125/// MCP method variants with their associated parameters.
126#[derive(Debug, Deserialize)]
127#[serde(tag = "method", content = "params")]
128enum McpMethod {
129    /// Initialize method - params accepted but not currently used
130    #[serde(rename = "initialize")]
131    Initialize(#[allow(dead_code)] InitializeParams),
132    #[serde(rename = "tools/list")]
133    ToolsList,
134    #[serde(rename = "tools/call")]
135    ToolsCall(ToolsCallParams),
136    /// Catch-all for unknown methods (e.g. `notifications/initialized`)
137    #[serde(other)]
138    Unknown,
139}
140
141impl std::fmt::Display for McpMethod {
142    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143        match self {
144            McpMethod::Initialize(_) => write!(f, "initialize"),
145            McpMethod::ToolsList => write!(f, "tools/list"),
146            McpMethod::ToolsCall(_) => write!(f, "tools/call"),
147            McpMethod::Unknown => write!(f, "unknown"),
148        }
149    }
150}
151
152#[derive(Debug, Deserialize)]
153struct InitializeParams {
154    /// Protocol version from client. Not currently validated but accepted for MCP compliance.
155    #[serde(rename = "protocolVersion")]
156    #[allow(dead_code)]
157    protocol_version: String,
158    /// Client capabilities. Not currently used but accepted for MCP compliance.
159    #[serde(default)]
160    #[allow(dead_code)]
161    capabilities: serde_json::Value,
162    /// Client information (name, version). Not currently used but accepted for MCP compliance.
163    #[serde(rename = "clientInfo")]
164    #[allow(dead_code)]
165    client_info: Option<ClientInfo>,
166}
167
168#[derive(Debug, Deserialize)]
169struct ClientInfo {
170    #[allow(dead_code)]
171    name: String,
172    #[allow(dead_code)]
173    version: String,
174}
175
176/// Tool call parameters, deserialized via adjacently tagged enum.
177/// Serde maps `name` to the variant and `arguments` to the variant's data.
178#[derive(Debug, Deserialize)]
179#[serde(tag = "name", content = "arguments")]
180#[serde(rename_all = "snake_case")]
181enum ToolsCallParams {
182    // Agents endpoint tools
183    // Uses an ignored empty struct so MCP clients sending `"arguments": {}` can deserialize.
184    GetDataProducts(#[serde(default)] ()),
185    GetDataProductDetails(GetDataProductDetailsParams),
186    ReadDataProduct(ReadDataProductParams),
187    Query(QueryParams),
188    // Observatory endpoint tools
189    QuerySystemCatalog(QuerySystemCatalogParams),
190}
191
192impl std::fmt::Display for ToolsCallParams {
193    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194        match self {
195            ToolsCallParams::GetDataProducts(_) => write!(f, "get_data_products"),
196            ToolsCallParams::GetDataProductDetails(_) => write!(f, "get_data_product_details"),
197            ToolsCallParams::ReadDataProduct(_) => write!(f, "read_data_product"),
198            ToolsCallParams::Query(_) => write!(f, "query"),
199            ToolsCallParams::QuerySystemCatalog(_) => write!(f, "query_system_catalog"),
200        }
201    }
202}
203
204#[derive(Debug, Deserialize)]
205struct GetDataProductDetailsParams {
206    name: String,
207}
208
209#[derive(Debug, Deserialize)]
210struct ReadDataProductParams {
211    name: String,
212    #[serde(default = "default_read_limit")]
213    limit: u32,
214    cluster: Option<String>,
215}
216
217fn default_read_limit() -> u32 {
218    500
219}
220
221/// Maximum number of rows that can be returned by read_data_product.
222const MAX_READ_LIMIT: u32 = 1000;
223
224#[derive(Debug, Deserialize)]
225struct QueryParams {
226    cluster: String,
227    sql_query: String,
228}
229
230#[derive(Debug, Deserialize)]
231struct QuerySystemCatalogParams {
232    sql_query: String,
233}
234
235#[derive(Debug, Serialize)]
236struct McpResponse {
237    jsonrpc: String,
238    id: serde_json::Value,
239    #[serde(skip_serializing_if = "Option::is_none")]
240    result: Option<McpResult>,
241    #[serde(skip_serializing_if = "Option::is_none")]
242    error: Option<McpError>,
243}
244
245/// Typed MCP response results.
246#[derive(Debug, Serialize)]
247#[serde(untagged)]
248enum McpResult {
249    Initialize(InitializeResult),
250    ToolsList(ToolsListResult),
251    ToolContent(ToolContentResult),
252}
253
254#[derive(Debug, Serialize)]
255struct InitializeResult {
256    #[serde(rename = "protocolVersion")]
257    protocol_version: String,
258    capabilities: Capabilities,
259    #[serde(rename = "serverInfo")]
260    server_info: ServerInfo,
261    #[serde(skip_serializing_if = "Option::is_none")]
262    instructions: Option<String>,
263}
264
265#[derive(Debug, Serialize)]
266struct Capabilities {
267    tools: serde_json::Value,
268}
269
270#[derive(Debug, Serialize)]
271struct ServerInfo {
272    name: String,
273    version: String,
274}
275
276#[derive(Debug, Serialize)]
277struct ToolsListResult {
278    tools: Vec<ToolDefinition>,
279}
280
281#[derive(Debug, Serialize)]
282struct ToolDefinition {
283    name: String,
284    #[serde(skip_serializing_if = "Option::is_none")]
285    title: Option<String>,
286    description: String,
287    #[serde(rename = "inputSchema")]
288    input_schema: serde_json::Value,
289    #[serde(skip_serializing_if = "Option::is_none")]
290    annotations: Option<ToolAnnotations>,
291}
292
293/// MCP 2025-11-25 tool annotations that describe tool behavior.
294/// These hints help clients make trust and safety decisions.
295#[derive(Debug, Serialize)]
296struct ToolAnnotations {
297    #[serde(rename = "readOnlyHint", skip_serializing_if = "Option::is_none")]
298    read_only_hint: Option<bool>,
299    #[serde(rename = "destructiveHint", skip_serializing_if = "Option::is_none")]
300    destructive_hint: Option<bool>,
301    #[serde(rename = "idempotentHint", skip_serializing_if = "Option::is_none")]
302    idempotent_hint: Option<bool>,
303    #[serde(rename = "openWorldHint", skip_serializing_if = "Option::is_none")]
304    open_world_hint: Option<bool>,
305}
306
307/// Annotations for all MCP tools: read-only, non-destructive, idempotent.
308const READ_ONLY_ANNOTATIONS: ToolAnnotations = ToolAnnotations {
309    read_only_hint: Some(true),
310    destructive_hint: Some(false),
311    idempotent_hint: Some(true),
312    open_world_hint: Some(false),
313};
314
315#[derive(Debug, Serialize)]
316struct ToolContentResult {
317    content: Vec<ContentBlock>,
318    #[serde(rename = "isError")]
319    is_error: bool,
320}
321
322#[derive(Debug, Serialize)]
323struct ContentBlock {
324    #[serde(rename = "type")]
325    content_type: String,
326    text: String,
327}
328
329/// JSON-RPC 2.0 error codes.
330mod error_codes {
331    pub const INVALID_REQUEST: i32 = -32600;
332    pub const METHOD_NOT_FOUND: i32 = -32601;
333    pub const INVALID_PARAMS: i32 = -32602;
334    pub const INTERNAL_ERROR: i32 = -32603;
335}
336
337#[derive(Debug, Serialize)]
338struct McpError {
339    code: i32,
340    message: String,
341    #[serde(skip_serializing_if = "Option::is_none")]
342    data: Option<serde_json::Value>,
343}
344
345impl From<McpRequestError> for McpError {
346    fn from(err: McpRequestError) -> Self {
347        McpError {
348            code: err.error_code(),
349            message: err.to_string(),
350            data: Some(json!({
351                "error_type": err.error_type(),
352            })),
353        }
354    }
355}
356
357#[derive(Debug, Clone, Copy)]
358enum McpEndpointType {
359    Agents,
360    Observatory,
361}
362
363impl std::fmt::Display for McpEndpointType {
364    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
365        match self {
366            McpEndpointType::Agents => write!(f, "agents"),
367            McpEndpointType::Observatory => write!(f, "observatory"),
368        }
369    }
370}
371
372/// MCP 2025-11-25 requires servers to return 405 for GET requests
373/// on endpoints that only support POST.
374pub async fn handle_mcp_method_not_allowed() -> impl IntoResponse {
375    StatusCode::METHOD_NOT_ALLOWED
376}
377
378/// Agents endpoint: exposes user data products.
379pub async fn handle_mcp_agents(
380    headers: HeaderMap,
381    client: AuthedClient,
382    Json(body): Json<McpRequest>,
383) -> axum::response::Response {
384    if let Some(resp) = validate_origin(&headers) {
385        return resp;
386    }
387    handle_mcp_request(client, body, McpEndpointType::Agents)
388        .await
389        .into_response()
390}
391
392/// Observatory endpoint: exposes system catalog (mz_*) only.
393pub async fn handle_mcp_observatory(
394    headers: HeaderMap,
395    client: AuthedClient,
396    Json(body): Json<McpRequest>,
397) -> axum::response::Response {
398    if let Some(resp) = validate_origin(&headers) {
399        return resp;
400    }
401    handle_mcp_request(client, body, McpEndpointType::Observatory)
402        .await
403        .into_response()
404}
405
406/// Validates the Origin header to prevent DNS rebinding attacks (MCP 2025-11-25).
407/// Returns Some(403) if the Origin is present but doesn't match the Host.
408/// Returns None if the Origin is absent (non-browser client) or valid.
409fn validate_origin(headers: &HeaderMap) -> Option<axum::response::Response> {
410    let origin = match headers.get(http::header::ORIGIN) {
411        Some(o) => o,
412        None => return None, // No Origin header = non-browser client, allow
413    };
414
415    let host = headers
416        .get(http::header::HOST)
417        .and_then(|h| h.to_str().ok());
418
419    let origin_str = origin.to_str().ok().unwrap_or("");
420
421    // Extract host portion from Origin (strip scheme)
422    let origin_host = origin_str
423        .strip_prefix("https://")
424        .or_else(|| origin_str.strip_prefix("http://"))
425        .unwrap_or(origin_str);
426
427    match host {
428        Some(h) if h == origin_host => None, // Origin matches Host
429        _ => {
430            warn!(
431                origin = origin_str,
432                host = ?host,
433                "MCP request rejected: Origin does not match Host",
434            );
435            Some(StatusCode::FORBIDDEN.into_response())
436        }
437    }
438}
439
440async fn handle_mcp_request(
441    mut client: AuthedClient,
442    request: McpRequest,
443    endpoint_type: McpEndpointType,
444) -> impl IntoResponse {
445    // Check the per-endpoint feature flag via a catalog snapshot, similar to frontend_peek.rs.
446    let catalog = client.client.catalog_snapshot("mcp").await;
447    let dyncfgs = catalog.system_config().dyncfgs();
448    let enabled = match endpoint_type {
449        McpEndpointType::Agents => ENABLE_MCP_AGENTS.get(dyncfgs),
450        McpEndpointType::Observatory => ENABLE_MCP_OBSERVATORY.get(dyncfgs),
451    };
452    if !enabled {
453        debug!(endpoint = %endpoint_type, "MCP endpoint disabled by feature flag");
454        return StatusCode::SERVICE_UNAVAILABLE.into_response();
455    }
456
457    let query_tool_enabled = ENABLE_MCP_AGENTS_QUERY_TOOL.get(dyncfgs);
458    let max_response_size = MCP_MAX_RESPONSE_SIZE.get(dyncfgs);
459
460    let user = client.client.session().user().name.clone();
461    let is_notification = request.id.is_none();
462
463    debug!(
464        method = %request.method,
465        endpoint = %endpoint_type,
466        user = %user,
467        is_notification = is_notification,
468        "MCP request received"
469    );
470
471    // Handle notifications (no response needed)
472    if is_notification {
473        debug!(method = %request.method, "Received notification (no response will be sent)");
474        return StatusCode::OK.into_response();
475    }
476
477    let request_id = request.id.clone().unwrap_or(serde_json::Value::Null);
478
479    // Spawn task for fault isolation, with a timeout safety net.
480    let result = tokio::time::timeout(
481        MCP_REQUEST_TIMEOUT,
482        mz_ore::task::spawn(|| "mcp_request", async move {
483            handle_mcp_request_inner(
484                &mut client,
485                request,
486                endpoint_type,
487                query_tool_enabled,
488                max_response_size,
489            )
490            .await
491        }),
492    )
493    .await;
494
495    let response = match result {
496        Ok(inner) => inner,
497        Err(_elapsed) => {
498            warn!(
499                endpoint = %endpoint_type,
500                timeout = ?MCP_REQUEST_TIMEOUT,
501                "MCP request timed out",
502            );
503            McpResponse {
504                jsonrpc: JSONRPC_VERSION.to_string(),
505                id: request_id,
506                result: None,
507                error: Some(
508                    McpRequestError::QueryExecutionFailed(format!(
509                        "Request timed out after {} seconds.",
510                        MCP_REQUEST_TIMEOUT.as_secs(),
511                    ))
512                    .into(),
513                ),
514            }
515        }
516    };
517
518    (StatusCode::OK, Json(response)).into_response()
519}
520
521async fn handle_mcp_request_inner(
522    client: &mut AuthedClient,
523    request: McpRequest,
524    endpoint_type: McpEndpointType,
525    query_tool_enabled: bool,
526    max_response_size: usize,
527) -> McpResponse {
528    // Extract request ID (guaranteed to be Some since notifications are filtered earlier)
529    let request_id = request.id.clone().unwrap_or(serde_json::Value::Null);
530
531    let result = handle_mcp_method(
532        client,
533        &request,
534        endpoint_type,
535        query_tool_enabled,
536        max_response_size,
537    )
538    .await;
539
540    match result {
541        Ok(result_value) => McpResponse {
542            jsonrpc: JSONRPC_VERSION.to_string(),
543            id: request_id,
544            result: Some(result_value),
545            error: None,
546        },
547        Err(e) => {
548            // Log non-trivial errors
549            if !matches!(
550                e,
551                McpRequestError::MethodNotFound(_) | McpRequestError::InvalidJsonRpcVersion
552            ) {
553                warn!(error = %e, method = %request.method, "MCP method execution failed");
554            }
555            McpResponse {
556                jsonrpc: JSONRPC_VERSION.to_string(),
557                id: request_id,
558                result: None,
559                error: Some(e.into()),
560            }
561        }
562    }
563}
564
565async fn handle_mcp_method(
566    client: &mut AuthedClient,
567    request: &McpRequest,
568    endpoint_type: McpEndpointType,
569    query_tool_enabled: bool,
570    max_response_size: usize,
571) -> Result<McpResult, McpRequestError> {
572    // Validate JSON-RPC version
573    if request.jsonrpc != JSONRPC_VERSION {
574        return Err(McpRequestError::InvalidJsonRpcVersion);
575    }
576
577    // Handle different MCP methods using pattern matching
578    match &request.method {
579        McpMethod::Initialize(_) => {
580            debug!(endpoint = %endpoint_type, "Processing initialize");
581            handle_initialize(endpoint_type).await
582        }
583        McpMethod::ToolsList => {
584            debug!(endpoint = %endpoint_type, "Processing tools/list");
585            handle_tools_list(endpoint_type, query_tool_enabled, max_response_size).await
586        }
587        McpMethod::ToolsCall(params) => {
588            debug!(tool = %params, endpoint = %endpoint_type, "Processing tools/call");
589            handle_tools_call(
590                client,
591                params,
592                endpoint_type,
593                query_tool_enabled,
594                max_response_size,
595            )
596            .await
597        }
598        McpMethod::Unknown => Err(McpRequestError::MethodNotFound(
599            "unknown method".to_string(),
600        )),
601    }
602}
603
604async fn handle_initialize(endpoint_type: McpEndpointType) -> Result<McpResult, McpRequestError> {
605    Ok(McpResult::Initialize(InitializeResult {
606        protocol_version: MCP_PROTOCOL_VERSION.to_string(),
607        capabilities: Capabilities { tools: json!({}) },
608        server_info: ServerInfo {
609            name: format!("materialize-mcp-{}", endpoint_type),
610            version: env!("CARGO_PKG_VERSION").to_string(),
611        },
612        instructions: None,
613    }))
614}
615
616async fn handle_tools_list(
617    endpoint_type: McpEndpointType,
618    query_tool_enabled: bool,
619    max_response_size: usize,
620) -> Result<McpResult, McpRequestError> {
621    let size_hint = format!("Response limit: {} MB.", max_response_size / 1_000_000);
622
623    let tools = match endpoint_type {
624        McpEndpointType::Agents => {
625            let mut tools = vec![
626                ToolDefinition {
627                    name: "get_data_products".to_string(),
628                    title: Some("List Data Products".to_string()),
629                    description: "Discover all available real-time data views (data products) that represent business entities like customers, orders, products, etc. Each data product provides fresh, queryable data with defined schemas. Use this first to see what data is available before querying specific information.".to_string(),
630                    input_schema: json!({
631                        "type": "object",
632                        "properties": {},
633                        "required": []
634                    }),
635                    annotations: Some(READ_ONLY_ANNOTATIONS),
636                },
637                ToolDefinition {
638                    name: "get_data_product_details".to_string(),
639                    title: Some("Get Data Product Details".to_string()),
640                    description: "Get the complete schema and structure of a specific data product. This shows you exactly what fields are available, their types, and what data you can query. Use this after finding a data product from get_data_products() to understand how to query it.".to_string(),
641                    input_schema: json!({
642                        "type": "object",
643                        "properties": {
644                            "name": {
645                                "type": "string",
646                                "description": "Exact name of the data product from get_data_products() list"
647                            }
648                        },
649                        "required": ["name"]
650                    }),
651                    annotations: Some(READ_ONLY_ANNOTATIONS),
652                },
653                ToolDefinition {
654                    name: "read_data_product".to_string(),
655                    title: Some("Read Data Product".to_string()),
656                    description: format!("Read rows from a specific data product. Returns up to `limit` rows (default 500, max 1000). The data product must exist in the catalog (use get_data_products() to discover available products). Use this to retrieve actual data from a known data product. {size_hint}"),
657                    input_schema: json!({
658                        "type": "object",
659                        "properties": {
660                            "name": {
661                                "type": "string",
662                                "description": "Exact fully-qualified name of the data product (e.g. '\"materialize\".\"schema\".\"view_name\"')"
663                            },
664                            "limit": {
665                                "type": "integer",
666                                "description": "Maximum number of rows to return (default 500, max 1000)",
667                                "default": 500
668                            },
669                            "cluster": {
670                                "type": "string",
671                                "description": "Optional cluster override. If omitted, uses the cluster from the data product catalog."
672                            }
673                        },
674                        "required": ["name"]
675                    }),
676                    annotations: Some(READ_ONLY_ANNOTATIONS),
677                },
678            ];
679            if query_tool_enabled {
680                tools.push(ToolDefinition {
681                    name: "query".to_string(),
682                    title: Some("Query Data Products".to_string()),
683                    description: format!("Execute SQL queries against real-time data products to retrieve current business information. Use standard PostgreSQL syntax. You can JOIN multiple data products together, but ONLY if they are all hosted on the same cluster. Always specify the cluster parameter from the data product details. This provides fresh, up-to-date results from materialized views. {size_hint}"),
684                    input_schema: json!({
685                        "type": "object",
686                        "properties": {
687                            "cluster": {
688                                "type": "string",
689                                "description": "Exact cluster name from the data product details - required for query execution"
690                            },
691                            "sql_query": {
692                                "type": "string",
693                                "description": "PostgreSQL-compatible SELECT statement to retrieve data. Use the fully qualified data product name exactly as provided (with double quotes). You can JOIN multiple data products, but only those on the same cluster."
694                            }
695                        },
696                        "required": ["cluster", "sql_query"]
697                    }),
698                    annotations: Some(READ_ONLY_ANNOTATIONS),
699                });
700            }
701            tools
702        }
703        McpEndpointType::Observatory => {
704            vec![ToolDefinition {
705                name: "query_system_catalog".to_string(),
706                title: Some("Query System Catalog".to_string()),
707                description: concat!(
708                    "Query Materialize system catalog tables for troubleshooting and observability. ",
709                    "Only mz_*, pg_catalog, and information_schema tables are accessible.\n\n",
710                    "Key tables by scenario:\n",
711                    "- Freshness: mz_internal.mz_wallclock_global_lag_recent_history, mz_internal.mz_materialization_lag, mz_internal.mz_hydration_statuses\n",
712                    "- Memory: mz_internal.mz_cluster_replica_utilization, mz_internal.mz_cluster_replica_metrics, mz_internal.mz_dataflow_arrangement_sizes\n",
713                    "- Cluster health: mz_internal.mz_cluster_replica_statuses, mz_catalog.mz_cluster_replicas\n",
714                    "- Source/Sink health: mz_internal.mz_source_statuses, mz_internal.mz_sink_statuses, mz_internal.mz_source_statistics, mz_internal.mz_sink_statistics\n",
715                    "- Object catalog: mz_catalog.mz_objects (all objects), mz_catalog.mz_tables, mz_catalog.mz_materialized_views, mz_catalog.mz_sources, mz_catalog.mz_sinks\n\n",
716                    "Use SHOW TABLES FROM mz_internal or mz_catalog to discover more tables.",
717                ).to_owned() + &format!(" {size_hint}"),
718                input_schema: json!({
719                    "type": "object",
720                    "properties": {
721                        "sql_query": {
722                            "type": "string",
723                            "description": "PostgreSQL-compatible SELECT, SHOW, or EXPLAIN query referencing mz_* system catalog tables"
724                        }
725                    },
726                    "required": ["sql_query"]
727                }),
728                annotations: Some(READ_ONLY_ANNOTATIONS),
729            }]
730        }
731    };
732
733    Ok(McpResult::ToolsList(ToolsListResult { tools }))
734}
735
736async fn handle_tools_call(
737    client: &mut AuthedClient,
738    params: &ToolsCallParams,
739    endpoint_type: McpEndpointType,
740    query_tool_enabled: bool,
741    max_response_size: usize,
742) -> Result<McpResult, McpRequestError> {
743    match (endpoint_type, params) {
744        (McpEndpointType::Agents, ToolsCallParams::GetDataProducts(_)) => {
745            get_data_products(client, max_response_size).await
746        }
747        (McpEndpointType::Agents, ToolsCallParams::GetDataProductDetails(p)) => {
748            get_data_product_details(client, &p.name, max_response_size).await
749        }
750        (McpEndpointType::Agents, ToolsCallParams::ReadDataProduct(p)) => {
751            read_data_product(
752                client,
753                &p.name,
754                p.limit,
755                p.cluster.as_deref(),
756                max_response_size,
757            )
758            .await
759        }
760        (McpEndpointType::Agents, ToolsCallParams::Query(_)) if !query_tool_enabled => {
761            Err(McpRequestError::ToolNotFound(
762                "query tool is not available. Use get_data_products, get_data_product_details, and read_data_product instead.".to_string(),
763            ))
764        }
765        (McpEndpointType::Agents, ToolsCallParams::Query(p)) => {
766            execute_query(client, &p.cluster, &p.sql_query, max_response_size).await
767        }
768        (McpEndpointType::Observatory, ToolsCallParams::QuerySystemCatalog(p)) => {
769            query_system_catalog(client, &p.sql_query, max_response_size).await
770        }
771        // Tool called on wrong endpoint
772        (endpoint, tool) => Err(McpRequestError::ToolNotFound(format!(
773            "{} is not available on {} endpoint",
774            tool, endpoint
775        ))),
776    }
777}
778
779/// Execute SQL via `execute_request` from sql.rs.
780async fn execute_sql(
781    client: &mut AuthedClient,
782    query: &str,
783) -> Result<Vec<Vec<serde_json::Value>>, McpRequestError> {
784    let mut response = SqlResponse::new();
785
786    execute_request(
787        client,
788        SqlRequest::Simple {
789            query: query.to_string(),
790        },
791        &mut response,
792    )
793    .await
794    .map_err(|e| McpRequestError::QueryExecutionFailed(e.to_string()))?;
795
796    // Extract the result with rows (the user's single SELECT/SHOW query)
797    // Other results will be OK (from BEGIN, SET, COMMIT) or Err
798    for result in response.results {
799        match result {
800            SqlResult::Rows { rows, .. } => return Ok(rows),
801            SqlResult::Err { error, .. } => {
802                return Err(McpRequestError::QueryExecutionFailed(error.message));
803            }
804            SqlResult::Ok { .. } => continue,
805        }
806    }
807
808    Err(McpRequestError::QueryExecutionFailed(
809        "Query did not return any results".to_string(),
810    ))
811}
812
813/// Serialize rows to JSON and enforce the response size cap.
814///
815/// If the serialized response exceeds `max_size` bytes, returns an error
816/// telling the agent to narrow its query. This mirrors how the HTTP SQL
817/// endpoint handles `max_result_size` in sql.rs — fail cleanly rather
818/// than silently truncating.
819fn format_rows_response(
820    rows: Vec<Vec<serde_json::Value>>,
821    max_size: usize,
822) -> Result<McpResult, McpRequestError> {
823    let text =
824        serde_json::to_string_pretty(&rows).map_err(|e| McpRequestError::Internal(anyhow!(e)))?;
825
826    if text.len() > max_size {
827        return Err(McpRequestError::QueryExecutionFailed(format!(
828            "Response size ({} bytes) exceeds the {} byte limit. \
829             Use LIMIT or WHERE to narrow your query.",
830            text.len(),
831            max_size,
832        )));
833    }
834
835    Ok(McpResult::ToolContent(ToolContentResult {
836        content: vec![ContentBlock {
837            content_type: "text".to_string(),
838            text,
839        }],
840        is_error: false,
841    }))
842}
843
844async fn get_data_products(
845    client: &mut AuthedClient,
846    max_response_size: usize,
847) -> Result<McpResult, McpRequestError> {
848    debug!("Executing get_data_products");
849    let rows = execute_sql(client, DISCOVERY_QUERY).await?;
850    debug!("get_data_products returned {} rows", rows.len());
851
852    format_rows_response(rows, max_response_size)
853}
854
855async fn get_data_product_details(
856    client: &mut AuthedClient,
857    name: &str,
858    max_response_size: usize,
859) -> Result<McpResult, McpRequestError> {
860    debug!(name = %name, "Executing get_data_product_details");
861
862    let query = format!("{}{}", DETAILS_QUERY_PREFIX, escaped_string_literal(name));
863
864    let rows = execute_sql(client, &query).await?;
865
866    if rows.is_empty() {
867        return Err(McpRequestError::DataProductNotFound(name.to_string()));
868    }
869
870    format_rows_response(rows, max_response_size)
871}
872
873/// Parses a data product name and returns it safely quoted for SQL interpolation.
874///
875/// Uses the SQL parser to validate the name as an `UnresolvedItemName`, then
876/// formats it with `FormatMode::Stable` so every identifier part is
877/// double-quoted with proper escaping. This prevents SQL injection regardless
878/// of the input.
879fn safe_data_product_name(name: &str) -> Result<String, McpRequestError> {
880    let name = name.trim();
881    if name.is_empty() {
882        return Err(McpRequestError::QueryValidationFailed(
883            "Data product name cannot be empty".to_string(),
884        ));
885    }
886
887    let parsed = parse_item_name(name).map_err(|_| {
888        McpRequestError::QueryValidationFailed(format!(
889            "Invalid data product name: {}. Expected a valid object name, \
890             e.g. '\"database\".\"schema\".\"name\"' or 'my_view'",
891            name
892        ))
893    })?;
894
895    // Stable formatting forces all identifiers to be double-quoted,
896    // so SQL keywords and special characters cannot escape.
897    Ok(parsed.to_ast_string_stable())
898}
899
900/// Read rows from a data product. Issues a single read-only query.
901///
902/// When `cluster_override` is provided, sets the cluster explicitly.
903/// Otherwise the query runs on the session's default cluster.
904///
905/// The name is expected to come from `get_data_products()` / `get_data_product_details()`.
906/// The query runs inside a READ ONLY transaction, preventing mutations.
907async fn read_data_product(
908    client: &mut AuthedClient,
909    name: &str,
910    limit: u32,
911    cluster_override: Option<&str>,
912    max_response_size: usize,
913) -> Result<McpResult, McpRequestError> {
914    debug!(name = %name, limit = limit, cluster_override = ?cluster_override, "Executing read_data_product");
915
916    // Parse and safely quote the name for SQL interpolation.
917    let safe_name = safe_data_product_name(name)?;
918
919    // Lightweight existence check: verify the data product is visible in the
920    // catalog before running the read query. This gives a clean DataProductNotFound
921    // error for missing or inaccessible products (including RBAC revocations)
922    // without relying on fragile error code matching.
923    // TODO: Remove this extra round-trip once catalog errors get specific SQL
924    // error codes (see TODO in src/adapter/src/error.rs `fn code()`), then we
925    // can translate the query error directly.
926    let lookup_query = format!(
927        "SELECT 1 FROM mz_internal.mz_mcp_data_products WHERE object_name = {}",
928        escaped_string_literal(name)
929    );
930    let lookup_rows = execute_sql(client, &lookup_query).await?;
931    if lookup_rows.is_empty() {
932        return Err(McpRequestError::DataProductNotFound(name.to_string()));
933    }
934
935    let clamped_limit = limit.min(MAX_READ_LIMIT);
936
937    let read_query = match cluster_override {
938        Some(cluster) => format!(
939            "BEGIN READ ONLY; SET CLUSTER = {}; SELECT * FROM {} LIMIT {}; COMMIT;",
940            escaped_string_literal(cluster),
941            safe_name,
942            clamped_limit,
943        ),
944        // Single statement — skip explicit transaction for better performance.
945        None => format!("SELECT * FROM {} LIMIT {}", safe_name, clamped_limit),
946    };
947
948    let rows = execute_sql(client, &read_query).await?;
949
950    format_rows_response(rows, max_response_size)
951}
952
953/// Validates query is a single SELECT, SHOW, or EXPLAIN statement.
954fn validate_readonly_query(sql: &str) -> Result<(), McpRequestError> {
955    let sql = sql.trim();
956    if sql.is_empty() {
957        return Err(McpRequestError::QueryValidationFailed(
958            "Empty query".to_string(),
959        ));
960    }
961
962    // Parse the SQL to get AST
963    let stmts = parse(sql).map_err(|e| {
964        McpRequestError::QueryValidationFailed(format!("Failed to parse SQL: {}", e))
965    })?;
966
967    // Only allow a single statement
968    if stmts.len() != 1 {
969        return Err(McpRequestError::QueryValidationFailed(format!(
970            "Only one query allowed at a time. Found {} statements.",
971            stmts.len()
972        )));
973    }
974
975    // Allowlist: Only SELECT, SHOW, and EXPLAIN statements permitted
976    let stmt = &stmts[0];
977    use mz_sql_parser::ast::Statement;
978
979    match &stmt.ast {
980        Statement::Select(_) | Statement::Show(_) | Statement::ExplainPlan(_) => {
981            // Allowed - read-only operations
982            Ok(())
983        }
984        _ => Err(McpRequestError::QueryValidationFailed(
985            "Only SELECT, SHOW, and EXPLAIN statements are allowed".to_string(),
986        )),
987    }
988}
989
990async fn execute_query(
991    client: &mut AuthedClient,
992    cluster: &str,
993    sql_query: &str,
994    max_response_size: usize,
995) -> Result<McpResult, McpRequestError> {
996    debug!(cluster = %cluster, "Executing user query");
997
998    validate_readonly_query(sql_query)?;
999
1000    // Use READ ONLY transaction to prevent modifications
1001    // Combine with SET CLUSTER (prometheus.rs:29-33 pattern)
1002    let combined_query = format!(
1003        "BEGIN READ ONLY; SET CLUSTER = {}; {}; COMMIT;",
1004        escaped_string_literal(cluster),
1005        sql_query
1006    );
1007
1008    let rows = execute_sql(client, &combined_query).await?;
1009
1010    format_rows_response(rows, max_response_size)
1011}
1012
1013async fn query_system_catalog(
1014    client: &mut AuthedClient,
1015    sql_query: &str,
1016    max_response_size: usize,
1017) -> Result<McpResult, McpRequestError> {
1018    debug!("Executing query_system_catalog");
1019
1020    // First validate it's a read-only query
1021    validate_readonly_query(sql_query)?;
1022
1023    // Then validate that query only references mz_* tables by parsing the SQL
1024    validate_system_catalog_query(sql_query)?;
1025
1026    // Single statement — skip explicit transaction for better performance.
1027    // Read-only safety is enforced by validate_readonly_query above.
1028    let rows = execute_sql(client, sql_query).await?;
1029
1030    format_rows_response(rows, max_response_size)
1031}
1032
1033/// Collects table references from SQL AST with their schema qualification.
1034struct TableReferenceCollector {
1035    /// Stores (schema, table_name) tuples. Schema is None if unqualified.
1036    tables: Vec<(Option<String>, String)>,
1037    /// CTE names to exclude from validation (they're not real tables)
1038    cte_names: std::collections::BTreeSet<String>,
1039}
1040
1041impl TableReferenceCollector {
1042    fn new() -> Self {
1043        Self {
1044            tables: Vec::new(),
1045            cte_names: std::collections::BTreeSet::new(),
1046        }
1047    }
1048}
1049
1050impl<'ast> Visit<'ast, Raw> for TableReferenceCollector {
1051    fn visit_cte(&mut self, cte: &'ast mz_sql_parser::ast::Cte<Raw>) {
1052        // Track CTE names so we don't treat them as table references
1053        self.cte_names
1054            .insert(cte.alias.name.as_str().to_lowercase());
1055        visit::visit_cte(self, cte);
1056    }
1057
1058    fn visit_table_factor(&mut self, table_factor: &'ast mz_sql_parser::ast::TableFactor<Raw>) {
1059        // Only visit actual table references in FROM/JOIN clauses, not function names
1060        if let mz_sql_parser::ast::TableFactor::Table { name, .. } = table_factor {
1061            match name {
1062                RawItemName::Name(n) | RawItemName::Id(_, n, _) => {
1063                    let parts = &n.0;
1064                    if !parts.is_empty() {
1065                        let table_name = parts.last().unwrap().as_str().to_lowercase();
1066
1067                        // Skip if this is a CTE reference, not a real table
1068                        if self.cte_names.contains(&table_name) {
1069                            visit::visit_table_factor(self, table_factor);
1070                            return;
1071                        }
1072
1073                        // Extract schema if qualified (e.g., mz_catalog.mz_tables)
1074                        let schema = if parts.len() >= 2 {
1075                            Some(parts[parts.len() - 2].as_str().to_lowercase())
1076                        } else {
1077                            None
1078                        };
1079                        self.tables.push((schema, table_name));
1080                    }
1081                }
1082            }
1083        }
1084        visit::visit_table_factor(self, table_factor);
1085    }
1086}
1087
1088/// Validates that a query only references system catalog tables.
1089///
1090/// For SELECT statements, all table references must be in system schemas
1091/// (from `SYSTEM_SCHEMAS`, excluding `mz_unsafe`), and at least one system
1092/// table must be referenced (constant queries like `SELECT 1` are rejected
1093/// to prevent misuse of the observatory endpoint for arbitrary computation).
1094/// SHOW and EXPLAIN statements are allowed without table references.
1095fn validate_system_catalog_query(sql: &str) -> Result<(), McpRequestError> {
1096    // Parse the SQL to validate it
1097    let stmts = parse(sql).map_err(|e| {
1098        McpRequestError::QueryValidationFailed(format!("Failed to parse SQL: {}", e))
1099    })?;
1100
1101    if stmts.is_empty() {
1102        return Err(McpRequestError::QueryValidationFailed(
1103            "Empty query".to_string(),
1104        ));
1105    }
1106
1107    // Walk the AST to collect all table references
1108    let mut collector = TableReferenceCollector::new();
1109    for stmt in &stmts {
1110        collector.visit_statement(&stmt.ast);
1111    }
1112
1113    // Use the canonical system schema list, excluding mz_unsafe which contains
1114    // internal-only objects that should not be exposed to MCP clients.
1115    let is_allowed_schema =
1116        |s: &str| SYSTEM_SCHEMAS.contains(&s) && s != namespaces::MZ_UNSAFE_SCHEMA;
1117
1118    // Helper to check if a table reference is allowed
1119    let is_system_table = |(schema, table_name): &(Option<String>, String)| {
1120        match schema {
1121            // Explicitly qualified with allowed schema
1122            Some(s) => is_allowed_schema(s.as_str()),
1123            // Unqualified: allow if starts with mz_ (common Materialize system tables).
1124            // Note: user tables can also start with mz_, but unqualified references
1125            // resolve to the user's search_path first, so this is a best-effort check.
1126            None => table_name.starts_with("mz_"),
1127        }
1128    };
1129
1130    // Check that all table references are system tables
1131    let non_system_tables: Vec<String> = collector
1132        .tables
1133        .iter()
1134        .filter(|t| !is_system_table(t))
1135        .map(|(schema, table)| match schema {
1136            Some(s) => format!("{}.{}", s, table),
1137            None => table.clone(),
1138        })
1139        .collect();
1140
1141    if !non_system_tables.is_empty() {
1142        return Err(McpRequestError::QueryValidationFailed(format!(
1143            "Query references non-system tables: {}. Only system catalog tables (mz_*, pg_catalog, information_schema) are allowed.",
1144            non_system_tables.join(", ")
1145        )));
1146    }
1147
1148    // SHOW and EXPLAIN statements don't reference tables in the AST, but are safe
1149    // read-only operations. Only require system table references for SELECT.
1150    use mz_sql_parser::ast::Statement;
1151    let is_select = stmts.iter().any(|s| matches!(&s.ast, Statement::Select(_)));
1152
1153    if is_select && (collector.tables.is_empty() || !collector.tables.iter().any(is_system_table)) {
1154        return Err(McpRequestError::QueryValidationFailed(
1155            "Query must reference at least one system catalog table".to_string(),
1156        ));
1157    }
1158
1159    Ok(())
1160}
1161
1162#[cfg(test)]
1163mod tests {
1164    use super::*;
1165
1166    #[mz_ore::test]
1167    fn test_validate_readonly_query_select() {
1168        assert!(validate_readonly_query("SELECT * FROM mz_tables").is_ok());
1169        assert!(validate_readonly_query("SELECT 1 + 2").is_ok());
1170        assert!(validate_readonly_query("  SELECT 1  ").is_ok());
1171    }
1172
1173    #[mz_ore::test]
1174    fn test_validate_readonly_query_subqueries() {
1175        // Simple subquery in WHERE clause
1176        assert!(
1177            validate_readonly_query(
1178                "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM mz_columns)"
1179            )
1180            .is_ok()
1181        );
1182
1183        // Subquery in FROM clause
1184        assert!(
1185            validate_readonly_query(
1186                "SELECT * FROM (SELECT * FROM mz_tables WHERE name LIKE 'test%') AS t"
1187            )
1188            .is_ok()
1189        );
1190
1191        // Correlated subquery
1192        assert!(validate_readonly_query(
1193            "SELECT * FROM mz_tables t WHERE EXISTS (SELECT 1 FROM mz_columns c WHERE c.id = t.id)"
1194        )
1195        .is_ok());
1196
1197        // Nested subqueries
1198        assert!(validate_readonly_query(
1199            "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM mz_columns WHERE type IN (SELECT name FROM mz_types))"
1200        )
1201        .is_ok());
1202
1203        // Subquery with aggregation
1204        assert!(
1205            validate_readonly_query(
1206                "SELECT * FROM mz_tables WHERE id = (SELECT MAX(id) FROM mz_columns)"
1207            )
1208            .is_ok()
1209        );
1210    }
1211
1212    #[mz_ore::test]
1213    fn test_validate_readonly_query_show() {
1214        assert!(validate_readonly_query("SHOW CLUSTERS").is_ok());
1215        assert!(validate_readonly_query("SHOW TABLES").is_ok());
1216    }
1217
1218    #[mz_ore::test]
1219    fn test_validate_readonly_query_explain() {
1220        assert!(validate_readonly_query("EXPLAIN SELECT 1").is_ok());
1221    }
1222
1223    #[mz_ore::test]
1224    fn test_validate_readonly_query_rejects_writes() {
1225        assert!(validate_readonly_query("INSERT INTO t VALUES (1)").is_err());
1226        assert!(validate_readonly_query("UPDATE t SET a = 1").is_err());
1227        assert!(validate_readonly_query("DELETE FROM t").is_err());
1228        assert!(validate_readonly_query("CREATE TABLE t (a INT)").is_err());
1229        assert!(validate_readonly_query("DROP TABLE t").is_err());
1230    }
1231
1232    #[mz_ore::test]
1233    fn test_validate_readonly_query_rejects_multiple() {
1234        assert!(validate_readonly_query("SELECT 1; SELECT 2").is_err());
1235    }
1236
1237    #[mz_ore::test]
1238    fn test_validate_readonly_query_rejects_empty() {
1239        assert!(validate_readonly_query("").is_err());
1240        assert!(validate_readonly_query("   ").is_err());
1241    }
1242
1243    #[mz_ore::test]
1244    fn test_validate_system_catalog_query_accepts_mz_tables() {
1245        assert!(validate_system_catalog_query("SELECT * FROM mz_tables").is_ok());
1246        assert!(validate_system_catalog_query("SELECT * FROM mz_internal.mz_comments").is_ok());
1247        assert!(
1248            validate_system_catalog_query(
1249                "SELECT * FROM mz_tables t JOIN mz_columns c ON t.id = c.id"
1250            )
1251            .is_ok()
1252        );
1253    }
1254
1255    #[mz_ore::test]
1256    fn test_validate_system_catalog_query_subqueries() {
1257        // Subquery with mz_* tables
1258        assert!(
1259            validate_system_catalog_query(
1260                "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM mz_columns)"
1261            )
1262            .is_ok()
1263        );
1264
1265        // Nested subqueries with mz_* tables
1266        assert!(validate_system_catalog_query(
1267            "SELECT * FROM mz_tables WHERE id IN (SELECT table_id FROM mz_columns WHERE type IN (SELECT id FROM mz_types))"
1268        )
1269        .is_ok());
1270
1271        // Subquery in FROM clause
1272        assert!(
1273            validate_system_catalog_query(
1274                "SELECT * FROM (SELECT * FROM mz_tables WHERE name LIKE 'test%') AS t"
1275            )
1276            .is_ok()
1277        );
1278
1279        // Reject subqueries that reference non-mz_* tables
1280        assert!(
1281            validate_system_catalog_query(
1282                "SELECT * FROM mz_tables WHERE id IN (SELECT table_id FROM user_data)"
1283            )
1284            .is_err()
1285        );
1286
1287        // Reject mixed references in nested subqueries
1288        assert!(validate_system_catalog_query(
1289            "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM (SELECT id FROM user_table) AS t)"
1290        )
1291        .is_err());
1292    }
1293
1294    #[mz_ore::test]
1295    fn test_validate_system_catalog_query_rejects_user_tables() {
1296        assert!(validate_system_catalog_query("SELECT * FROM user_data").is_err());
1297        assert!(validate_system_catalog_query("SELECT * FROM my_table").is_err());
1298        // Security: reject queries that mention mz_ in a non-table context
1299        assert!(
1300            validate_system_catalog_query("SELECT * FROM user_data WHERE 'mz_' IS NOT NULL")
1301                .is_err()
1302        );
1303    }
1304
1305    #[mz_ore::test]
1306    fn test_validate_system_catalog_query_allows_functions() {
1307        // Function names should not be treated as table references
1308        assert!(
1309            validate_system_catalog_query(
1310                "SELECT date_part('year', now())::int4 AS y FROM mz_tables LIMIT 1"
1311            )
1312            .is_ok()
1313        );
1314        assert!(validate_system_catalog_query("SELECT length(name) FROM mz_tables").is_ok());
1315        assert!(
1316            validate_system_catalog_query(
1317                "SELECT count(*) FROM mz_sources WHERE now() > created_at"
1318            )
1319            .is_ok()
1320        );
1321    }
1322
1323    #[mz_ore::test]
1324    fn test_validate_system_catalog_query_schema_qualified() {
1325        // Qualified with allowed schemas should work
1326        assert!(validate_system_catalog_query("SELECT * FROM mz_catalog.mz_tables").is_ok());
1327        assert!(validate_system_catalog_query("SELECT * FROM mz_internal.mz_sessions").is_ok());
1328        assert!(validate_system_catalog_query("SELECT * FROM pg_catalog.pg_type").is_ok());
1329        assert!(validate_system_catalog_query("SELECT * FROM information_schema.tables").is_ok());
1330
1331        // Qualified with disallowed schema should fail
1332        assert!(validate_system_catalog_query("SELECT * FROM public.user_table").is_err());
1333        assert!(validate_system_catalog_query("SELECT * FROM myschema.mytable").is_err());
1334
1335        // mz_unsafe is a system schema but explicitly blocked for MCP
1336        assert!(
1337            validate_system_catalog_query("SELECT * FROM mz_unsafe.mz_some_table").is_err(),
1338            "mz_unsafe schema should be blocked even though it is a system schema"
1339        );
1340
1341        // Mixed: system and user schemas should fail
1342        assert!(
1343            validate_system_catalog_query(
1344                "SELECT * FROM mz_catalog.mz_tables JOIN public.user_data ON true"
1345            )
1346            .is_err()
1347        );
1348    }
1349
1350    #[mz_ore::test]
1351    fn test_validate_system_catalog_query_adversarial_cases() {
1352        // Try to sneak in user table via CTE
1353        assert!(
1354            validate_system_catalog_query(
1355                "WITH user_cte AS (SELECT * FROM user_data) \
1356                 SELECT * FROM mz_tables, user_cte"
1357            )
1358            .is_err(),
1359            "Should reject CTE referencing user table"
1360        );
1361
1362        // Complex multi-level CTE with user table buried deep
1363        assert!(
1364            validate_system_catalog_query(
1365                "WITH \
1366                   cte1 AS (SELECT * FROM mz_tables), \
1367                   cte2 AS (SELECT * FROM cte1), \
1368                   cte3 AS (SELECT * FROM user_data) \
1369                 SELECT * FROM cte2"
1370            )
1371            .is_err(),
1372            "Should reject CTE chain with user table"
1373        );
1374
1375        // Multiple joins - user table in the middle
1376        assert!(
1377            validate_system_catalog_query(
1378                "SELECT * FROM mz_tables t1 \
1379                 JOIN user_data u ON t1.id = u.id \
1380                 JOIN mz_sources s ON t1.id = s.id"
1381            )
1382            .is_err(),
1383            "Should reject multi-join with user table"
1384        );
1385
1386        // LEFT JOIN trying to hide user table
1387        assert!(
1388            validate_system_catalog_query(
1389                "SELECT * FROM mz_tables t \
1390                 LEFT JOIN user_data u ON t.id = u.table_id \
1391                 WHERE u.id IS NULL"
1392            )
1393            .is_err(),
1394            "Should reject LEFT JOIN with user table"
1395        );
1396
1397        // Nested subquery with user table in FROM
1398        assert!(
1399            validate_system_catalog_query(
1400                "SELECT * FROM mz_tables WHERE id IN \
1401                 (SELECT table_id FROM (SELECT * FROM user_data) AS u)"
1402            )
1403            .is_err(),
1404            "Should reject nested subquery with user table"
1405        );
1406
1407        // UNION trying to mix system and user data
1408        assert!(
1409            validate_system_catalog_query(
1410                "SELECT name FROM mz_tables \
1411                 UNION \
1412                 SELECT name FROM user_data"
1413            )
1414            .is_err(),
1415            "Should reject UNION with user table"
1416        );
1417
1418        // UNION ALL variation
1419        assert!(
1420            validate_system_catalog_query(
1421                "SELECT id FROM mz_sources \
1422                 UNION ALL \
1423                 SELECT id FROM products"
1424            )
1425            .is_err(),
1426            "Should reject UNION ALL with user table"
1427        );
1428
1429        // Cross join with user table
1430        assert!(
1431            validate_system_catalog_query("SELECT * FROM mz_tables CROSS JOIN user_data").is_err(),
1432            "Should reject CROSS JOIN with user table"
1433        );
1434
1435        // Subquery in SELECT clause referencing user table
1436        assert!(
1437            validate_system_catalog_query(
1438                "SELECT t.*, (SELECT COUNT(*) FROM user_data) AS cnt FROM mz_tables t"
1439            )
1440            .is_err(),
1441            "Should reject subquery in SELECT with user table"
1442        );
1443
1444        // Try to use a schema name that looks similar to allowed ones
1445        assert!(
1446            validate_system_catalog_query("SELECT * FROM mz_catalogg.fake_table").is_err(),
1447            "Should reject typo-squatting schema name"
1448        );
1449        assert!(
1450            validate_system_catalog_query("SELECT * FROM mz_catalog_hack.fake_table").is_err(),
1451            "Should reject fake schema with mz_catalog prefix"
1452        );
1453
1454        // Lateral join with user table
1455        assert!(
1456            validate_system_catalog_query(
1457                "SELECT * FROM mz_tables t, LATERAL (SELECT * FROM user_data WHERE id = t.id) u"
1458            )
1459            .is_err(),
1460            "Should reject LATERAL join with user table"
1461        );
1462
1463        // Valid complex query - all system tables
1464        assert!(
1465            validate_system_catalog_query(
1466                "WITH \
1467                   tables AS (SELECT * FROM mz_tables), \
1468                   sources AS (SELECT * FROM mz_sources) \
1469                 SELECT t.name, s.name \
1470                 FROM tables t \
1471                 JOIN sources s ON t.id = s.id \
1472                 WHERE t.id IN (SELECT id FROM mz_columns)"
1473            )
1474            .is_ok(),
1475            "Should allow complex query with only system tables"
1476        );
1477
1478        // Valid UNION of system tables
1479        assert!(
1480            validate_system_catalog_query(
1481                "SELECT name FROM mz_tables \
1482                 UNION \
1483                 SELECT name FROM mz_sources"
1484            )
1485            .is_ok(),
1486            "Should allow UNION of system tables"
1487        );
1488    }
1489
1490    #[mz_ore::test]
1491    fn test_validate_system_catalog_query_rejects_constant_queries() {
1492        // SELECT without any table reference should be rejected — the observatory
1493        // endpoint is for system catalog queries, not arbitrary computation.
1494        assert!(
1495            validate_system_catalog_query("SELECT 1").is_err(),
1496            "Should reject constant SELECT with no table references"
1497        );
1498        assert!(
1499            validate_system_catalog_query("SELECT 1 + 2, 'hello'").is_err(),
1500            "Should reject constant expression SELECT"
1501        );
1502        assert!(
1503            validate_system_catalog_query("SELECT now()").is_err(),
1504            "Should reject function-only SELECT with no table references"
1505        );
1506    }
1507
1508    #[mz_ore::test]
1509    fn test_validate_system_catalog_query_rejects_mixed_tables() {
1510        assert!(
1511            validate_system_catalog_query(
1512                "SELECT * FROM mz_tables t JOIN user_data u ON t.id = u.table_id"
1513            )
1514            .is_err()
1515        );
1516    }
1517
1518    #[mz_ore::test]
1519    fn test_validate_system_catalog_query_allows_show() {
1520        // SHOW queries don't reference tables in the AST but are safe read-only ops
1521        assert!(
1522            validate_system_catalog_query("SHOW TABLES FROM mz_internal").is_ok(),
1523            "SHOW TABLES FROM mz_internal should be allowed"
1524        );
1525        assert!(
1526            validate_system_catalog_query("SHOW TABLES FROM mz_catalog").is_ok(),
1527            "SHOW TABLES FROM mz_catalog should be allowed"
1528        );
1529        assert!(
1530            validate_system_catalog_query("SHOW CLUSTERS").is_ok(),
1531            "SHOW CLUSTERS should be allowed"
1532        );
1533        assert!(
1534            validate_system_catalog_query("SHOW SOURCES").is_ok(),
1535            "SHOW SOURCES should be allowed"
1536        );
1537        assert!(
1538            validate_system_catalog_query("SHOW TABLES").is_ok(),
1539            "SHOW TABLES should be allowed"
1540        );
1541    }
1542
1543    #[mz_ore::test]
1544    fn test_validate_system_catalog_query_allows_explain() {
1545        assert!(
1546            validate_system_catalog_query("EXPLAIN SELECT * FROM mz_tables").is_ok(),
1547            "EXPLAIN of system table query should be allowed"
1548        );
1549        assert!(
1550            validate_system_catalog_query("EXPLAIN SELECT 1").is_ok(),
1551            "EXPLAIN SELECT 1 should be allowed"
1552        );
1553    }
1554
1555    // ── Query tool feature flag tests ──────────────────────────────────────
1556
1557    #[mz_ore::test(tokio::test)]
1558    async fn test_tools_list_agents_query_tool_disabled() {
1559        let result = handle_tools_list(McpEndpointType::Agents, false, 1_000_000)
1560            .await
1561            .unwrap();
1562        let McpResult::ToolsList(list) = result else {
1563            panic!("Expected ToolsList result");
1564        };
1565        let tool_names: Vec<&str> = list.tools.iter().map(|t| t.name.as_str()).collect();
1566        assert!(
1567            tool_names.contains(&"get_data_products"),
1568            "get_data_products should always be present"
1569        );
1570        assert!(
1571            tool_names.contains(&"get_data_product_details"),
1572            "get_data_product_details should always be present"
1573        );
1574        assert!(
1575            tool_names.contains(&"read_data_product"),
1576            "read_data_product should always be present"
1577        );
1578        assert!(
1579            !tool_names.contains(&"query"),
1580            "query tool should be hidden when disabled"
1581        );
1582    }
1583
1584    #[mz_ore::test(tokio::test)]
1585    async fn test_tools_list_agents_query_tool_enabled() {
1586        let result = handle_tools_list(McpEndpointType::Agents, true, 1_000_000)
1587            .await
1588            .unwrap();
1589        let McpResult::ToolsList(list) = result else {
1590            panic!("Expected ToolsList result");
1591        };
1592        let tool_names: Vec<&str> = list.tools.iter().map(|t| t.name.as_str()).collect();
1593        assert!(
1594            tool_names.contains(&"get_data_products"),
1595            "get_data_products should always be present"
1596        );
1597        assert!(
1598            tool_names.contains(&"get_data_product_details"),
1599            "get_data_product_details should always be present"
1600        );
1601        assert!(
1602            tool_names.contains(&"read_data_product"),
1603            "read_data_product should always be present"
1604        );
1605        assert!(
1606            tool_names.contains(&"query"),
1607            "query tool should be present when enabled"
1608        );
1609    }
1610
1611    #[mz_ore::test(tokio::test)]
1612    async fn test_tools_list_observatory_unaffected_by_query_flag() {
1613        // Observatory endpoint should not be affected by the query tool flag
1614        for flag in [true, false] {
1615            let result = handle_tools_list(McpEndpointType::Observatory, flag, 1_000_000)
1616                .await
1617                .unwrap();
1618            let McpResult::ToolsList(list) = result else {
1619                panic!("Expected ToolsList result");
1620            };
1621            let tool_names: Vec<&str> = list.tools.iter().map(|t| t.name.as_str()).collect();
1622            assert!(
1623                tool_names.contains(&"query_system_catalog"),
1624                "query_system_catalog should always be present on observatory"
1625            );
1626            assert!(
1627                !tool_names.contains(&"query"),
1628                "query tool should never appear on observatory"
1629            );
1630        }
1631    }
1632
1633    // ── Data product name validation tests ─────────────────────────────
1634
1635    #[mz_ore::test]
1636    fn test_safe_data_product_name_valid() {
1637        // Fully qualified quoted identifiers
1638        assert_eq!(
1639            safe_data_product_name(r#""materialize"."public"."my_view""#).unwrap(),
1640            r#""materialize"."public"."my_view""#
1641        );
1642        // Two-part name
1643        assert_eq!(
1644            safe_data_product_name(r#""public"."my_view""#).unwrap(),
1645            r#""public"."my_view""#
1646        );
1647        // Unquoted name gets quoted in stable mode
1648        assert_eq!(safe_data_product_name("my_view").unwrap(), r#""my_view""#);
1649    }
1650
1651    #[mz_ore::test]
1652    fn test_safe_data_product_name_rejects_empty() {
1653        assert!(safe_data_product_name("").is_err());
1654        assert!(safe_data_product_name("   ").is_err());
1655    }
1656
1657    #[mz_ore::test]
1658    fn test_safe_data_product_name_rejects_sql_injection() {
1659        // Attempted injection via semicolon
1660        assert!(safe_data_product_name("my_view; DROP TABLE users").is_err());
1661        // Attempted injection via subquery
1662        assert!(safe_data_product_name("my_view UNION SELECT * FROM secrets").is_err());
1663        // Multiple table references via comma
1664        assert!(safe_data_product_name("my_view, secrets").is_err());
1665        // SQL keywords after name are rejected by the parser
1666        assert!(safe_data_product_name("my_view WHERE 1=1 --").is_err());
1667    }
1668
1669    // ── Origin validation tests ────────────────────────────────────────
1670
1671    #[mz_ore::test]
1672    fn test_validate_origin_no_header() {
1673        let headers = HeaderMap::new();
1674        assert!(validate_origin(&headers).is_none(), "No Origin = allow");
1675    }
1676
1677    #[mz_ore::test]
1678    fn test_validate_origin_matching() {
1679        let mut headers = HeaderMap::new();
1680        headers.insert(http::header::ORIGIN, "https://example.com".parse().unwrap());
1681        headers.insert(http::header::HOST, "example.com".parse().unwrap());
1682        assert!(
1683            validate_origin(&headers).is_none(),
1684            "Matching Origin = allow"
1685        );
1686    }
1687
1688    #[mz_ore::test]
1689    fn test_validate_origin_mismatch() {
1690        let mut headers = HeaderMap::new();
1691        headers.insert(http::header::ORIGIN, "https://evil.com".parse().unwrap());
1692        headers.insert(http::header::HOST, "example.com".parse().unwrap());
1693        assert!(
1694            validate_origin(&headers).is_some(),
1695            "Mismatched Origin = reject"
1696        );
1697    }
1698
1699    #[mz_ore::test]
1700    fn test_validate_origin_no_host() {
1701        let mut headers = HeaderMap::new();
1702        headers.insert(http::header::ORIGIN, "https://example.com".parse().unwrap());
1703        assert!(
1704            validate_origin(&headers).is_some(),
1705            "Origin with no Host = reject"
1706        );
1707    }
1708
1709    #[mz_ore::test]
1710    fn test_mcp_error_codes() {
1711        assert_eq!(
1712            McpRequestError::InvalidJsonRpcVersion.error_code(),
1713            error_codes::INVALID_REQUEST
1714        );
1715        assert_eq!(
1716            McpRequestError::MethodNotFound("test".to_string()).error_code(),
1717            error_codes::METHOD_NOT_FOUND
1718        );
1719        assert_eq!(
1720            McpRequestError::QueryExecutionFailed("test".to_string()).error_code(),
1721            error_codes::INTERNAL_ERROR
1722        );
1723    }
1724}