Skip to main content

mz_deploy/lsp/
semantic_tokens.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! LSP semantic tokens handler.
11//!
12//! Implements `textDocument/semanticTokens/full` by lexing the document with
13//! [`mz_sql_lexer::lexer::lex`] and mapping each token to a standard LSP
14//! token type. Comments (discarded by the lexer) are recovered via a
15//! separate pre-scan that is aware of strings and quoted identifiers.
16//!
17//! The output is delta-encoded per LSP 3.16: tokens are sorted by byte
18//! offset, split across line boundaries (LSP tokens are line-local), and
19//! serialized as a flat `[deltaLine, deltaStartChar, length, tokenType, 0]`
20//! sequence.
21//!
22//! Legend indices must match the order declared in the server's
23//! `SemanticTokensLegend` (see [`legend_token_types`]).
24
25use mz_sql_lexer::lexer::{self, PosToken, Token};
26use tower_lsp::lsp_types::{SemanticToken, SemanticTokenType};
27
28const TOKEN_TYPE_KEYWORD: u32 = 0;
29const TOKEN_TYPE_STRING: u32 = 1;
30const TOKEN_TYPE_NUMBER: u32 = 2;
31const TOKEN_TYPE_OPERATOR: u32 = 3;
32const TOKEN_TYPE_VARIABLE: u32 = 4;
33const TOKEN_TYPE_PARAMETER: u32 = 5;
34const TOKEN_TYPE_COMMENT: u32 = 6;
35
36/// Token types in the order required for legend indices.
37pub(super) fn legend_token_types() -> Vec<SemanticTokenType> {
38    vec![
39        SemanticTokenType::KEYWORD,
40        SemanticTokenType::STRING,
41        SemanticTokenType::NUMBER,
42        SemanticTokenType::OPERATOR,
43        SemanticTokenType::VARIABLE,
44        SemanticTokenType::PARAMETER,
45        SemanticTokenType::COMMENT,
46    ]
47}
48
49/// Byte-offset span with an associated semantic token type.
50#[derive(Debug, Clone, Copy)]
51struct RawSpan {
52    start: usize,
53    end: usize,
54    token_type: u32,
55}
56
57/// A line-local semantic token, after multi-line splitting.
58#[derive(Debug, Clone, Copy)]
59struct LineToken {
60    line: u32,
61    start_char: u32,
62    length: u32,
63    token_type: u32,
64}
65
66/// Computes the semantic tokens for a SQL document.
67///
68/// Produces a delta-encoded `Vec<SemanticToken>` suitable for returning in
69/// a `SemanticTokensResult::Tokens`. Returns an empty vec on empty input.
70/// On lexer error, emits the tokens collected up to the error site plus
71/// comments from the full pre-scan; never panics.
72pub(super) fn compute_semantic_tokens(text: &str) -> Vec<SemanticToken> {
73    let mut spans = Vec::new();
74    collect_comments(text, &mut spans);
75    if let Ok(tokens) = lexer::lex(text) {
76        for tok in &tokens {
77            if let Some(span) = lex_token_span(tok, text) {
78                spans.push(span);
79            }
80        }
81    }
82    spans.sort_by_key(|s| s.start);
83
84    let line_tokens = split_across_lines(text, &spans);
85    encode_deltas(&line_tokens)
86}
87
88/// Pre-scan raw text for `--` line comments and `/* */` block comments.
89/// String bodies and quoted-identifier bodies are skipped so that comment
90/// markers inside them are not misidentified.
91fn collect_comments(text: &str, out: &mut Vec<RawSpan>) {
92    let bytes = text.as_bytes();
93    let mut i = 0;
94    while i < bytes.len() {
95        match bytes[i] {
96            b'\'' => i = skip_single_quoted(bytes, i),
97            b'"' => i = skip_double_quoted(bytes, i),
98            b'-' if bytes.get(i + 1) == Some(&b'-') => {
99                let start = i;
100                while i < bytes.len() && bytes[i] != b'\n' {
101                    i += 1;
102                }
103                out.push(RawSpan {
104                    start,
105                    end: i,
106                    token_type: TOKEN_TYPE_COMMENT,
107                });
108            }
109            b'/' if bytes.get(i + 1) == Some(&b'*') => {
110                let start = i;
111                i += 2;
112                let mut depth = 1usize;
113                while i + 1 < bytes.len() && depth > 0 {
114                    if bytes[i] == b'/' && bytes[i + 1] == b'*' {
115                        depth += 1;
116                        i += 2;
117                    } else if bytes[i] == b'*' && bytes[i + 1] == b'/' {
118                        depth -= 1;
119                        i += 2;
120                    } else {
121                        i += 1;
122                    }
123                }
124                if depth > 0 {
125                    i = bytes.len();
126                }
127                out.push(RawSpan {
128                    start,
129                    end: i,
130                    token_type: TOKEN_TYPE_COMMENT,
131                });
132            }
133            _ => i += 1,
134        }
135    }
136}
137
138/// Skip a `'...'` string body (with doubled-quote escape). Returns index
139/// just past the closing quote, or `bytes.len()` if unterminated.
140fn skip_single_quoted(bytes: &[u8], start: usize) -> usize {
141    let mut i = start + 1;
142    while i < bytes.len() {
143        match bytes[i] {
144            b'\'' if bytes.get(i + 1) == Some(&b'\'') => i += 2,
145            b'\'' => return i + 1,
146            b'\\' if i + 1 < bytes.len() => i += 2,
147            _ => i += 1,
148        }
149    }
150    bytes.len()
151}
152
153/// Skip a `"..."` quoted-identifier body (with doubled-quote escape).
154fn skip_double_quoted(bytes: &[u8], start: usize) -> usize {
155    let mut i = start + 1;
156    while i < bytes.len() {
157        match bytes[i] {
158            b'"' if bytes.get(i + 1) == Some(&b'"') => i += 2,
159            b'"' => return i + 1,
160            _ => i += 1,
161        }
162    }
163    bytes.len()
164}
165
166/// Map a lexer token to its byte span and semantic type.
167///
168/// Returns `None` for tokens that should not be highlighted (punctuation).
169fn lex_token_span(tok: &PosToken, text: &str) -> Option<RawSpan> {
170    let start = tok.offset;
171    let bytes = text.as_bytes();
172    let (end, token_type) = match &tok.kind {
173        Token::Keyword(kw) => (start + kw.as_str().len(), TOKEN_TYPE_KEYWORD),
174        Token::Op(op) => (start + op.len(), TOKEN_TYPE_OPERATOR),
175        Token::Number(n) => (start + n.len(), TOKEN_TYPE_NUMBER),
176        Token::Star | Token::Eq | Token::Colon => (start + 1, TOKEN_TYPE_OPERATOR),
177        Token::DoubleColon | Token::Arrow => (start + 2, TOKEN_TYPE_OPERATOR),
178        Token::Ident(_) => (start + scan_ident_len(bytes, start), TOKEN_TYPE_VARIABLE),
179        Token::String(_) => (
180            start + scan_string_token_len(bytes, start),
181            TOKEN_TYPE_STRING,
182        ),
183        Token::HexString(_) => (
184            start + scan_hex_string_token_len(bytes, start),
185            TOKEN_TYPE_STRING,
186        ),
187        Token::Parameter(_) => (
188            start + scan_parameter_len(bytes, start),
189            TOKEN_TYPE_PARAMETER,
190        ),
191        Token::LParen
192        | Token::RParen
193        | Token::LBracket
194        | Token::RBracket
195        | Token::Dot
196        | Token::Comma
197        | Token::Semicolon => return None,
198    };
199    Some(RawSpan {
200        start,
201        end,
202        token_type,
203    })
204}
205
206fn scan_ident_len(bytes: &[u8], start: usize) -> usize {
207    if bytes.get(start) == Some(&b'"') {
208        skip_double_quoted(bytes, start) - start
209    } else {
210        let mut i = start;
211        while i < bytes.len() {
212            let c = bytes[i];
213            if c.is_ascii_alphanumeric() || c == b'_' || c == b'$' || c >= 0x80 {
214                i += 1;
215            } else {
216                break;
217            }
218        }
219        i - start
220    }
221}
222
223/// Length of a string token. May be a normal `'...'` or extended `E'...'` /
224/// `e'...'` form (the E prefix is part of the token offset).
225fn scan_string_token_len(bytes: &[u8], start: usize) -> usize {
226    let quote_pos = if bytes.get(start) == Some(&b'\'') {
227        start
228    } else {
229        // E-prefix extended string, or dollar-quoted $$...$$.
230        // Handle dollar-quoted by seeking to the next `$`.
231        if bytes.get(start) == Some(&b'$') {
232            return scan_dollar_quoted_len(bytes, start);
233        }
234        start + 1
235    };
236    skip_single_quoted(bytes, quote_pos) - start
237}
238
239/// Length of a hex string token: `x'...'` or `X'...'`.
240fn scan_hex_string_token_len(bytes: &[u8], start: usize) -> usize {
241    let quote_pos = start + 1;
242    skip_single_quoted(bytes, quote_pos) - start
243}
244
245/// Length of a `$tag$body$tag$` dollar-quoted string. Matches the outer
246/// delimiter using its tag (possibly empty).
247fn scan_dollar_quoted_len(bytes: &[u8], start: usize) -> usize {
248    // Find first `$` that closes the tag.
249    let mut i = start + 1;
250    while i < bytes.len() && bytes[i] != b'$' {
251        i += 1;
252    }
253    if i >= bytes.len() {
254        return bytes.len() - start;
255    }
256    let tag = &bytes[start..=i]; // includes both $s
257    i += 1;
258    while i + tag.len() <= bytes.len() {
259        if &bytes[i..i + tag.len()] == tag {
260            return (i + tag.len()) - start;
261        }
262        i += 1;
263    }
264    bytes.len() - start
265}
266
267fn scan_parameter_len(bytes: &[u8], start: usize) -> usize {
268    let mut i = start + 1;
269    while i < bytes.len() && bytes[i].is_ascii_digit() {
270        i += 1;
271    }
272    i - start
273}
274
275/// Split each raw span across line boundaries and compute UTF-16 column
276/// offsets. Produces line-local tokens, still in byte-order.
277fn split_across_lines(text: &str, spans: &[RawSpan]) -> Vec<LineToken> {
278    let line_starts = line_starts(text);
279    let mut out = Vec::with_capacity(spans.len());
280    for span in spans {
281        if span.end <= span.start {
282            continue;
283        }
284        let start_line = line_for_offset(&line_starts, span.start);
285        let end_line = line_for_offset(&line_starts, span.end.saturating_sub(1));
286        if start_line == end_line {
287            let line_start = line_starts[start_line];
288            let start_col = utf16_len(&text[line_start..span.start]);
289            let length = utf16_len(&text[span.start..span.end]);
290            if length > 0 {
291                out.push(LineToken {
292                    line: saturating_u32(start_line),
293                    start_char: saturating_u32(start_col),
294                    length: saturating_u32(length),
295                    token_type: span.token_type,
296                });
297            }
298        } else {
299            // Multi-line: emit one token per line segment.
300            for line in start_line..=end_line {
301                let line_start = line_starts[line];
302                let line_end = line_starts.get(line + 1).copied().unwrap_or(text.len());
303                let seg_start = span.start.max(line_start);
304                // Trim trailing newline from the segment so we don't emit
305                // a token that spans into the next line's column 0.
306                let seg_end_raw = span.end.min(line_end);
307                let seg_end = trim_trailing_newline(text, line_start, seg_end_raw);
308                if seg_end <= seg_start {
309                    continue;
310                }
311                let start_col = utf16_len(&text[line_start..seg_start]);
312                let length = utf16_len(&text[seg_start..seg_end]);
313                if length > 0 {
314                    out.push(LineToken {
315                        line: saturating_u32(line),
316                        start_char: saturating_u32(start_col),
317                        length: saturating_u32(length),
318                        token_type: span.token_type,
319                    });
320                }
321            }
322        }
323    }
324    out.sort_by(|a, b| a.line.cmp(&b.line).then(a.start_char.cmp(&b.start_char)));
325    out
326}
327
328/// Trim a trailing `\n` or `\r\n` from a segment so it doesn't include the
329/// line terminator.
330fn trim_trailing_newline(text: &str, line_start: usize, end: usize) -> usize {
331    let bytes = text.as_bytes();
332    let mut e = end;
333    if e > line_start && bytes.get(e - 1) == Some(&b'\n') {
334        e -= 1;
335        if e > line_start && bytes.get(e - 1) == Some(&b'\r') {
336            e -= 1;
337        }
338    }
339    e
340}
341
342/// Byte offsets of the start of each line (including line 0 at offset 0).
343fn line_starts(text: &str) -> Vec<usize> {
344    let mut v = vec![0];
345    for (i, b) in text.bytes().enumerate() {
346        if b == b'\n' {
347            v.push(i + 1);
348        }
349    }
350    v
351}
352
353/// Binary search for the line containing `offset`.
354fn line_for_offset(line_starts: &[usize], offset: usize) -> usize {
355    match line_starts.binary_search(&offset) {
356        Ok(i) => i,
357        Err(i) => i - 1,
358    }
359}
360
361/// Number of UTF-16 code units in `s`. ASCII-only fast path returns the
362/// byte length; non-ASCII walks chars and sums `len_utf16`.
363fn utf16_len(s: &str) -> usize {
364    if s.is_ascii() {
365        return s.len();
366    }
367    s.chars().map(|c| c.len_utf16()).sum()
368}
369
370/// Convert a `usize` (line/column/length in the document) into the `u32` width
371/// required by the LSP semantic-token wire format. No-op on values below
372/// `u32::MAX`; saturates otherwise. LSP positions are specified to be `u32`,
373/// so any document large enough to saturate is already unrepresentable.
374fn saturating_u32(v: usize) -> u32 {
375    u32::try_from(v).unwrap_or(u32::MAX)
376}
377
378/// Delta-encode line-local tokens per LSP 3.16.
379fn encode_deltas(tokens: &[LineToken]) -> Vec<SemanticToken> {
380    let mut out = Vec::with_capacity(tokens.len());
381    let mut prev_line: u32 = 0;
382    let mut prev_char: u32 = 0;
383    for t in tokens {
384        let delta_line = t.line - prev_line;
385        let delta_start = if delta_line == 0 {
386            t.start_char - prev_char
387        } else {
388            t.start_char
389        };
390        out.push(SemanticToken {
391            delta_line,
392            delta_start,
393            length: t.length,
394            token_type: t.token_type,
395            token_modifiers_bitset: 0,
396        });
397        prev_line = t.line;
398        prev_char = t.start_char;
399    }
400    out
401}
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406    use mz_ore::cast::CastFrom;
407
408    /// Decode a delta-encoded token list into absolute
409    /// `(line, start_char, length, token_type)` tuples for easier assertions.
410    fn decode(tokens: &[SemanticToken]) -> Vec<(u32, u32, u32, u32)> {
411        let mut out = Vec::new();
412        let mut line = 0u32;
413        let mut ch = 0u32;
414        for t in tokens {
415            line += t.delta_line;
416            ch = if t.delta_line == 0 {
417                ch + t.delta_start
418            } else {
419                t.delta_start
420            };
421            out.push((line, ch, t.length, t.token_type));
422        }
423        out
424    }
425
426    #[mz_ore::test]
427    fn empty_input() {
428        assert!(compute_semantic_tokens("").is_empty());
429    }
430
431    #[mz_ore::test]
432    fn basic_select() {
433        let sql = "SELECT 1 FROM t";
434        let decoded = decode(&compute_semantic_tokens(sql));
435        assert_eq!(
436            decoded,
437            vec![
438                (0, 0, 6, TOKEN_TYPE_KEYWORD),   // SELECT
439                (0, 7, 1, TOKEN_TYPE_NUMBER),    // 1
440                (0, 9, 4, TOKEN_TYPE_KEYWORD),   // FROM
441                (0, 14, 1, TOKEN_TYPE_VARIABLE), // t
442            ]
443        );
444    }
445
446    #[mz_ore::test]
447    fn create_mv() {
448        let sql = "CREATE MATERIALIZED VIEW foo AS SELECT 1";
449        let decoded = decode(&compute_semantic_tokens(sql));
450        assert_eq!(
451            decoded,
452            vec![
453                (0, 0, 6, TOKEN_TYPE_KEYWORD),   // CREATE
454                (0, 7, 12, TOKEN_TYPE_KEYWORD),  // MATERIALIZED
455                (0, 20, 4, TOKEN_TYPE_KEYWORD),  // VIEW
456                (0, 25, 3, TOKEN_TYPE_VARIABLE), // foo
457                (0, 29, 2, TOKEN_TYPE_KEYWORD),  // AS
458                (0, 32, 6, TOKEN_TYPE_KEYWORD),  // SELECT
459                (0, 39, 1, TOKEN_TYPE_NUMBER),   // 1
460            ]
461        );
462    }
463
464    #[mz_ore::test]
465    fn line_comment_mid_statement() {
466        let sql = "SELECT -- hi\n1";
467        let decoded = decode(&compute_semantic_tokens(sql));
468        assert_eq!(
469            decoded,
470            vec![
471                (0, 0, 6, TOKEN_TYPE_KEYWORD), // SELECT
472                (0, 7, 5, TOKEN_TYPE_COMMENT), // -- hi
473                (1, 0, 1, TOKEN_TYPE_NUMBER),  // 1
474            ]
475        );
476    }
477
478    #[mz_ore::test]
479    fn block_comment_single_line() {
480        let sql = "SELECT /* x */ 1";
481        let decoded = decode(&compute_semantic_tokens(sql));
482        assert_eq!(
483            decoded,
484            vec![
485                (0, 0, 6, TOKEN_TYPE_KEYWORD), // SELECT
486                (0, 7, 7, TOKEN_TYPE_COMMENT), // /* x */
487                (0, 15, 1, TOKEN_TYPE_NUMBER), // 1
488            ]
489        );
490    }
491
492    #[mz_ore::test]
493    fn block_comment_multiline() {
494        let sql = "/*\na\nb\n*/";
495        let decoded = decode(&compute_semantic_tokens(sql));
496        assert_eq!(
497            decoded,
498            vec![
499                (0, 0, 2, TOKEN_TYPE_COMMENT), // /*
500                (1, 0, 1, TOKEN_TYPE_COMMENT), // a
501                (2, 0, 1, TOKEN_TYPE_COMMENT), // b
502                (3, 0, 2, TOKEN_TYPE_COMMENT), // */
503            ]
504        );
505    }
506
507    #[mz_ore::test]
508    fn string_literal() {
509        let sql = "SELECT 'hello'";
510        let decoded = decode(&compute_semantic_tokens(sql));
511        assert_eq!(
512            decoded,
513            vec![
514                (0, 0, 6, TOKEN_TYPE_KEYWORD), // SELECT
515                (0, 7, 7, TOKEN_TYPE_STRING),  // 'hello'
516            ]
517        );
518    }
519
520    #[mz_ore::test]
521    fn comment_markers_inside_string_not_detected() {
522        let sql = "SELECT '--not a comment' FROM t";
523        let decoded = decode(&compute_semantic_tokens(sql));
524        assert_eq!(
525            decoded,
526            vec![
527                (0, 0, 6, TOKEN_TYPE_KEYWORD),   // SELECT
528                (0, 7, 17, TOKEN_TYPE_STRING),   // '--not a comment'
529                (0, 25, 4, TOKEN_TYPE_KEYWORD),  // FROM
530                (0, 30, 1, TOKEN_TYPE_VARIABLE), // t
531            ]
532        );
533    }
534
535    #[mz_ore::test]
536    fn quoted_identifier() {
537        let sql = r#"SELECT "My Col" FROM t"#;
538        let decoded = decode(&compute_semantic_tokens(sql));
539        assert_eq!(
540            decoded,
541            vec![
542                (0, 0, 6, TOKEN_TYPE_KEYWORD),   // SELECT
543                (0, 7, 8, TOKEN_TYPE_VARIABLE),  // "My Col"
544                (0, 16, 4, TOKEN_TYPE_KEYWORD),  // FROM
545                (0, 21, 1, TOKEN_TYPE_VARIABLE), // t
546            ]
547        );
548    }
549
550    #[mz_ore::test]
551    fn operators() {
552        let sql = "SELECT 1 + 2 * 3";
553        let decoded = decode(&compute_semantic_tokens(sql));
554        assert_eq!(
555            decoded,
556            vec![
557                (0, 0, 6, TOKEN_TYPE_KEYWORD),   // SELECT
558                (0, 7, 1, TOKEN_TYPE_NUMBER),    // 1
559                (0, 9, 1, TOKEN_TYPE_OPERATOR),  // +
560                (0, 11, 1, TOKEN_TYPE_NUMBER),   // 2
561                (0, 13, 1, TOKEN_TYPE_OPERATOR), // *
562                (0, 15, 1, TOKEN_TYPE_NUMBER),   // 3
563            ]
564        );
565    }
566
567    #[mz_ore::test]
568    fn parameter() {
569        let sql = "SELECT $1 + $42";
570        let decoded = decode(&compute_semantic_tokens(sql));
571        assert_eq!(
572            decoded,
573            vec![
574                (0, 0, 6, TOKEN_TYPE_KEYWORD),    // SELECT
575                (0, 7, 2, TOKEN_TYPE_PARAMETER),  // $1
576                (0, 10, 1, TOKEN_TYPE_OPERATOR),  // +
577                (0, 12, 3, TOKEN_TYPE_PARAMETER), // $42
578            ]
579        );
580    }
581
582    #[mz_ore::test]
583    fn punctuation_is_skipped() {
584        let sql = "SELECT (a, b.c);";
585        let decoded = decode(&compute_semantic_tokens(sql));
586        // Parens, comma, dot, semicolon should not appear.
587        assert_eq!(
588            decoded,
589            vec![
590                (0, 0, 6, TOKEN_TYPE_KEYWORD),   // SELECT
591                (0, 8, 1, TOKEN_TYPE_VARIABLE),  // a
592                (0, 11, 1, TOKEN_TYPE_VARIABLE), // b
593                (0, 13, 1, TOKEN_TYPE_VARIABLE), // c
594            ]
595        );
596    }
597
598    #[mz_ore::test]
599    fn non_ascii_identifier() {
600        // Japanese katakana "テーブル" (meaning "table"). Each char is one UTF-16 unit.
601        let sql = "SELECT テーブル";
602        let decoded = decode(&compute_semantic_tokens(sql));
603        assert_eq!(
604            decoded,
605            vec![
606                (0, 0, 6, TOKEN_TYPE_KEYWORD),  // SELECT
607                (0, 7, 4, TOKEN_TYPE_VARIABLE), // テーブル (4 chars = 4 UTF-16 units)
608            ]
609        );
610    }
611
612    #[mz_ore::test]
613    fn lex_error_does_not_panic() {
614        // `@@@` is not a valid SQL start — lex() errors. We should still
615        // return some comment data or an empty list without panicking.
616        let sql = "-- leading comment\n@@@";
617        let tokens = compute_semantic_tokens(sql);
618        // At least the comment should survive.
619        let decoded = decode(&tokens);
620        assert!(
621            decoded
622                .iter()
623                .any(|(_, _, _, ty)| *ty == TOKEN_TYPE_COMMENT),
624            "expected the comment to be recovered even on lex error: {decoded:?}"
625        );
626    }
627
628    #[mz_ore::test]
629    fn legend_order_matches_constants() {
630        let legend = legend_token_types();
631        assert_eq!(
632            legend[usize::cast_from(TOKEN_TYPE_KEYWORD)],
633            SemanticTokenType::KEYWORD
634        );
635        assert_eq!(
636            legend[usize::cast_from(TOKEN_TYPE_STRING)],
637            SemanticTokenType::STRING
638        );
639        assert_eq!(
640            legend[usize::cast_from(TOKEN_TYPE_NUMBER)],
641            SemanticTokenType::NUMBER
642        );
643        assert_eq!(
644            legend[usize::cast_from(TOKEN_TYPE_OPERATOR)],
645            SemanticTokenType::OPERATOR
646        );
647        assert_eq!(
648            legend[usize::cast_from(TOKEN_TYPE_VARIABLE)],
649            SemanticTokenType::VARIABLE
650        );
651        assert_eq!(
652            legend[usize::cast_from(TOKEN_TYPE_PARAMETER)],
653            SemanticTokenType::PARAMETER
654        );
655        assert_eq!(
656            legend[usize::cast_from(TOKEN_TYPE_COMMENT)],
657            SemanticTokenType::COMMENT
658        );
659    }
660}