mz_lsp_server/
backend.rs

1// Copyright (c) 2023 Eyal Kalderon
2// Copyright Materialize, Inc. and contributors. All rights reserved.
3//
4// Use of this software is governed by the Business Source License
5// included in the LICENSE file.
6//
7// As of the Change Date specified in that file, in accordance with
8// the Business Source License, use of this software will be governed
9// by the Apache License, Version 2.0.
10//
11// Portions of this file are derived from the tower-lsp project. The original source
12// code was retrieved on 10/02/2023 from:
13//
14//     https://github.com/ebkalderon/tower-lsp/blob/cc4c858/examples/stdio.rs
15//
16// The original source code is subject to the terms of the <APACHE|MIT> license, a copy
17// of which can be found in the LICENSE file at the root of this repository.
18
19use ::serde::Deserialize;
20use mz_ore::collections::HashMap;
21use mz_sql_lexer::keywords::Keyword;
22use mz_sql_lexer::lexer::{self, Token};
23use mz_sql_parser::ast::display::FormatMode;
24use mz_sql_parser::ast::{Raw, Statement, statement_kind_label_value};
25use mz_sql_parser::parser::parse_statements;
26use mz_sql_pretty::PrettyConfig;
27use regex::Regex;
28use ropey::Rope;
29use serde::Serialize;
30use serde_json::{Value, json};
31use tokio::sync::Mutex;
32use tower_lsp::jsonrpc::{Error, ErrorCode, Result};
33use tower_lsp::lsp_types::*;
34use tower_lsp::{Client, LanguageServer};
35
36use crate::{PKG_NAME, PKG_VERSION};
37
38/// Default formatting width to use in the [LanguageServer::formatting] implementation.
39pub const DEFAULT_FORMATTING_WIDTH: usize = 100;
40
41/// This is a re-implementation of [mz_sql_parser::parser::StatementParseResult]
42/// but replacing the sql code with a rope.
43#[derive(Debug)]
44pub struct ParseResult {
45    /// Abstract Syntax Trees (AST) for each of the SQL statements
46    /// in a file.
47    pub asts: Vec<Statement<Raw>>,
48    /// Text handler for big files.
49    pub rope: Rope,
50}
51
52#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
53/// Represents the structure a client uses to understand
54/// statement's kind and sql content.
55pub struct ExecuteCommandParseStatement {
56    /// The sql content in the statement
57    pub sql: String,
58    /// The type of statement.
59    /// Represents the String version of [Statement].
60    pub kind: String,
61}
62
63#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
64/// Represents the response from the parse command.
65pub struct ExecuteCommandParseResponse {
66    /// Contains all the valid SQL statements.
67    pub statements: Vec<ExecuteCommandParseStatement>,
68}
69
70/// Represents the completion items that will
71/// be returned to the client when requested.
72#[derive(Debug)]
73pub struct Completions {
74    /// Contains the completion items
75    /// after a SELECT token.
76    pub select: Vec<CompletionItem>,
77    /// Contains the completion items for
78    /// after a FROM token.
79    pub from: Vec<CompletionItem>,
80}
81
82/// The [Backend] struct implements the [LanguageServer] trait, and thus must provide implementations for its methods.
83/// Most imporant methods includes:
84/// - `initialize`: sets up the server.
85/// - `did_open`: logs when a file is opened and triggers an `on_change` method.
86/// - `did_save`, `did_close`: log messages indicating file actions.
87/// - `completion`: Provides completion suggestions. WIP.
88/// - `code_lens`: Offers in-editor commands. WIP.
89///
90/// Most of the `did_` methods re-route the request to the private method `on_change`
91/// within the `Backend` struct. This method is triggered whenever there's a change
92/// in the file, and it parses the content using `mz_sql_parser`.
93/// Depending on the parse result, it either sneds the logs the results or any encountered errors.
94#[derive(Debug)]
95pub struct Backend {
96    /// Handles the communication to the client.
97    /// Logs and results must be sent through
98    /// the client at the end of each capability.
99    pub client: Client,
100
101    /// Contains parsing results for each open file.
102    /// Instead of retrieving the last version from the file
103    /// each time a command, like formatting, is executed,
104    /// we use the most recent parsing results stored here.
105    /// Reading from the file would access old content.
106    /// E.g. The user formats or performs an action
107    /// prior to save the file.
108    pub parse_results: Mutex<HashMap<Url, ParseResult>>,
109
110    /// Contains the latest content for each file.
111    pub content: Mutex<HashMap<Url, Rope>>,
112
113    /// Formatting width to use in mz- prettier
114    pub formatting_width: Mutex<usize>,
115
116    /// Schema available in the client
117    /// used for completion suggestions.
118    pub schema: Mutex<Option<Schema>>,
119
120    /// Completion suggestion to return
121    /// to the client when requested.
122    pub completions: Mutex<Completions>,
123}
124
125/// Represents a column from an [ObjectType
126#[derive(Debug, Clone, Deserialize)]
127pub struct SchemaObjectColumn {
128    /// Represents the column's name.
129    pub name: String,
130    /// Represents the column's type.
131    #[serde(rename = "type")]
132    pub typ: String,
133}
134
135#[derive(Debug, Clone, Deserialize)]
136#[serde(rename_all = "camelCase")]
137/// Represents each possible object type admissible by the LSP.
138pub enum ObjectType {
139    /// Represents a materialized view.
140    MaterializedView,
141    /// Represents a view.
142    View,
143    /// Represents a source.
144    Source,
145    /// Represents a table.
146    Table,
147    /// Represents a sink.
148    Sink,
149}
150
151impl std::fmt::Display for ObjectType {
152    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153        match self {
154            ObjectType::MaterializedView => write!(f, "Materialized View"),
155            ObjectType::View => write!(f, "View"),
156            ObjectType::Source => write!(f, "Source"),
157            ObjectType::Table => write!(f, "Table"),
158            ObjectType::Sink => write!(f, "Sink"),
159        }
160    }
161}
162
163/// Represents a Materialize object present in the schema,
164/// and its columns.
165///
166/// E.g. a table, view, source, etc.
167#[derive(Debug, Clone, Deserialize)]
168pub struct SchemaObject {
169    /// Represents the object type.
170    #[serde(rename = "type")]
171    pub typ: ObjectType,
172    /// Represents the object name.
173    pub name: String,
174    /// Contains all the columns available in the object.
175    pub columns: Vec<SchemaObjectColumn>,
176}
177
178/// Represents the current schema, database and all
179/// its objects the client is using.
180///
181/// This is later used to return completion items to the client.
182#[derive(Debug, Clone, Deserialize)]
183pub struct Schema {
184    /// Represents the schema name.
185    pub schema: String,
186    /// Represents the database name.
187    pub database: String,
188    /// Contains all the user objects (tables, views, sources, etc.)
189    /// available in the current database/schema.
190    pub objects: Vec<SchemaObject>,
191}
192
193/// Contains customizable options send by the client.
194#[derive(Debug, Deserialize)]
195#[serde(rename_all = "camelCase")]
196pub struct InitializeOptions {
197    /// Represents the width used to format text using [mz_sql_pretty].
198    pub formatting_width: Option<usize>,
199    /// Represents the current schema available in the client.
200    pub schema: Option<Schema>,
201}
202
203#[tower_lsp::async_trait]
204impl LanguageServer for Backend {
205    async fn initialize(&self, params: InitializeParams) -> Result<InitializeResult> {
206        // Load the formatting width and schema option sent by the client.
207        if let Some(value_options) = params.initialization_options {
208            match serde_json::from_value(value_options) {
209                Ok(options) => {
210                    let options: InitializeOptions = options;
211                    if let Some(formatting_width) = options.formatting_width {
212                        let mut formatting_width_guard = self.formatting_width.lock().await;
213                        *formatting_width_guard = formatting_width;
214                    }
215
216                    if let Some(schema) = options.schema {
217                        let mut schema_guard = self.schema.lock().await;
218                        *schema_guard = Some(schema.clone());
219                        let mut completions = self.completions.lock().await;
220                        *completions = self.build_completion_items(schema);
221                    };
222                }
223                Err(err) => {
224                    self.client
225                        .log_message(
226                            MessageType::INFO,
227                            format!("Initialization options are erroneus: {:?}", err.to_string()),
228                        )
229                        .await;
230                }
231            };
232        }
233
234        Ok(InitializeResult {
235            server_info: Some(ServerInfo {
236                name: PKG_NAME.clone(),
237                version: Some(PKG_VERSION.clone()),
238            }),
239            offset_encoding: None,
240            capabilities: ServerCapabilities {
241                document_formatting_provider: Some(tower_lsp::lsp_types::OneOf::Left(true)),
242                text_document_sync: Some(TextDocumentSyncCapability::Kind(
243                    TextDocumentSyncKind::FULL,
244                )),
245                execute_command_provider: Some(ExecuteCommandOptions {
246                    commands: vec!["parse".to_string()],
247                    work_done_progress_options: WorkDoneProgressOptions {
248                        work_done_progress: None,
249                    },
250                }),
251                completion_provider: Some(CompletionOptions {
252                    resolve_provider: Some(false),
253                    trigger_characters: Some(vec![".".to_string()]),
254                    work_done_progress_options: Default::default(),
255                    all_commit_characters: None,
256                    completion_item: None,
257                }),
258                workspace: Some(WorkspaceServerCapabilities {
259                    workspace_folders: Some(WorkspaceFoldersServerCapabilities {
260                        supported: Some(true),
261                        change_notifications: Some(OneOf::Left(true)),
262                    }),
263                    file_operations: None,
264                }),
265                ..ServerCapabilities::default()
266            },
267        })
268    }
269
270    async fn initialized(&self, _: InitializedParams) {
271        self.client
272            .log_message(MessageType::INFO, "initialized!")
273            .await;
274    }
275
276    async fn shutdown(&self) -> Result<()> {
277        Ok(())
278    }
279
280    async fn did_change_workspace_folders(&self, _: DidChangeWorkspaceFoldersParams) {
281        self.client
282            .log_message(MessageType::INFO, "workspace folders changed!")
283            .await;
284    }
285
286    async fn did_change_configuration(&self, _: DidChangeConfigurationParams) {
287        self.client
288            .log_message(MessageType::INFO, "configuration changed!")
289            .await;
290    }
291
292    async fn did_change_watched_files(&self, _: DidChangeWatchedFilesParams) {
293        self.client
294            .log_message(MessageType::INFO, "watched files have changed!")
295            .await;
296    }
297
298    /// Executes a single command and returns the response. Def: [workspace/executeCommand](https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#workspace_executeCommand)
299    ///
300    /// Commands implemented:
301    ///
302    /// * parse: returns multiple valid statements from a single sql code.
303    async fn execute_command(&self, command_params: ExecuteCommandParams) -> Result<Option<Value>> {
304        match command_params.command.as_str() {
305            "parse" => {
306                let json_args = command_params.arguments.get(0);
307
308                if let Some(json_args) = json_args {
309                    let args = serde_json::from_value::<String>(json_args.clone())
310                        .map_err(|_| build_error("Error deserializing parse args as String."))?;
311                    let statements = parse_statements(&args)
312                        .map_err(|_| build_error("Error parsing the statements."))?;
313
314                    // Transform raw statements to splitted statements
315                    // and infere the kind.
316                    // E.g. if it is a select or a create_table statement.
317                    let parse_statements: Vec<ExecuteCommandParseStatement> = statements
318                        .iter()
319                        .map(|x| ExecuteCommandParseStatement {
320                            kind: statement_kind_label_value(x.ast.clone().into()).to_string(),
321                            sql: x.sql.to_string(),
322                        })
323                        .collect();
324
325                    return Ok(Some(json!(ExecuteCommandParseResponse {
326                        statements: parse_statements
327                    })));
328                } else {
329                    return Err(build_error("Missing command args."));
330                }
331            }
332            "optionsUpdate" => {
333                let json_args = command_params.arguments.get(0);
334
335                if let Some(json_args) = json_args {
336                    let args = serde_json::from_value::<InitializeOptions>(json_args.clone())
337                        .map_err(|_| {
338                            build_error("Error deserializing parse args as InitializeOptions.")
339                        })?;
340
341                    if let Some(formatting_width) = args.formatting_width {
342                        let mut formatting_width_guard = self.formatting_width.lock().await;
343                        *formatting_width_guard = formatting_width;
344                    }
345
346                    if let Some(schema) = args.schema {
347                        let mut schema_guard = self.schema.lock().await;
348                        *schema_guard = Some(schema.clone());
349                        let mut completions = self.completions.lock().await;
350                        *completions = self.build_completion_items(schema);
351                    }
352
353                    return Ok(None);
354                } else {
355                    return Err(build_error("Missing command args."));
356                }
357            }
358            _ => {
359                return Err(build_error("Unknown command."));
360            }
361        }
362    }
363
364    async fn did_open(&self, params: DidOpenTextDocumentParams) {
365        self.client
366            .log_message(MessageType::INFO, "file opened!")
367            .await;
368
369        self.parse(TextDocumentItem {
370            uri: params.text_document.uri,
371            text: params.text_document.text,
372            version: params.text_document.version,
373        })
374        .await
375    }
376
377    async fn did_change(&self, params: DidChangeTextDocumentParams) {
378        self.client
379            .log_message(MessageType::INFO, "file changed!")
380            .await;
381
382        self.parse(TextDocumentItem {
383            uri: params.text_document.uri,
384            text: params.content_changes[0].text.clone(),
385            version: params.text_document.version,
386        })
387        .await
388    }
389
390    async fn did_save(&self, _: DidSaveTextDocumentParams) {
391        self.client
392            .log_message(MessageType::INFO, "file saved!")
393            .await;
394    }
395
396    async fn did_close(&self, _: DidCloseTextDocumentParams) {
397        self.client
398            .log_message(MessageType::INFO, "file closed!")
399            .await;
400    }
401
402    async fn code_lens(&self, _params: CodeLensParams) -> Result<Option<Vec<CodeLens>>> {
403        let _lenses: Vec<CodeLens> = vec![CodeLens {
404            range: Range {
405                start: Position::new(0, 0),
406                end: Position::new(0, 0),
407            },
408            command: Some(Command {
409                title: "Run".to_string(),
410                command: "materialize.run".to_string(),
411                arguments: None,
412            }),
413            data: None,
414        }];
415
416        // TODO: Re enable when position is correct.
417        // Ok(Some(lenses))
418        Ok(None)
419    }
420
421    /// Completion implementation.
422    async fn completion(&self, params: CompletionParams) -> Result<Option<CompletionResponse>> {
423        let uri = params.text_document_position.text_document.uri;
424        let position = params.text_document_position.position;
425
426        let content = self.content.lock().await;
427        let content = content.get(&uri);
428
429        if let Some(content) = content {
430            // Get the lex token.
431            let lex_results = lexer::lex(&content.to_string())
432                .map_err(|_| build_error("Error getting lex tokens."))?;
433            let offset = position_to_offset(position, content)
434                .ok_or_else(|| build_error("Error getting completion offset."))?;
435
436            let last_keyword = lex_results
437                .iter()
438                .filter_map(|x| {
439                    if x.offset < offset {
440                        match x.kind {
441                            Token::Keyword(k) => match k {
442                                Keyword::Select => Some(k),
443                                Keyword::From => Some(k),
444                                _ => None,
445                            },
446                            // Skip the rest for now.
447                            _ => None,
448                        }
449                    } else {
450                        None
451                    }
452                })
453                .next_back();
454
455            if let Some(keyword) = last_keyword {
456                return match keyword {
457                    Keyword::Select => {
458                        let completions = self.completions.lock().await;
459                        let select_completions = completions.select.clone();
460                        Ok(Some(CompletionResponse::Array(select_completions)))
461                    }
462                    Keyword::From => {
463                        let completions = self.completions.lock().await;
464                        let from_completions = completions.from.clone();
465                        Ok(Some(CompletionResponse::Array(from_completions)))
466                    }
467                    _ => Ok(None),
468                };
469            } else {
470                return Ok(None);
471            }
472        } else {
473            return Ok(None);
474        }
475    }
476
477    /// Formats the code using [mz_sql_pretty].
478    ///
479    /// Implements the [`textDocument/formatting`](https://microsoft.github.io/language-server-protocol/specification#textDocument_formatting) language feature.
480    async fn formatting(&self, params: DocumentFormattingParams) -> Result<Option<Vec<TextEdit>>> {
481        let locked_map = self.parse_results.lock().await;
482        let width = self.formatting_width.lock().await;
483
484        if let Some(parse_result) = locked_map.get(&params.text_document.uri) {
485            let pretty = parse_result
486                .asts
487                .iter()
488                .map(|ast| {
489                    mz_sql_pretty::to_pretty(
490                        ast,
491                        PrettyConfig {
492                            width: *width,
493                            format_mode: FormatMode::Simple,
494                        },
495                    )
496                })
497                .collect::<Vec<String>>()
498                .join("\n");
499            let rope = &parse_result.rope;
500
501            return Ok(Some(vec![TextEdit {
502                new_text: pretty,
503                range: Range {
504                    // TODO: Remove unwraps.
505                    start: offset_to_position(0, rope).unwrap(),
506                    end: offset_to_position(rope.len_chars(), rope).unwrap(),
507                },
508            }]));
509        } else {
510            return Ok(None);
511        }
512    }
513}
514
515struct TextDocumentItem {
516    uri: Url,
517    text: String,
518    version: i32,
519}
520
521impl Backend {
522    /// Parses the SQL code and publishes diagnosis about it.
523    async fn parse(&self, params: TextDocumentItem) {
524        self.client
525            .log_message(MessageType::INFO, format!("on_change {:?}", params.uri))
526            .await;
527        let rope = ropey::Rope::from_str(&params.text);
528
529        let mut content = self.content.lock().await;
530        let mut parse_results = self.parse_results.lock().await;
531
532        // Parse the text
533        let parse_result = mz_sql_parser::parser::parse_statements(&params.text);
534
535        match parse_result {
536            // The parser will return Ok when everything is well written.
537            Ok(results) => {
538                content.insert(params.uri.clone(), rope.clone());
539
540                // Clear the diagnostics in case there were issues before.
541                self.client
542                    .publish_diagnostics(params.uri.clone(), vec![], Some(params.version))
543                    .await;
544
545                let asts = results.iter().map(|x| x.ast.clone()).collect();
546                let parse_result: ParseResult = ParseResult { asts, rope };
547                parse_results.insert(params.uri, parse_result);
548            }
549
550            // If there is at least one error the parser will return Err.
551            Err(err_parsing) => {
552                let error_position = err_parsing.error.pos;
553                let start = offset_to_position(error_position, &rope).unwrap();
554                let end = start;
555                let range = Range { start, end };
556
557                parse_results.remove(&params.uri);
558
559                // Check for Jinja code (dbt)
560                // If Jinja code is detected, inform that parsing is not available..
561                if self.is_jinja(&err_parsing.error.message, params.text) {
562                    // Do not send any new diagnostics
563                    return;
564                }
565
566                // Only insert content if it is not Jinja code.
567                content.insert(params.uri.clone(), rope.clone());
568
569                let diagnostics = Diagnostic::new_simple(range, err_parsing.error.message);
570
571                self.client
572                    .publish_diagnostics(
573                        params.uri.clone(),
574                        vec![diagnostics],
575                        Some(params.version),
576                    )
577                    .await;
578            }
579        }
580    }
581
582    /// Detects if the code contains Jinja code using RegEx and
583    /// looks for Jinja's delimiters:
584    /// - {% ... %} for Statements
585    /// - {{ ... }} for Expressions to print to the template output
586    /// - {# ... #} for Comments not included in the template output
587    ///
588    /// Reference: <https://jinja.palletsprojects.com/en/3.0.x/templates/#synopsis>
589    ///
590    /// The trade-off is that the regex is simple, but it may detect some code as Jinja
591    /// when it is not actually Jinja. For example: `SELECT '{{ 100 }}';`.
592    /// To handle such cases more successfully, the server will first attempt to parse the
593    /// file, and if it fails, it will then check if it contains Jinja code.
594    fn contains_jinja_code(&self, s: &str) -> bool {
595        let re = Regex::new(r"\{\{.*?\}\}|\{%.*?%\}|\{#.*?#\}").unwrap();
596        re.is_match(s)
597    }
598
599    /// Returns true if Jinja code is detected.
600    fn is_jinja(&self, s: &str, code: String) -> bool {
601        s == "unexpected character in input: {" && self.contains_jinja_code(&code)
602    }
603
604    /// Builds the completion items for the following statements:
605    ///
606    /// * SELECT
607    /// * FROM
608    ///
609    /// Use this function to build the completion items once,
610    /// and avoid having to rebuild on every [LanguageServer::completion] call.
611    fn build_completion_items(&self, schema: Schema) -> Completions {
612        // Build SELECT completion items:
613        let mut select_completions = Vec::new();
614        let mut from_completions = Vec::new();
615
616        schema.objects.iter().for_each(|object| {
617            // Columns
618            object.columns.iter().for_each(|column| {
619                select_completions.push(CompletionItem {
620                    label: column.name.to_string(),
621                    label_details: Some(CompletionItemLabelDetails {
622                        detail: Some(column.typ.to_string()),
623                        description: None,
624                    }),
625                    kind: Some(CompletionItemKind::FIELD),
626                    detail: Some(
627                        format!(
628                            "From {}.{}.{} ({:?})",
629                            schema.database, schema.schema, object.name, object.typ
630                        )
631                        .to_string(),
632                    ),
633                    documentation: None,
634                    deprecated: Some(false),
635                    ..Default::default()
636                });
637            });
638
639            // Objects
640            from_completions.push(CompletionItem {
641                label: object.name.to_string(),
642                label_details: Some(CompletionItemLabelDetails {
643                    detail: Some(object.typ.to_string()),
644                    description: None,
645                }),
646                kind: match object.typ {
647                    ObjectType::View => Some(CompletionItemKind::ENUM_MEMBER),
648                    ObjectType::MaterializedView => Some(CompletionItemKind::ENUM),
649                    ObjectType::Source => Some(CompletionItemKind::CLASS),
650                    ObjectType::Sink => Some(CompletionItemKind::CLASS),
651                    ObjectType::Table => Some(CompletionItemKind::CONSTANT),
652                },
653                detail: Some(
654                    format!(
655                        "Represents {}.{}.{} ({:?})",
656                        schema.database, schema.schema, object.name, object.typ
657                    )
658                    .to_string(),
659                ),
660                documentation: None,
661                deprecated: Some(false),
662                ..Default::default()
663            });
664        });
665
666        Completions {
667            from: from_completions,
668            select: select_completions,
669        }
670    }
671}
672
673/// This function converts a (line, column) position in the text to an offset in the file.
674///
675/// It is the inverse of the `offset_to_position` function.
676fn position_to_offset(position: Position, rope: &Rope) -> Option<usize> {
677    // Convert line and column from u32 back to usize
678    let line: usize = position.line.try_into().ok()?;
679    let column: usize = position.character.try_into().ok()?;
680
681    // Get the offset of the first character of the line
682    let first_char_of_line_offset = rope.try_line_to_char(line).ok()?;
683
684    // Calculate the offset by adding the column number to the first character of the line's offset
685    let offset = first_char_of_line_offset + column;
686
687    Some(offset)
688}
689
690/// This function is a helper function that converts an offset in the file to a (line, column).
691///
692/// It is useful when translating an ofsset returned by [mz_sql_parser::parser::parse_statements]
693/// to an (x,y) position in the text to represent the error in the correct token.
694fn offset_to_position(offset: usize, rope: &Rope) -> Option<Position> {
695    let line = rope.try_char_to_line(offset).ok()?;
696    let first_char_of_line = rope.try_line_to_char(line).ok()?;
697    let column = offset - first_char_of_line;
698
699    // Convert to u32.
700    let line_u32 = line.try_into().ok()?;
701    let column_u32 = column.try_into().ok()?;
702
703    Some(Position::new(line_u32, column_u32))
704}
705
706/// Builds a [tower_lsp::jsonrpc::Error]
707///
708/// Use this function to map normal errors to the one the trait expects
709fn build_error(message: &'static str) -> tower_lsp::jsonrpc::Error {
710    Error {
711        code: ErrorCode::InternalError,
712        message: std::borrow::Cow::Borrowed(message),
713        data: None,
714    }
715}