1use ::serde::Deserialize;
20use mz_ore::collections::HashMap;
21use mz_sql_lexer::keywords::Keyword;
22use mz_sql_lexer::lexer::{self, Token};
23use mz_sql_parser::ast::display::FormatMode;
24use mz_sql_parser::ast::{Raw, Statement, statement_kind_label_value};
25use mz_sql_parser::parser::parse_statements;
26use mz_sql_pretty::PrettyConfig;
27use regex::Regex;
28use ropey::Rope;
29use serde::Serialize;
30use serde_json::{Value, json};
31use tokio::sync::Mutex;
32use tower_lsp::jsonrpc::{Error, ErrorCode, Result};
33use tower_lsp::lsp_types::*;
34use tower_lsp::{Client, LanguageServer};
35
36use crate::{PKG_NAME, PKG_VERSION};
37
38pub const DEFAULT_FORMATTING_WIDTH: usize = 100;
40
41#[derive(Debug)]
44pub struct ParseResult {
45 pub asts: Vec<Statement<Raw>>,
48 pub rope: Rope,
50}
51
52#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
53pub struct ExecuteCommandParseStatement {
56 pub sql: String,
58 pub kind: String,
61}
62
63#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
64pub struct ExecuteCommandParseResponse {
66 pub statements: Vec<ExecuteCommandParseStatement>,
68}
69
70#[derive(Debug)]
73pub struct Completions {
74 pub select: Vec<CompletionItem>,
77 pub from: Vec<CompletionItem>,
80}
81
82#[derive(Debug)]
95pub struct Backend {
96 pub client: Client,
100
101 pub parse_results: Mutex<HashMap<Url, ParseResult>>,
109
110 pub content: Mutex<HashMap<Url, Rope>>,
112
113 pub formatting_width: Mutex<usize>,
115
116 pub schema: Mutex<Option<Schema>>,
119
120 pub completions: Mutex<Completions>,
123}
124
125#[derive(Debug, Clone, Deserialize)]
127pub struct SchemaObjectColumn {
128 pub name: String,
130 #[serde(rename = "type")]
132 pub typ: String,
133}
134
135#[derive(Debug, Clone, Deserialize)]
136#[serde(rename_all = "camelCase")]
137pub enum ObjectType {
139 MaterializedView,
141 View,
143 Source,
145 Table,
147 Sink,
149}
150
151impl std::fmt::Display for ObjectType {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 match self {
154 ObjectType::MaterializedView => write!(f, "Materialized View"),
155 ObjectType::View => write!(f, "View"),
156 ObjectType::Source => write!(f, "Source"),
157 ObjectType::Table => write!(f, "Table"),
158 ObjectType::Sink => write!(f, "Sink"),
159 }
160 }
161}
162
163#[derive(Debug, Clone, Deserialize)]
168pub struct SchemaObject {
169 #[serde(rename = "type")]
171 pub typ: ObjectType,
172 pub name: String,
174 pub columns: Vec<SchemaObjectColumn>,
176}
177
178#[derive(Debug, Clone, Deserialize)]
183pub struct Schema {
184 pub schema: String,
186 pub database: String,
188 pub objects: Vec<SchemaObject>,
191}
192
193#[derive(Debug, Deserialize)]
195#[serde(rename_all = "camelCase")]
196pub struct InitializeOptions {
197 pub formatting_width: Option<usize>,
199 pub schema: Option<Schema>,
201}
202
203#[tower_lsp::async_trait]
204impl LanguageServer for Backend {
205 async fn initialize(&self, params: InitializeParams) -> Result<InitializeResult> {
206 if let Some(value_options) = params.initialization_options {
208 match serde_json::from_value(value_options) {
209 Ok(options) => {
210 let options: InitializeOptions = options;
211 if let Some(formatting_width) = options.formatting_width {
212 let mut formatting_width_guard = self.formatting_width.lock().await;
213 *formatting_width_guard = formatting_width;
214 }
215
216 if let Some(schema) = options.schema {
217 let mut schema_guard = self.schema.lock().await;
218 *schema_guard = Some(schema.clone());
219 let mut completions = self.completions.lock().await;
220 *completions = self.build_completion_items(schema);
221 };
222 }
223 Err(err) => {
224 self.client
225 .log_message(
226 MessageType::INFO,
227 format!("Initialization options are erroneus: {:?}", err.to_string()),
228 )
229 .await;
230 }
231 };
232 }
233
234 Ok(InitializeResult {
235 server_info: Some(ServerInfo {
236 name: PKG_NAME.clone(),
237 version: Some(PKG_VERSION.clone()),
238 }),
239 offset_encoding: None,
240 capabilities: ServerCapabilities {
241 document_formatting_provider: Some(tower_lsp::lsp_types::OneOf::Left(true)),
242 text_document_sync: Some(TextDocumentSyncCapability::Kind(
243 TextDocumentSyncKind::FULL,
244 )),
245 execute_command_provider: Some(ExecuteCommandOptions {
246 commands: vec!["parse".to_string()],
247 work_done_progress_options: WorkDoneProgressOptions {
248 work_done_progress: None,
249 },
250 }),
251 completion_provider: Some(CompletionOptions {
252 resolve_provider: Some(false),
253 trigger_characters: Some(vec![".".to_string()]),
254 work_done_progress_options: Default::default(),
255 all_commit_characters: None,
256 completion_item: None,
257 }),
258 workspace: Some(WorkspaceServerCapabilities {
259 workspace_folders: Some(WorkspaceFoldersServerCapabilities {
260 supported: Some(true),
261 change_notifications: Some(OneOf::Left(true)),
262 }),
263 file_operations: None,
264 }),
265 ..ServerCapabilities::default()
266 },
267 })
268 }
269
270 async fn initialized(&self, _: InitializedParams) {
271 self.client
272 .log_message(MessageType::INFO, "initialized!")
273 .await;
274 }
275
276 async fn shutdown(&self) -> Result<()> {
277 Ok(())
278 }
279
280 async fn did_change_workspace_folders(&self, _: DidChangeWorkspaceFoldersParams) {
281 self.client
282 .log_message(MessageType::INFO, "workspace folders changed!")
283 .await;
284 }
285
286 async fn did_change_configuration(&self, _: DidChangeConfigurationParams) {
287 self.client
288 .log_message(MessageType::INFO, "configuration changed!")
289 .await;
290 }
291
292 async fn did_change_watched_files(&self, _: DidChangeWatchedFilesParams) {
293 self.client
294 .log_message(MessageType::INFO, "watched files have changed!")
295 .await;
296 }
297
298 async fn execute_command(&self, command_params: ExecuteCommandParams) -> Result<Option<Value>> {
304 match command_params.command.as_str() {
305 "parse" => {
306 let json_args = command_params.arguments.get(0);
307
308 if let Some(json_args) = json_args {
309 let args = serde_json::from_value::<String>(json_args.clone())
310 .map_err(|_| build_error("Error deserializing parse args as String."))?;
311 let statements = parse_statements(&args)
312 .map_err(|_| build_error("Error parsing the statements."))?;
313
314 let parse_statements: Vec<ExecuteCommandParseStatement> = statements
318 .iter()
319 .map(|x| ExecuteCommandParseStatement {
320 kind: statement_kind_label_value(x.ast.clone().into()).to_string(),
321 sql: x.sql.to_string(),
322 })
323 .collect();
324
325 return Ok(Some(json!(ExecuteCommandParseResponse {
326 statements: parse_statements
327 })));
328 } else {
329 return Err(build_error("Missing command args."));
330 }
331 }
332 "optionsUpdate" => {
333 let json_args = command_params.arguments.get(0);
334
335 if let Some(json_args) = json_args {
336 let args = serde_json::from_value::<InitializeOptions>(json_args.clone())
337 .map_err(|_| {
338 build_error("Error deserializing parse args as InitializeOptions.")
339 })?;
340
341 if let Some(formatting_width) = args.formatting_width {
342 let mut formatting_width_guard = self.formatting_width.lock().await;
343 *formatting_width_guard = formatting_width;
344 }
345
346 if let Some(schema) = args.schema {
347 let mut schema_guard = self.schema.lock().await;
348 *schema_guard = Some(schema.clone());
349 let mut completions = self.completions.lock().await;
350 *completions = self.build_completion_items(schema);
351 }
352
353 return Ok(None);
354 } else {
355 return Err(build_error("Missing command args."));
356 }
357 }
358 _ => {
359 return Err(build_error("Unknown command."));
360 }
361 }
362 }
363
364 async fn did_open(&self, params: DidOpenTextDocumentParams) {
365 self.client
366 .log_message(MessageType::INFO, "file opened!")
367 .await;
368
369 self.parse(TextDocumentItem {
370 uri: params.text_document.uri,
371 text: params.text_document.text,
372 version: params.text_document.version,
373 })
374 .await
375 }
376
377 async fn did_change(&self, params: DidChangeTextDocumentParams) {
378 self.client
379 .log_message(MessageType::INFO, "file changed!")
380 .await;
381
382 self.parse(TextDocumentItem {
383 uri: params.text_document.uri,
384 text: params.content_changes[0].text.clone(),
385 version: params.text_document.version,
386 })
387 .await
388 }
389
390 async fn did_save(&self, _: DidSaveTextDocumentParams) {
391 self.client
392 .log_message(MessageType::INFO, "file saved!")
393 .await;
394 }
395
396 async fn did_close(&self, _: DidCloseTextDocumentParams) {
397 self.client
398 .log_message(MessageType::INFO, "file closed!")
399 .await;
400 }
401
402 async fn code_lens(&self, _params: CodeLensParams) -> Result<Option<Vec<CodeLens>>> {
403 let _lenses: Vec<CodeLens> = vec![CodeLens {
404 range: Range {
405 start: Position::new(0, 0),
406 end: Position::new(0, 0),
407 },
408 command: Some(Command {
409 title: "Run".to_string(),
410 command: "materialize.run".to_string(),
411 arguments: None,
412 }),
413 data: None,
414 }];
415
416 Ok(None)
419 }
420
421 async fn completion(&self, params: CompletionParams) -> Result<Option<CompletionResponse>> {
423 let uri = params.text_document_position.text_document.uri;
424 let position = params.text_document_position.position;
425
426 let content = self.content.lock().await;
427 let content = content.get(&uri);
428
429 if let Some(content) = content {
430 let lex_results = lexer::lex(&content.to_string())
432 .map_err(|_| build_error("Error getting lex tokens."))?;
433 let offset = position_to_offset(position, content)
434 .ok_or_else(|| build_error("Error getting completion offset."))?;
435
436 let last_keyword = lex_results
437 .iter()
438 .filter_map(|x| {
439 if x.offset < offset {
440 match x.kind {
441 Token::Keyword(k) => match k {
442 Keyword::Select => Some(k),
443 Keyword::From => Some(k),
444 _ => None,
445 },
446 _ => None,
448 }
449 } else {
450 None
451 }
452 })
453 .next_back();
454
455 if let Some(keyword) = last_keyword {
456 return match keyword {
457 Keyword::Select => {
458 let completions = self.completions.lock().await;
459 let select_completions = completions.select.clone();
460 Ok(Some(CompletionResponse::Array(select_completions)))
461 }
462 Keyword::From => {
463 let completions = self.completions.lock().await;
464 let from_completions = completions.from.clone();
465 Ok(Some(CompletionResponse::Array(from_completions)))
466 }
467 _ => Ok(None),
468 };
469 } else {
470 return Ok(None);
471 }
472 } else {
473 return Ok(None);
474 }
475 }
476
477 async fn formatting(&self, params: DocumentFormattingParams) -> Result<Option<Vec<TextEdit>>> {
481 let locked_map = self.parse_results.lock().await;
482 let width = self.formatting_width.lock().await;
483
484 if let Some(parse_result) = locked_map.get(¶ms.text_document.uri) {
485 let pretty = parse_result
486 .asts
487 .iter()
488 .map(|ast| {
489 mz_sql_pretty::to_pretty(
490 ast,
491 PrettyConfig {
492 width: *width,
493 format_mode: FormatMode::Simple,
494 },
495 )
496 })
497 .collect::<Vec<String>>()
498 .join("\n");
499 let rope = &parse_result.rope;
500
501 return Ok(Some(vec![TextEdit {
502 new_text: pretty,
503 range: Range {
504 start: offset_to_position(0, rope).unwrap(),
506 end: offset_to_position(rope.len_chars(), rope).unwrap(),
507 },
508 }]));
509 } else {
510 return Ok(None);
511 }
512 }
513}
514
515struct TextDocumentItem {
516 uri: Url,
517 text: String,
518 version: i32,
519}
520
521impl Backend {
522 async fn parse(&self, params: TextDocumentItem) {
524 self.client
525 .log_message(MessageType::INFO, format!("on_change {:?}", params.uri))
526 .await;
527 let rope = ropey::Rope::from_str(¶ms.text);
528
529 let mut content = self.content.lock().await;
530 let mut parse_results = self.parse_results.lock().await;
531
532 let parse_result = mz_sql_parser::parser::parse_statements(¶ms.text);
534
535 match parse_result {
536 Ok(results) => {
538 content.insert(params.uri.clone(), rope.clone());
539
540 self.client
542 .publish_diagnostics(params.uri.clone(), vec![], Some(params.version))
543 .await;
544
545 let asts = results.iter().map(|x| x.ast.clone()).collect();
546 let parse_result: ParseResult = ParseResult { asts, rope };
547 parse_results.insert(params.uri, parse_result);
548 }
549
550 Err(err_parsing) => {
552 let error_position = err_parsing.error.pos;
553 let start = offset_to_position(error_position, &rope).unwrap();
554 let end = start;
555 let range = Range { start, end };
556
557 parse_results.remove(¶ms.uri);
558
559 if self.is_jinja(&err_parsing.error.message, params.text) {
562 return;
564 }
565
566 content.insert(params.uri.clone(), rope.clone());
568
569 let diagnostics = Diagnostic::new_simple(range, err_parsing.error.message);
570
571 self.client
572 .publish_diagnostics(
573 params.uri.clone(),
574 vec![diagnostics],
575 Some(params.version),
576 )
577 .await;
578 }
579 }
580 }
581
582 fn contains_jinja_code(&self, s: &str) -> bool {
595 let re = Regex::new(r"\{\{.*?\}\}|\{%.*?%\}|\{#.*?#\}").unwrap();
596 re.is_match(s)
597 }
598
599 fn is_jinja(&self, s: &str, code: String) -> bool {
601 s == "unexpected character in input: {" && self.contains_jinja_code(&code)
602 }
603
604 fn build_completion_items(&self, schema: Schema) -> Completions {
612 let mut select_completions = Vec::new();
614 let mut from_completions = Vec::new();
615
616 schema.objects.iter().for_each(|object| {
617 object.columns.iter().for_each(|column| {
619 select_completions.push(CompletionItem {
620 label: column.name.to_string(),
621 label_details: Some(CompletionItemLabelDetails {
622 detail: Some(column.typ.to_string()),
623 description: None,
624 }),
625 kind: Some(CompletionItemKind::FIELD),
626 detail: Some(
627 format!(
628 "From {}.{}.{} ({:?})",
629 schema.database, schema.schema, object.name, object.typ
630 )
631 .to_string(),
632 ),
633 documentation: None,
634 deprecated: Some(false),
635 ..Default::default()
636 });
637 });
638
639 from_completions.push(CompletionItem {
641 label: object.name.to_string(),
642 label_details: Some(CompletionItemLabelDetails {
643 detail: Some(object.typ.to_string()),
644 description: None,
645 }),
646 kind: match object.typ {
647 ObjectType::View => Some(CompletionItemKind::ENUM_MEMBER),
648 ObjectType::MaterializedView => Some(CompletionItemKind::ENUM),
649 ObjectType::Source => Some(CompletionItemKind::CLASS),
650 ObjectType::Sink => Some(CompletionItemKind::CLASS),
651 ObjectType::Table => Some(CompletionItemKind::CONSTANT),
652 },
653 detail: Some(
654 format!(
655 "Represents {}.{}.{} ({:?})",
656 schema.database, schema.schema, object.name, object.typ
657 )
658 .to_string(),
659 ),
660 documentation: None,
661 deprecated: Some(false),
662 ..Default::default()
663 });
664 });
665
666 Completions {
667 from: from_completions,
668 select: select_completions,
669 }
670 }
671}
672
673fn position_to_offset(position: Position, rope: &Rope) -> Option<usize> {
677 let line: usize = position.line.try_into().ok()?;
679 let column: usize = position.character.try_into().ok()?;
680
681 let first_char_of_line_offset = rope.try_line_to_char(line).ok()?;
683
684 let offset = first_char_of_line_offset + column;
686
687 Some(offset)
688}
689
690fn offset_to_position(offset: usize, rope: &Rope) -> Option<Position> {
695 let line = rope.try_char_to_line(offset).ok()?;
696 let first_char_of_line = rope.try_line_to_char(line).ok()?;
697 let column = offset - first_char_of_line;
698
699 let line_u32 = line.try_into().ok()?;
701 let column_u32 = column.try_into().ok()?;
702
703 Some(Position::new(line_u32, column_u32))
704}
705
706fn build_error(message: &'static str) -> tower_lsp::jsonrpc::Error {
710 Error {
711 code: ErrorCode::InternalError,
712 message: std::borrow::Cow::Borrowed(message),
713 data: None,
714 }
715}