Skip to main content

mz_environmentd/http/
mcp.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Model Context Protocol (MCP) HTTP handlers.
11//!
12//! Exposes Materialize data products to AI agents via JSON-RPC 2.0 over HTTP POST.
13//!
14//! ## Endpoints
15//!
16//! - `/api/mcp/agents` - User data products for customer AI agents
17//! - `/api/mcp/observatory` - System catalog (`mz_*`) for troubleshooting
18//!
19//! ## Tools
20//!
21//! **Agents:** `get_data_products`, `get_data_product_details`, `query`
22//! **Observatory:** `query_system_catalog`
23//!
24//! Data products are discovered via `mz_internal.mz_mcp_data_products` system view.
25
26use anyhow::anyhow;
27use axum::Json;
28use axum::response::IntoResponse;
29use http::StatusCode;
30use mz_adapter_types::dyncfgs::{ENABLE_MCP_AGENTS, ENABLE_MCP_OBSERVATORY};
31use mz_sql::parse::parse;
32use mz_sql::session::metadata::SessionMetadata;
33use mz_sql_parser::ast::display::escaped_string_literal;
34use mz_sql_parser::ast::visit::{self, Visit};
35use mz_sql_parser::ast::{Raw, RawItemName};
36use serde::{Deserialize, Serialize};
37use serde_json::json;
38use thiserror::Error;
39use tracing::{debug, warn};
40
41use crate::http::AuthedClient;
42use crate::http::sql::{SqlRequest, SqlResponse, SqlResult, execute_request};
43
44// To add a new tool: add entry to tools/list, add handler function, add dispatch case.
45const DISCOVERY_QUERY: &str = "SELECT * FROM mz_internal.mz_mcp_data_products";
46
47/// MCP request errors, mapped to JSON-RPC error codes.
48#[derive(Debug, Error)]
49enum McpRequestError {
50    #[error("Invalid JSON-RPC version: expected 2.0")]
51    InvalidJsonRpcVersion,
52    #[error("Method not found: {0}")]
53    #[allow(dead_code)] // Handled by serde deserialization, kept for error mapping
54    MethodNotFound(String),
55    #[error("Tool not found: {0}")]
56    ToolNotFound(String),
57    #[error("Data product not found: {0}")]
58    DataProductNotFound(String),
59    #[error("Query validation failed: {0}")]
60    QueryValidationFailed(String),
61    #[error("Query execution failed: {0}")]
62    QueryExecutionFailed(String),
63    #[error("Internal error: {0}")]
64    Internal(#[from] anyhow::Error),
65}
66
67impl McpRequestError {
68    fn error_code(&self) -> i32 {
69        match self {
70            Self::InvalidJsonRpcVersion => error_codes::INVALID_REQUEST,
71            Self::MethodNotFound(_) => error_codes::METHOD_NOT_FOUND,
72            Self::ToolNotFound(_) => error_codes::INVALID_PARAMS,
73            Self::DataProductNotFound(_) => error_codes::INVALID_PARAMS,
74            Self::QueryValidationFailed(_) => error_codes::INVALID_PARAMS,
75            Self::QueryExecutionFailed(_) | Self::Internal(_) => error_codes::INTERNAL_ERROR,
76        }
77    }
78
79    fn error_type(&self) -> &'static str {
80        match self {
81            Self::InvalidJsonRpcVersion => "InvalidRequest",
82            Self::MethodNotFound(_) => "MethodNotFound",
83            Self::ToolNotFound(_) => "ToolNotFound",
84            Self::DataProductNotFound(_) => "DataProductNotFound",
85            Self::QueryValidationFailed(_) => "ValidationError",
86            Self::QueryExecutionFailed(_) => "ExecutionError",
87            Self::Internal(_) => "InternalError",
88        }
89    }
90}
91
92/// JSON-RPC 2.0 request. Requests have `id`; notifications don't.
93#[derive(Debug, Deserialize)]
94pub(crate) struct McpRequest {
95    jsonrpc: String,
96    id: Option<serde_json::Value>,
97    #[serde(flatten)]
98    method: McpMethod,
99}
100
101/// MCP method variants with their associated parameters.
102#[derive(Debug, Deserialize)]
103#[serde(tag = "method", content = "params")]
104enum McpMethod {
105    /// Initialize method - params accepted but not currently used
106    #[serde(rename = "initialize")]
107    Initialize(#[allow(dead_code)] InitializeParams),
108    #[serde(rename = "tools/list")]
109    ToolsList,
110    #[serde(rename = "tools/call")]
111    ToolsCall(ToolsCallParams),
112    /// Catch-all for unknown methods (e.g. `notifications/initialized`)
113    #[serde(other)]
114    Unknown,
115}
116
117impl std::fmt::Display for McpMethod {
118    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119        match self {
120            McpMethod::Initialize(_) => write!(f, "initialize"),
121            McpMethod::ToolsList => write!(f, "tools/list"),
122            McpMethod::ToolsCall(_) => write!(f, "tools/call"),
123            McpMethod::Unknown => write!(f, "unknown"),
124        }
125    }
126}
127
128#[derive(Debug, Deserialize)]
129struct InitializeParams {
130    /// Protocol version from client. Not currently validated but accepted for MCP compliance.
131    #[serde(rename = "protocolVersion")]
132    #[allow(dead_code)]
133    protocol_version: String,
134    /// Client capabilities. Not currently used but accepted for MCP compliance.
135    #[serde(default)]
136    #[allow(dead_code)]
137    capabilities: serde_json::Value,
138    /// Client information (name, version). Not currently used but accepted for MCP compliance.
139    #[serde(rename = "clientInfo")]
140    #[allow(dead_code)]
141    client_info: Option<ClientInfo>,
142}
143
144#[derive(Debug, Deserialize)]
145struct ClientInfo {
146    #[allow(dead_code)]
147    name: String,
148    #[allow(dead_code)]
149    version: String,
150}
151
152/// Tool call parameters, deserialized via adjacently tagged enum.
153/// Serde maps `name` to the variant and `arguments` to the variant's data.
154#[derive(Debug, Deserialize)]
155#[serde(tag = "name", content = "arguments")]
156#[serde(rename_all = "snake_case")]
157enum ToolsCallParams {
158    // Agents endpoint tools
159    // Uses an ignored empty struct so MCP clients sending `"arguments": {}` can deserialize.
160    GetDataProducts(#[serde(default)] ()),
161    GetDataProductDetails(GetDataProductDetailsParams),
162    Query(QueryParams),
163    // Observatory endpoint tools
164    QuerySystemCatalog(QuerySystemCatalogParams),
165}
166
167impl std::fmt::Display for ToolsCallParams {
168    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169        match self {
170            ToolsCallParams::GetDataProducts(_) => write!(f, "get_data_products"),
171            ToolsCallParams::GetDataProductDetails(_) => write!(f, "get_data_product_details"),
172            ToolsCallParams::Query(_) => write!(f, "query"),
173            ToolsCallParams::QuerySystemCatalog(_) => write!(f, "query_system_catalog"),
174        }
175    }
176}
177
178#[derive(Debug, Deserialize)]
179struct GetDataProductDetailsParams {
180    name: String,
181}
182
183#[derive(Debug, Deserialize)]
184struct QueryParams {
185    cluster: String,
186    sql_query: String,
187}
188
189#[derive(Debug, Deserialize)]
190struct QuerySystemCatalogParams {
191    sql_query: String,
192}
193
194#[derive(Debug, Serialize)]
195struct McpResponse {
196    jsonrpc: String,
197    id: serde_json::Value,
198    #[serde(skip_serializing_if = "Option::is_none")]
199    result: Option<McpResult>,
200    #[serde(skip_serializing_if = "Option::is_none")]
201    error: Option<McpError>,
202}
203
204/// Typed MCP response results.
205#[derive(Debug, Serialize)]
206#[serde(untagged)]
207enum McpResult {
208    Initialize(InitializeResult),
209    ToolsList(ToolsListResult),
210    ToolContent(ToolContentResult),
211}
212
213#[derive(Debug, Serialize)]
214struct InitializeResult {
215    #[serde(rename = "protocolVersion")]
216    protocol_version: String,
217    capabilities: Capabilities,
218    #[serde(rename = "serverInfo")]
219    server_info: ServerInfo,
220}
221
222#[derive(Debug, Serialize)]
223struct Capabilities {
224    tools: serde_json::Value,
225}
226
227#[derive(Debug, Serialize)]
228struct ServerInfo {
229    name: String,
230    version: String,
231}
232
233#[derive(Debug, Serialize)]
234struct ToolsListResult {
235    tools: Vec<ToolDefinition>,
236}
237
238#[derive(Debug, Serialize)]
239struct ToolDefinition {
240    name: String,
241    description: String,
242    #[serde(rename = "inputSchema")]
243    input_schema: serde_json::Value,
244}
245
246#[derive(Debug, Serialize)]
247struct ToolContentResult {
248    content: Vec<ContentBlock>,
249}
250
251#[derive(Debug, Serialize)]
252struct ContentBlock {
253    #[serde(rename = "type")]
254    content_type: String,
255    text: String,
256}
257
258/// JSON-RPC 2.0 error codes.
259mod error_codes {
260    pub const INVALID_REQUEST: i32 = -32600;
261    pub const METHOD_NOT_FOUND: i32 = -32601;
262    pub const INVALID_PARAMS: i32 = -32602;
263    pub const INTERNAL_ERROR: i32 = -32603;
264}
265
266#[derive(Debug, Serialize)]
267struct McpError {
268    code: i32,
269    message: String,
270    #[serde(skip_serializing_if = "Option::is_none")]
271    data: Option<serde_json::Value>,
272}
273
274impl From<McpRequestError> for McpError {
275    fn from(err: McpRequestError) -> Self {
276        McpError {
277            code: err.error_code(),
278            message: err.to_string(),
279            data: Some(json!({
280                "error_type": err.error_type(),
281            })),
282        }
283    }
284}
285
286#[derive(Debug, Clone, Copy)]
287enum McpEndpointType {
288    Agents,
289    Observatory,
290}
291
292impl std::fmt::Display for McpEndpointType {
293    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
294        match self {
295            McpEndpointType::Agents => write!(f, "agents"),
296            McpEndpointType::Observatory => write!(f, "observatory"),
297        }
298    }
299}
300
301/// Agents endpoint: exposes user data products.
302pub async fn handle_mcp_agents(
303    client: AuthedClient,
304    Json(request): Json<McpRequest>,
305) -> impl IntoResponse {
306    handle_mcp_request(client, request, McpEndpointType::Agents).await
307}
308
309/// Observatory endpoint: exposes system catalog (mz_*) only.
310pub async fn handle_mcp_observatory(
311    client: AuthedClient,
312    Json(request): Json<McpRequest>,
313) -> impl IntoResponse {
314    handle_mcp_request(client, request, McpEndpointType::Observatory).await
315}
316
317async fn handle_mcp_request(
318    mut client: AuthedClient,
319    request: McpRequest,
320    endpoint_type: McpEndpointType,
321) -> impl IntoResponse {
322    // Check the per-endpoint feature flag via a catalog snapshot, similar to frontend_peek.rs.
323    let catalog = client.client.catalog_snapshot("mcp").await;
324    let dyncfgs = catalog.system_config().dyncfgs();
325    let enabled = match endpoint_type {
326        McpEndpointType::Agents => ENABLE_MCP_AGENTS.get(dyncfgs),
327        McpEndpointType::Observatory => ENABLE_MCP_OBSERVATORY.get(dyncfgs),
328    };
329    if !enabled {
330        debug!(endpoint = %endpoint_type, "MCP endpoint disabled by feature flag");
331        return StatusCode::SERVICE_UNAVAILABLE.into_response();
332    }
333
334    let user = client.client.session().user().name.clone();
335    let is_notification = request.id.is_none();
336
337    debug!(
338        method = %request.method,
339        endpoint = %endpoint_type,
340        user = %user,
341        is_notification = is_notification,
342        "MCP request received"
343    );
344
345    // Handle notifications (no response needed)
346    if is_notification {
347        debug!(method = %request.method, "Received notification (no response will be sent)");
348        return StatusCode::OK.into_response();
349    }
350
351    // Spawn task for fault isolation
352    let response = mz_ore::task::spawn(|| "mcp_request", async move {
353        handle_mcp_request_inner(&mut client, request, endpoint_type).await
354    })
355    .await;
356
357    (StatusCode::OK, Json(response)).into_response()
358}
359
360async fn handle_mcp_request_inner(
361    client: &mut AuthedClient,
362    request: McpRequest,
363    endpoint_type: McpEndpointType,
364) -> McpResponse {
365    // Extract request ID (guaranteed to be Some since notifications are filtered earlier)
366    let request_id = request.id.clone().unwrap_or(serde_json::Value::Null);
367
368    let result = handle_mcp_method(client, &request, endpoint_type).await;
369
370    match result {
371        Ok(result_value) => McpResponse {
372            jsonrpc: "2.0".to_string(),
373            id: request_id,
374            result: Some(result_value),
375            error: None,
376        },
377        Err(e) => {
378            // Log non-trivial errors
379            if !matches!(
380                e,
381                McpRequestError::MethodNotFound(_) | McpRequestError::InvalidJsonRpcVersion
382            ) {
383                warn!(error = %e, method = %request.method, "MCP method execution failed");
384            }
385            McpResponse {
386                jsonrpc: "2.0".to_string(),
387                id: request_id,
388                result: None,
389                error: Some(e.into()),
390            }
391        }
392    }
393}
394
395async fn handle_mcp_method(
396    client: &mut AuthedClient,
397    request: &McpRequest,
398    endpoint_type: McpEndpointType,
399) -> Result<McpResult, McpRequestError> {
400    // Validate JSON-RPC version
401    if request.jsonrpc != "2.0" {
402        return Err(McpRequestError::InvalidJsonRpcVersion);
403    }
404
405    // Handle different MCP methods using pattern matching
406    match &request.method {
407        McpMethod::Initialize(_) => {
408            debug!(endpoint = %endpoint_type, "Processing initialize");
409            handle_initialize(endpoint_type).await
410        }
411        McpMethod::ToolsList => {
412            debug!(endpoint = %endpoint_type, "Processing tools/list");
413            handle_tools_list(endpoint_type).await
414        }
415        McpMethod::ToolsCall(params) => {
416            debug!(tool = %params, endpoint = %endpoint_type, "Processing tools/call");
417            handle_tools_call(client, params, endpoint_type).await
418        }
419        McpMethod::Unknown => Err(McpRequestError::MethodNotFound(
420            "unknown method".to_string(),
421        )),
422    }
423}
424
425async fn handle_initialize(endpoint_type: McpEndpointType) -> Result<McpResult, McpRequestError> {
426    Ok(McpResult::Initialize(InitializeResult {
427        protocol_version: "2024-11-05".to_string(),
428        capabilities: Capabilities { tools: json!({}) },
429        server_info: ServerInfo {
430            name: format!("materialize-mcp-{}", endpoint_type),
431            version: env!("CARGO_PKG_VERSION").to_string(),
432        },
433    }))
434}
435
436async fn handle_tools_list(endpoint_type: McpEndpointType) -> Result<McpResult, McpRequestError> {
437    let tools = match endpoint_type {
438        McpEndpointType::Agents => {
439            vec![
440                ToolDefinition {
441                    name: "get_data_products".to_string(),
442                    description: "Discover all available real-time data views (data products) that represent business entities like customers, orders, products, etc. Each data product provides fresh, queryable data with defined schemas. Use this first to see what data is available before querying specific information.".to_string(),
443                    input_schema: json!({
444                        "type": "object",
445                        "properties": {},
446                        "required": []
447                    }),
448                },
449                ToolDefinition {
450                    name: "get_data_product_details".to_string(),
451                    description: "Get the complete schema and structure of a specific data product. This shows you exactly what fields are available, their types, and what data you can query. Use this after finding a data product from get_data_products() to understand how to query it.".to_string(),
452                    input_schema: json!({
453                        "type": "object",
454                        "properties": {
455                            "name": {
456                                "type": "string",
457                                "description": "Exact name of the data product from get_data_products() list"
458                            }
459                        },
460                        "required": ["name"]
461                    }),
462                },
463                ToolDefinition {
464                    name: "query".to_string(),
465                    description: "Execute SQL queries against real-time data products to retrieve current business information. Use standard PostgreSQL syntax. You can JOIN multiple data products together, but ONLY if they are all hosted on the same cluster. Always specify the cluster parameter from the data product details. This provides fresh, up-to-date results from materialized views.".to_string(),
466                    input_schema: json!({
467                        "type": "object",
468                        "properties": {
469                            "cluster": {
470                                "type": "string",
471                                "description": "Exact cluster name from the data product details - required for query execution"
472                            },
473                            "sql_query": {
474                                "type": "string",
475                                "description": "PostgreSQL-compatible SELECT statement to retrieve data. Use the fully qualified data product name exactly as provided (with double quotes). You can JOIN multiple data products, but only those on the same cluster."
476                            }
477                        },
478                        "required": ["cluster", "sql_query"]
479                    }),
480                },
481            ]
482        }
483        McpEndpointType::Observatory => {
484            vec![
485                ToolDefinition {
486                    name: "query_system_catalog".to_string(),
487                    description: "Query Materialize system catalog tables (mz_*) for troubleshooting and observability. Only mz_* tables are accessible.".to_string(),
488                    input_schema: json!({
489                        "type": "object",
490                        "properties": {
491                            "sql_query": {
492                                "type": "string",
493                                "description": "SQL query restricted to mz_* system tables"
494                            }
495                        },
496                        "required": ["sql_query"]
497                    }),
498                },
499            ]
500        }
501    };
502
503    Ok(McpResult::ToolsList(ToolsListResult { tools }))
504}
505
506async fn handle_tools_call(
507    client: &mut AuthedClient,
508    params: &ToolsCallParams,
509    endpoint_type: McpEndpointType,
510) -> Result<McpResult, McpRequestError> {
511    match (endpoint_type, params) {
512        (McpEndpointType::Agents, ToolsCallParams::GetDataProducts(_)) => {
513            get_data_products(client).await
514        }
515        (McpEndpointType::Agents, ToolsCallParams::GetDataProductDetails(p)) => {
516            get_data_product_details(client, &p.name).await
517        }
518        (McpEndpointType::Agents, ToolsCallParams::Query(p)) => {
519            execute_query(client, &p.cluster, &p.sql_query).await
520        }
521        (McpEndpointType::Observatory, ToolsCallParams::QuerySystemCatalog(p)) => {
522            query_system_catalog(client, &p.sql_query).await
523        }
524        // Tool called on wrong endpoint
525        (endpoint, tool) => Err(McpRequestError::ToolNotFound(format!(
526            "{} is not available on {} endpoint",
527            tool, endpoint
528        ))),
529    }
530}
531
532/// Execute SQL via `execute_request` from sql.rs.
533async fn execute_sql(
534    client: &mut AuthedClient,
535    query: &str,
536) -> Result<Vec<Vec<serde_json::Value>>, McpRequestError> {
537    let mut response = SqlResponse::new();
538
539    execute_request(
540        client,
541        SqlRequest::Simple {
542            query: query.to_string(),
543        },
544        &mut response,
545    )
546    .await
547    .map_err(|e| McpRequestError::QueryExecutionFailed(e.to_string()))?;
548
549    // Extract the result with rows (the user's single SELECT/SHOW query)
550    // Other results will be OK (from BEGIN, SET, COMMIT) or Err
551    for result in response.results {
552        match result {
553            SqlResult::Rows { rows, .. } => return Ok(rows),
554            SqlResult::Err { error, .. } => {
555                return Err(McpRequestError::QueryExecutionFailed(error.message));
556            }
557            SqlResult::Ok { .. } => continue,
558        }
559    }
560
561    Err(McpRequestError::QueryExecutionFailed(
562        "Query did not return any results".to_string(),
563    ))
564}
565
566async fn get_data_products(client: &mut AuthedClient) -> Result<McpResult, McpRequestError> {
567    debug!("Executing get_data_products");
568    let rows = execute_sql(client, DISCOVERY_QUERY).await?;
569    debug!("get_data_products returned {} rows", rows.len());
570    if rows.is_empty() {
571        warn!("No data products found - indexes must have comments");
572    }
573
574    let text =
575        serde_json::to_string_pretty(&rows).map_err(|e| McpRequestError::Internal(anyhow!(e)))?;
576
577    Ok(McpResult::ToolContent(ToolContentResult {
578        content: vec![ContentBlock {
579            content_type: "text".to_string(),
580            text,
581        }],
582    }))
583}
584
585async fn get_data_product_details(
586    client: &mut AuthedClient,
587    name: &str,
588) -> Result<McpResult, McpRequestError> {
589    debug!(name = %name, "Executing get_data_product_details");
590
591    let query = format!(
592        "SELECT * FROM mz_internal.mz_mcp_data_products WHERE object_name = {}",
593        escaped_string_literal(name)
594    );
595
596    let rows = execute_sql(client, &query).await?;
597
598    if rows.is_empty() {
599        return Err(McpRequestError::DataProductNotFound(name.to_string()));
600    }
601
602    let text =
603        serde_json::to_string_pretty(&rows).map_err(|e| McpRequestError::Internal(anyhow!(e)))?;
604
605    Ok(McpResult::ToolContent(ToolContentResult {
606        content: vec![ContentBlock {
607            content_type: "text".to_string(),
608            text,
609        }],
610    }))
611}
612
613/// Validates query is a single SELECT, SHOW, or EXPLAIN statement.
614fn validate_readonly_query(sql: &str) -> Result<(), McpRequestError> {
615    let sql = sql.trim();
616    if sql.is_empty() {
617        return Err(McpRequestError::QueryValidationFailed(
618            "Empty query".to_string(),
619        ));
620    }
621
622    // Parse the SQL to get AST
623    let stmts = parse(sql).map_err(|e| {
624        McpRequestError::QueryValidationFailed(format!("Failed to parse SQL: {}", e))
625    })?;
626
627    // Only allow a single statement
628    if stmts.len() != 1 {
629        return Err(McpRequestError::QueryValidationFailed(format!(
630            "Only one query allowed at a time. Found {} statements.",
631            stmts.len()
632        )));
633    }
634
635    // Allowlist: Only SELECT, SHOW, and EXPLAIN statements permitted
636    let stmt = &stmts[0];
637    use mz_sql_parser::ast::Statement;
638
639    match &stmt.ast {
640        Statement::Select(_) | Statement::Show(_) | Statement::ExplainPlan(_) => {
641            // Allowed - read-only operations
642            Ok(())
643        }
644        _ => Err(McpRequestError::QueryValidationFailed(
645            "Only SELECT, SHOW, and EXPLAIN statements are allowed".to_string(),
646        )),
647    }
648}
649
650async fn execute_query(
651    client: &mut AuthedClient,
652    cluster: &str,
653    sql_query: &str,
654) -> Result<McpResult, McpRequestError> {
655    debug!(cluster = %cluster, "Executing user query");
656
657    validate_readonly_query(sql_query)?;
658
659    // Use READ ONLY transaction to prevent modifications
660    // Combine with SET CLUSTER (prometheus.rs:29-33 pattern)
661    let combined_query = format!(
662        "BEGIN READ ONLY; SET CLUSTER = {}; {}; COMMIT;",
663        escaped_string_literal(cluster),
664        sql_query
665    );
666
667    let rows = execute_sql(client, &combined_query).await?;
668
669    let text =
670        serde_json::to_string_pretty(&rows).map_err(|e| McpRequestError::Internal(anyhow!(e)))?;
671
672    Ok(McpResult::ToolContent(ToolContentResult {
673        content: vec![ContentBlock {
674            content_type: "text".to_string(),
675            text,
676        }],
677    }))
678}
679
680async fn query_system_catalog(
681    client: &mut AuthedClient,
682    sql_query: &str,
683) -> Result<McpResult, McpRequestError> {
684    debug!("Executing query_system_catalog");
685
686    // First validate it's a read-only query
687    validate_readonly_query(sql_query)?;
688
689    // Then validate that query only references mz_* tables by parsing the SQL
690    validate_system_catalog_query(sql_query)?;
691
692    // Use READ ONLY transaction for defense-in-depth
693    let wrapped_query = format!("BEGIN READ ONLY; {}; COMMIT;", sql_query);
694    let rows = execute_sql(client, &wrapped_query).await?;
695
696    let text =
697        serde_json::to_string_pretty(&rows).map_err(|e| McpRequestError::Internal(anyhow!(e)))?;
698
699    Ok(McpResult::ToolContent(ToolContentResult {
700        content: vec![ContentBlock {
701            content_type: "text".to_string(),
702            text,
703        }],
704    }))
705}
706
707/// Collects table references from SQL AST with their schema qualification.
708struct TableReferenceCollector {
709    /// Stores (schema, table_name) tuples. Schema is None if unqualified.
710    tables: Vec<(Option<String>, String)>,
711    /// CTE names to exclude from validation (they're not real tables)
712    cte_names: std::collections::BTreeSet<String>,
713}
714
715impl TableReferenceCollector {
716    fn new() -> Self {
717        Self {
718            tables: Vec::new(),
719            cte_names: std::collections::BTreeSet::new(),
720        }
721    }
722}
723
724impl<'ast> Visit<'ast, Raw> for TableReferenceCollector {
725    fn visit_cte(&mut self, cte: &'ast mz_sql_parser::ast::Cte<Raw>) {
726        // Track CTE names so we don't treat them as table references
727        self.cte_names
728            .insert(cte.alias.name.as_str().to_lowercase());
729        visit::visit_cte(self, cte);
730    }
731
732    fn visit_table_factor(&mut self, table_factor: &'ast mz_sql_parser::ast::TableFactor<Raw>) {
733        // Only visit actual table references in FROM/JOIN clauses, not function names
734        if let mz_sql_parser::ast::TableFactor::Table { name, .. } = table_factor {
735            match name {
736                RawItemName::Name(n) | RawItemName::Id(_, n, _) => {
737                    let parts = &n.0;
738                    if !parts.is_empty() {
739                        let table_name = parts.last().unwrap().as_str().to_lowercase();
740
741                        // Skip if this is a CTE reference, not a real table
742                        if self.cte_names.contains(&table_name) {
743                            visit::visit_table_factor(self, table_factor);
744                            return;
745                        }
746
747                        // Extract schema if qualified (e.g., mz_catalog.mz_tables)
748                        let schema = if parts.len() >= 2 {
749                            Some(parts[parts.len() - 2].as_str().to_lowercase())
750                        } else {
751                            None
752                        };
753                        self.tables.push((schema, table_name));
754                    }
755                }
756            }
757        }
758        visit::visit_table_factor(self, table_factor);
759    }
760}
761
762/// Validates query references only mz_* system catalog tables.
763fn validate_system_catalog_query(sql: &str) -> Result<(), McpRequestError> {
764    // Parse the SQL to validate it
765    let stmts = parse(sql).map_err(|e| {
766        McpRequestError::QueryValidationFailed(format!("Failed to parse SQL: {}", e))
767    })?;
768
769    if stmts.is_empty() {
770        return Err(McpRequestError::QueryValidationFailed(
771            "Empty query".to_string(),
772        ));
773    }
774
775    // Walk the AST to collect all table references
776    let mut collector = TableReferenceCollector::new();
777    for stmt in &stmts {
778        collector.visit_statement(&stmt.ast);
779    }
780
781    // Allowed system schemas
782    const ALLOWED_SCHEMAS: &[&str] = &[
783        "mz_catalog",
784        "mz_internal",
785        "pg_catalog",
786        "information_schema",
787    ];
788
789    // Helper to check if a table reference is allowed
790    let is_system_table = |(schema, table_name): &(Option<String>, String)| {
791        match schema {
792            // Explicitly qualified with allowed schema
793            Some(s) => ALLOWED_SCHEMAS.contains(&s.as_str()),
794            // Unqualified: allow if starts with mz_ (common Materialize system tables)
795            None => table_name.starts_with("mz_"),
796        }
797    };
798
799    // Check that all table references are system tables
800    let non_system_tables: Vec<String> = collector
801        .tables
802        .iter()
803        .filter(|t| !is_system_table(t))
804        .map(|(schema, table)| match schema {
805            Some(s) => format!("{}.{}", s, table),
806            None => table.clone(),
807        })
808        .collect();
809
810    if !non_system_tables.is_empty() {
811        return Err(McpRequestError::QueryValidationFailed(format!(
812            "Query references non-system tables: {}. Only system catalog tables (mz_*, pg_catalog, information_schema) are allowed.",
813            non_system_tables.join(", ")
814        )));
815    }
816
817    // Ensure at least one system table is referenced
818    if collector.tables.is_empty() || !collector.tables.iter().any(is_system_table) {
819        return Err(McpRequestError::QueryValidationFailed(
820            "Query must reference at least one system catalog table".to_string(),
821        ));
822    }
823
824    Ok(())
825}
826
827#[cfg(test)]
828mod tests {
829    use super::*;
830
831    #[mz_ore::test]
832    fn test_validate_readonly_query_select() {
833        assert!(validate_readonly_query("SELECT * FROM mz_tables").is_ok());
834        assert!(validate_readonly_query("SELECT 1 + 2").is_ok());
835        assert!(validate_readonly_query("  SELECT 1  ").is_ok());
836    }
837
838    #[mz_ore::test]
839    fn test_validate_readonly_query_subqueries() {
840        // Simple subquery in WHERE clause
841        assert!(
842            validate_readonly_query(
843                "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM mz_columns)"
844            )
845            .is_ok()
846        );
847
848        // Subquery in FROM clause
849        assert!(
850            validate_readonly_query(
851                "SELECT * FROM (SELECT * FROM mz_tables WHERE name LIKE 'test%') AS t"
852            )
853            .is_ok()
854        );
855
856        // Correlated subquery
857        assert!(validate_readonly_query(
858            "SELECT * FROM mz_tables t WHERE EXISTS (SELECT 1 FROM mz_columns c WHERE c.id = t.id)"
859        )
860        .is_ok());
861
862        // Nested subqueries
863        assert!(validate_readonly_query(
864            "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM mz_columns WHERE type IN (SELECT name FROM mz_types))"
865        )
866        .is_ok());
867
868        // Subquery with aggregation
869        assert!(
870            validate_readonly_query(
871                "SELECT * FROM mz_tables WHERE id = (SELECT MAX(id) FROM mz_columns)"
872            )
873            .is_ok()
874        );
875    }
876
877    #[mz_ore::test]
878    fn test_validate_readonly_query_show() {
879        assert!(validate_readonly_query("SHOW CLUSTERS").is_ok());
880        assert!(validate_readonly_query("SHOW TABLES").is_ok());
881    }
882
883    #[mz_ore::test]
884    fn test_validate_readonly_query_explain() {
885        assert!(validate_readonly_query("EXPLAIN SELECT 1").is_ok());
886    }
887
888    #[mz_ore::test]
889    fn test_validate_readonly_query_rejects_writes() {
890        assert!(validate_readonly_query("INSERT INTO t VALUES (1)").is_err());
891        assert!(validate_readonly_query("UPDATE t SET a = 1").is_err());
892        assert!(validate_readonly_query("DELETE FROM t").is_err());
893        assert!(validate_readonly_query("CREATE TABLE t (a INT)").is_err());
894        assert!(validate_readonly_query("DROP TABLE t").is_err());
895    }
896
897    #[mz_ore::test]
898    fn test_validate_readonly_query_rejects_multiple() {
899        assert!(validate_readonly_query("SELECT 1; SELECT 2").is_err());
900    }
901
902    #[mz_ore::test]
903    fn test_validate_readonly_query_rejects_empty() {
904        assert!(validate_readonly_query("").is_err());
905        assert!(validate_readonly_query("   ").is_err());
906    }
907
908    #[mz_ore::test]
909    fn test_validate_system_catalog_query_accepts_mz_tables() {
910        assert!(validate_system_catalog_query("SELECT * FROM mz_tables").is_ok());
911        assert!(validate_system_catalog_query("SELECT * FROM mz_internal.mz_comments").is_ok());
912        assert!(
913            validate_system_catalog_query(
914                "SELECT * FROM mz_tables t JOIN mz_columns c ON t.id = c.id"
915            )
916            .is_ok()
917        );
918    }
919
920    #[mz_ore::test]
921    fn test_validate_system_catalog_query_subqueries() {
922        // Subquery with mz_* tables
923        assert!(
924            validate_system_catalog_query(
925                "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM mz_columns)"
926            )
927            .is_ok()
928        );
929
930        // Nested subqueries with mz_* tables
931        assert!(validate_system_catalog_query(
932            "SELECT * FROM mz_tables WHERE id IN (SELECT table_id FROM mz_columns WHERE type IN (SELECT id FROM mz_types))"
933        )
934        .is_ok());
935
936        // Subquery in FROM clause
937        assert!(
938            validate_system_catalog_query(
939                "SELECT * FROM (SELECT * FROM mz_tables WHERE name LIKE 'test%') AS t"
940            )
941            .is_ok()
942        );
943
944        // Reject subqueries that reference non-mz_* tables
945        assert!(
946            validate_system_catalog_query(
947                "SELECT * FROM mz_tables WHERE id IN (SELECT table_id FROM user_data)"
948            )
949            .is_err()
950        );
951
952        // Reject mixed references in nested subqueries
953        assert!(validate_system_catalog_query(
954            "SELECT * FROM mz_tables WHERE id IN (SELECT id FROM (SELECT id FROM user_table) AS t)"
955        )
956        .is_err());
957    }
958
959    #[mz_ore::test]
960    fn test_validate_system_catalog_query_rejects_user_tables() {
961        assert!(validate_system_catalog_query("SELECT * FROM user_data").is_err());
962        assert!(validate_system_catalog_query("SELECT * FROM my_table").is_err());
963        // Security: reject queries that mention mz_ in a non-table context
964        assert!(
965            validate_system_catalog_query("SELECT * FROM user_data WHERE 'mz_' IS NOT NULL")
966                .is_err()
967        );
968    }
969
970    #[mz_ore::test]
971    fn test_validate_system_catalog_query_allows_functions() {
972        // Function names should not be treated as table references
973        assert!(
974            validate_system_catalog_query(
975                "SELECT date_part('year', now())::int4 AS y FROM mz_tables LIMIT 1"
976            )
977            .is_ok()
978        );
979        assert!(validate_system_catalog_query("SELECT length(name) FROM mz_tables").is_ok());
980        assert!(
981            validate_system_catalog_query(
982                "SELECT count(*) FROM mz_sources WHERE now() > created_at"
983            )
984            .is_ok()
985        );
986    }
987
988    #[mz_ore::test]
989    fn test_validate_system_catalog_query_schema_qualified() {
990        // Qualified with allowed schemas should work
991        assert!(validate_system_catalog_query("SELECT * FROM mz_catalog.mz_tables").is_ok());
992        assert!(validate_system_catalog_query("SELECT * FROM mz_internal.mz_sessions").is_ok());
993        assert!(validate_system_catalog_query("SELECT * FROM pg_catalog.pg_type").is_ok());
994        assert!(validate_system_catalog_query("SELECT * FROM information_schema.tables").is_ok());
995
996        // Qualified with disallowed schema should fail
997        assert!(validate_system_catalog_query("SELECT * FROM public.user_table").is_err());
998        assert!(validate_system_catalog_query("SELECT * FROM myschema.mytable").is_err());
999
1000        // Mixed: system and user schemas should fail
1001        assert!(
1002            validate_system_catalog_query(
1003                "SELECT * FROM mz_catalog.mz_tables JOIN public.user_data ON true"
1004            )
1005            .is_err()
1006        );
1007    }
1008
1009    #[mz_ore::test]
1010    fn test_validate_system_catalog_query_adversarial_cases() {
1011        // Try to sneak in user table via CTE
1012        assert!(
1013            validate_system_catalog_query(
1014                "WITH user_cte AS (SELECT * FROM user_data) \
1015                 SELECT * FROM mz_tables, user_cte"
1016            )
1017            .is_err(),
1018            "Should reject CTE referencing user table"
1019        );
1020
1021        // Complex multi-level CTE with user table buried deep
1022        assert!(
1023            validate_system_catalog_query(
1024                "WITH \
1025                   cte1 AS (SELECT * FROM mz_tables), \
1026                   cte2 AS (SELECT * FROM cte1), \
1027                   cte3 AS (SELECT * FROM user_data) \
1028                 SELECT * FROM cte2"
1029            )
1030            .is_err(),
1031            "Should reject CTE chain with user table"
1032        );
1033
1034        // Multiple joins - user table in the middle
1035        assert!(
1036            validate_system_catalog_query(
1037                "SELECT * FROM mz_tables t1 \
1038                 JOIN user_data u ON t1.id = u.id \
1039                 JOIN mz_sources s ON t1.id = s.id"
1040            )
1041            .is_err(),
1042            "Should reject multi-join with user table"
1043        );
1044
1045        // LEFT JOIN trying to hide user table
1046        assert!(
1047            validate_system_catalog_query(
1048                "SELECT * FROM mz_tables t \
1049                 LEFT JOIN user_data u ON t.id = u.table_id \
1050                 WHERE u.id IS NULL"
1051            )
1052            .is_err(),
1053            "Should reject LEFT JOIN with user table"
1054        );
1055
1056        // Nested subquery with user table in FROM
1057        assert!(
1058            validate_system_catalog_query(
1059                "SELECT * FROM mz_tables WHERE id IN \
1060                 (SELECT table_id FROM (SELECT * FROM user_data) AS u)"
1061            )
1062            .is_err(),
1063            "Should reject nested subquery with user table"
1064        );
1065
1066        // UNION trying to mix system and user data
1067        assert!(
1068            validate_system_catalog_query(
1069                "SELECT name FROM mz_tables \
1070                 UNION \
1071                 SELECT name FROM user_data"
1072            )
1073            .is_err(),
1074            "Should reject UNION with user table"
1075        );
1076
1077        // UNION ALL variation
1078        assert!(
1079            validate_system_catalog_query(
1080                "SELECT id FROM mz_sources \
1081                 UNION ALL \
1082                 SELECT id FROM products"
1083            )
1084            .is_err(),
1085            "Should reject UNION ALL with user table"
1086        );
1087
1088        // Cross join with user table
1089        assert!(
1090            validate_system_catalog_query("SELECT * FROM mz_tables CROSS JOIN user_data").is_err(),
1091            "Should reject CROSS JOIN with user table"
1092        );
1093
1094        // Subquery in SELECT clause referencing user table
1095        assert!(
1096            validate_system_catalog_query(
1097                "SELECT t.*, (SELECT COUNT(*) FROM user_data) AS cnt FROM mz_tables t"
1098            )
1099            .is_err(),
1100            "Should reject subquery in SELECT with user table"
1101        );
1102
1103        // Try to use a schema name that looks similar to allowed ones
1104        assert!(
1105            validate_system_catalog_query("SELECT * FROM mz_catalogg.fake_table").is_err(),
1106            "Should reject typo-squatting schema name"
1107        );
1108        assert!(
1109            validate_system_catalog_query("SELECT * FROM mz_catalog_hack.fake_table").is_err(),
1110            "Should reject fake schema with mz_catalog prefix"
1111        );
1112
1113        // Lateral join with user table
1114        assert!(
1115            validate_system_catalog_query(
1116                "SELECT * FROM mz_tables t, LATERAL (SELECT * FROM user_data WHERE id = t.id) u"
1117            )
1118            .is_err(),
1119            "Should reject LATERAL join with user table"
1120        );
1121
1122        // Valid complex query - all system tables
1123        assert!(
1124            validate_system_catalog_query(
1125                "WITH \
1126                   tables AS (SELECT * FROM mz_tables), \
1127                   sources AS (SELECT * FROM mz_sources) \
1128                 SELECT t.name, s.name \
1129                 FROM tables t \
1130                 JOIN sources s ON t.id = s.id \
1131                 WHERE t.id IN (SELECT id FROM mz_columns)"
1132            )
1133            .is_ok(),
1134            "Should allow complex query with only system tables"
1135        );
1136
1137        // Valid UNION of system tables
1138        assert!(
1139            validate_system_catalog_query(
1140                "SELECT name FROM mz_tables \
1141                 UNION \
1142                 SELECT name FROM mz_sources"
1143            )
1144            .is_ok(),
1145            "Should allow UNION of system tables"
1146        );
1147    }
1148
1149    #[mz_ore::test]
1150    fn test_validate_system_catalog_query_rejects_mixed_tables() {
1151        assert!(
1152            validate_system_catalog_query(
1153                "SELECT * FROM mz_tables t JOIN user_data u ON t.id = u.table_id"
1154            )
1155            .is_err()
1156        );
1157    }
1158
1159    #[mz_ore::test]
1160    fn test_mcp_error_codes() {
1161        assert_eq!(
1162            McpRequestError::InvalidJsonRpcVersion.error_code(),
1163            error_codes::INVALID_REQUEST
1164        );
1165        assert_eq!(
1166            McpRequestError::MethodNotFound("test".to_string()).error_code(),
1167            error_codes::METHOD_NOT_FOUND
1168        );
1169        assert_eq!(
1170            McpRequestError::QueryExecutionFailed("test".to_string()).error_code(),
1171            error_codes::INTERNAL_ERROR
1172        );
1173    }
1174}