1use mz_sql_lexer::lexer::{self, PosToken, Token};
26use tower_lsp::lsp_types::{SemanticToken, SemanticTokenType};
27
28const TOKEN_TYPE_KEYWORD: u32 = 0;
29const TOKEN_TYPE_STRING: u32 = 1;
30const TOKEN_TYPE_NUMBER: u32 = 2;
31const TOKEN_TYPE_OPERATOR: u32 = 3;
32const TOKEN_TYPE_VARIABLE: u32 = 4;
33const TOKEN_TYPE_PARAMETER: u32 = 5;
34const TOKEN_TYPE_COMMENT: u32 = 6;
35
36pub(super) fn legend_token_types() -> Vec<SemanticTokenType> {
38 vec![
39 SemanticTokenType::KEYWORD,
40 SemanticTokenType::STRING,
41 SemanticTokenType::NUMBER,
42 SemanticTokenType::OPERATOR,
43 SemanticTokenType::VARIABLE,
44 SemanticTokenType::PARAMETER,
45 SemanticTokenType::COMMENT,
46 ]
47}
48
49#[derive(Debug, Clone, Copy)]
51struct RawSpan {
52 start: usize,
53 end: usize,
54 token_type: u32,
55}
56
57#[derive(Debug, Clone, Copy)]
59struct LineToken {
60 line: u32,
61 start_char: u32,
62 length: u32,
63 token_type: u32,
64}
65
66pub(super) fn compute_semantic_tokens(text: &str) -> Vec<SemanticToken> {
73 let mut spans = Vec::new();
74 collect_comments(text, &mut spans);
75 if let Ok(tokens) = lexer::lex(text) {
76 for tok in &tokens {
77 if let Some(span) = lex_token_span(tok, text) {
78 spans.push(span);
79 }
80 }
81 }
82 spans.sort_by_key(|s| s.start);
83
84 let line_tokens = split_across_lines(text, &spans);
85 encode_deltas(&line_tokens)
86}
87
88fn collect_comments(text: &str, out: &mut Vec<RawSpan>) {
92 let bytes = text.as_bytes();
93 let mut i = 0;
94 while i < bytes.len() {
95 match bytes[i] {
96 b'\'' => i = skip_single_quoted(bytes, i),
97 b'"' => i = skip_double_quoted(bytes, i),
98 b'-' if bytes.get(i + 1) == Some(&b'-') => {
99 let start = i;
100 while i < bytes.len() && bytes[i] != b'\n' {
101 i += 1;
102 }
103 out.push(RawSpan {
104 start,
105 end: i,
106 token_type: TOKEN_TYPE_COMMENT,
107 });
108 }
109 b'/' if bytes.get(i + 1) == Some(&b'*') => {
110 let start = i;
111 i += 2;
112 let mut depth = 1usize;
113 while i + 1 < bytes.len() && depth > 0 {
114 if bytes[i] == b'/' && bytes[i + 1] == b'*' {
115 depth += 1;
116 i += 2;
117 } else if bytes[i] == b'*' && bytes[i + 1] == b'/' {
118 depth -= 1;
119 i += 2;
120 } else {
121 i += 1;
122 }
123 }
124 if depth > 0 {
125 i = bytes.len();
126 }
127 out.push(RawSpan {
128 start,
129 end: i,
130 token_type: TOKEN_TYPE_COMMENT,
131 });
132 }
133 _ => i += 1,
134 }
135 }
136}
137
138fn skip_single_quoted(bytes: &[u8], start: usize) -> usize {
141 let mut i = start + 1;
142 while i < bytes.len() {
143 match bytes[i] {
144 b'\'' if bytes.get(i + 1) == Some(&b'\'') => i += 2,
145 b'\'' => return i + 1,
146 b'\\' if i + 1 < bytes.len() => i += 2,
147 _ => i += 1,
148 }
149 }
150 bytes.len()
151}
152
153fn skip_double_quoted(bytes: &[u8], start: usize) -> usize {
155 let mut i = start + 1;
156 while i < bytes.len() {
157 match bytes[i] {
158 b'"' if bytes.get(i + 1) == Some(&b'"') => i += 2,
159 b'"' => return i + 1,
160 _ => i += 1,
161 }
162 }
163 bytes.len()
164}
165
166fn lex_token_span(tok: &PosToken, text: &str) -> Option<RawSpan> {
170 let start = tok.offset;
171 let bytes = text.as_bytes();
172 let (end, token_type) = match &tok.kind {
173 Token::Keyword(kw) => (start + kw.as_str().len(), TOKEN_TYPE_KEYWORD),
174 Token::Op(op) => (start + op.len(), TOKEN_TYPE_OPERATOR),
175 Token::Number(n) => (start + n.len(), TOKEN_TYPE_NUMBER),
176 Token::Star | Token::Eq | Token::Colon => (start + 1, TOKEN_TYPE_OPERATOR),
177 Token::DoubleColon | Token::Arrow => (start + 2, TOKEN_TYPE_OPERATOR),
178 Token::Ident(_) => (start + scan_ident_len(bytes, start), TOKEN_TYPE_VARIABLE),
179 Token::String(_) => (
180 start + scan_string_token_len(bytes, start),
181 TOKEN_TYPE_STRING,
182 ),
183 Token::HexString(_) => (
184 start + scan_hex_string_token_len(bytes, start),
185 TOKEN_TYPE_STRING,
186 ),
187 Token::Parameter(_) => (
188 start + scan_parameter_len(bytes, start),
189 TOKEN_TYPE_PARAMETER,
190 ),
191 Token::LParen
192 | Token::RParen
193 | Token::LBracket
194 | Token::RBracket
195 | Token::Dot
196 | Token::Comma
197 | Token::Semicolon => return None,
198 };
199 Some(RawSpan {
200 start,
201 end,
202 token_type,
203 })
204}
205
206fn scan_ident_len(bytes: &[u8], start: usize) -> usize {
207 if bytes.get(start) == Some(&b'"') {
208 skip_double_quoted(bytes, start) - start
209 } else {
210 let mut i = start;
211 while i < bytes.len() {
212 let c = bytes[i];
213 if c.is_ascii_alphanumeric() || c == b'_' || c == b'$' || c >= 0x80 {
214 i += 1;
215 } else {
216 break;
217 }
218 }
219 i - start
220 }
221}
222
223fn scan_string_token_len(bytes: &[u8], start: usize) -> usize {
226 let quote_pos = if bytes.get(start) == Some(&b'\'') {
227 start
228 } else {
229 if bytes.get(start) == Some(&b'$') {
232 return scan_dollar_quoted_len(bytes, start);
233 }
234 start + 1
235 };
236 skip_single_quoted(bytes, quote_pos) - start
237}
238
239fn scan_hex_string_token_len(bytes: &[u8], start: usize) -> usize {
241 let quote_pos = start + 1;
242 skip_single_quoted(bytes, quote_pos) - start
243}
244
245fn scan_dollar_quoted_len(bytes: &[u8], start: usize) -> usize {
248 let mut i = start + 1;
250 while i < bytes.len() && bytes[i] != b'$' {
251 i += 1;
252 }
253 if i >= bytes.len() {
254 return bytes.len() - start;
255 }
256 let tag = &bytes[start..=i]; i += 1;
258 while i + tag.len() <= bytes.len() {
259 if &bytes[i..i + tag.len()] == tag {
260 return (i + tag.len()) - start;
261 }
262 i += 1;
263 }
264 bytes.len() - start
265}
266
267fn scan_parameter_len(bytes: &[u8], start: usize) -> usize {
268 let mut i = start + 1;
269 while i < bytes.len() && bytes[i].is_ascii_digit() {
270 i += 1;
271 }
272 i - start
273}
274
275fn split_across_lines(text: &str, spans: &[RawSpan]) -> Vec<LineToken> {
278 let line_starts = line_starts(text);
279 let mut out = Vec::with_capacity(spans.len());
280 for span in spans {
281 if span.end <= span.start {
282 continue;
283 }
284 let start_line = line_for_offset(&line_starts, span.start);
285 let end_line = line_for_offset(&line_starts, span.end.saturating_sub(1));
286 if start_line == end_line {
287 let line_start = line_starts[start_line];
288 let start_col = utf16_len(&text[line_start..span.start]);
289 let length = utf16_len(&text[span.start..span.end]);
290 if length > 0 {
291 out.push(LineToken {
292 line: saturating_u32(start_line),
293 start_char: saturating_u32(start_col),
294 length: saturating_u32(length),
295 token_type: span.token_type,
296 });
297 }
298 } else {
299 for line in start_line..=end_line {
301 let line_start = line_starts[line];
302 let line_end = line_starts.get(line + 1).copied().unwrap_or(text.len());
303 let seg_start = span.start.max(line_start);
304 let seg_end_raw = span.end.min(line_end);
307 let seg_end = trim_trailing_newline(text, line_start, seg_end_raw);
308 if seg_end <= seg_start {
309 continue;
310 }
311 let start_col = utf16_len(&text[line_start..seg_start]);
312 let length = utf16_len(&text[seg_start..seg_end]);
313 if length > 0 {
314 out.push(LineToken {
315 line: saturating_u32(line),
316 start_char: saturating_u32(start_col),
317 length: saturating_u32(length),
318 token_type: span.token_type,
319 });
320 }
321 }
322 }
323 }
324 out.sort_by(|a, b| a.line.cmp(&b.line).then(a.start_char.cmp(&b.start_char)));
325 out
326}
327
328fn trim_trailing_newline(text: &str, line_start: usize, end: usize) -> usize {
331 let bytes = text.as_bytes();
332 let mut e = end;
333 if e > line_start && bytes.get(e - 1) == Some(&b'\n') {
334 e -= 1;
335 if e > line_start && bytes.get(e - 1) == Some(&b'\r') {
336 e -= 1;
337 }
338 }
339 e
340}
341
342fn line_starts(text: &str) -> Vec<usize> {
344 let mut v = vec![0];
345 for (i, b) in text.bytes().enumerate() {
346 if b == b'\n' {
347 v.push(i + 1);
348 }
349 }
350 v
351}
352
353fn line_for_offset(line_starts: &[usize], offset: usize) -> usize {
355 match line_starts.binary_search(&offset) {
356 Ok(i) => i,
357 Err(i) => i - 1,
358 }
359}
360
361fn utf16_len(s: &str) -> usize {
364 if s.is_ascii() {
365 return s.len();
366 }
367 s.chars().map(|c| c.len_utf16()).sum()
368}
369
370fn saturating_u32(v: usize) -> u32 {
375 u32::try_from(v).unwrap_or(u32::MAX)
376}
377
378fn encode_deltas(tokens: &[LineToken]) -> Vec<SemanticToken> {
380 let mut out = Vec::with_capacity(tokens.len());
381 let mut prev_line: u32 = 0;
382 let mut prev_char: u32 = 0;
383 for t in tokens {
384 let delta_line = t.line - prev_line;
385 let delta_start = if delta_line == 0 {
386 t.start_char - prev_char
387 } else {
388 t.start_char
389 };
390 out.push(SemanticToken {
391 delta_line,
392 delta_start,
393 length: t.length,
394 token_type: t.token_type,
395 token_modifiers_bitset: 0,
396 });
397 prev_line = t.line;
398 prev_char = t.start_char;
399 }
400 out
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406 use mz_ore::cast::CastFrom;
407
408 fn decode(tokens: &[SemanticToken]) -> Vec<(u32, u32, u32, u32)> {
411 let mut out = Vec::new();
412 let mut line = 0u32;
413 let mut ch = 0u32;
414 for t in tokens {
415 line += t.delta_line;
416 ch = if t.delta_line == 0 {
417 ch + t.delta_start
418 } else {
419 t.delta_start
420 };
421 out.push((line, ch, t.length, t.token_type));
422 }
423 out
424 }
425
426 #[mz_ore::test]
427 fn empty_input() {
428 assert!(compute_semantic_tokens("").is_empty());
429 }
430
431 #[mz_ore::test]
432 fn basic_select() {
433 let sql = "SELECT 1 FROM t";
434 let decoded = decode(&compute_semantic_tokens(sql));
435 assert_eq!(
436 decoded,
437 vec![
438 (0, 0, 6, TOKEN_TYPE_KEYWORD), (0, 7, 1, TOKEN_TYPE_NUMBER), (0, 9, 4, TOKEN_TYPE_KEYWORD), (0, 14, 1, TOKEN_TYPE_VARIABLE), ]
443 );
444 }
445
446 #[mz_ore::test]
447 fn create_mv() {
448 let sql = "CREATE MATERIALIZED VIEW foo AS SELECT 1";
449 let decoded = decode(&compute_semantic_tokens(sql));
450 assert_eq!(
451 decoded,
452 vec![
453 (0, 0, 6, TOKEN_TYPE_KEYWORD), (0, 7, 12, TOKEN_TYPE_KEYWORD), (0, 20, 4, TOKEN_TYPE_KEYWORD), (0, 25, 3, TOKEN_TYPE_VARIABLE), (0, 29, 2, TOKEN_TYPE_KEYWORD), (0, 32, 6, TOKEN_TYPE_KEYWORD), (0, 39, 1, TOKEN_TYPE_NUMBER), ]
461 );
462 }
463
464 #[mz_ore::test]
465 fn line_comment_mid_statement() {
466 let sql = "SELECT -- hi\n1";
467 let decoded = decode(&compute_semantic_tokens(sql));
468 assert_eq!(
469 decoded,
470 vec![
471 (0, 0, 6, TOKEN_TYPE_KEYWORD), (0, 7, 5, TOKEN_TYPE_COMMENT), (1, 0, 1, TOKEN_TYPE_NUMBER), ]
475 );
476 }
477
478 #[mz_ore::test]
479 fn block_comment_single_line() {
480 let sql = "SELECT /* x */ 1";
481 let decoded = decode(&compute_semantic_tokens(sql));
482 assert_eq!(
483 decoded,
484 vec![
485 (0, 0, 6, TOKEN_TYPE_KEYWORD), (0, 7, 7, TOKEN_TYPE_COMMENT), (0, 15, 1, TOKEN_TYPE_NUMBER), ]
489 );
490 }
491
492 #[mz_ore::test]
493 fn block_comment_multiline() {
494 let sql = "/*\na\nb\n*/";
495 let decoded = decode(&compute_semantic_tokens(sql));
496 assert_eq!(
497 decoded,
498 vec![
499 (0, 0, 2, TOKEN_TYPE_COMMENT), (1, 0, 1, TOKEN_TYPE_COMMENT), (2, 0, 1, TOKEN_TYPE_COMMENT), (3, 0, 2, TOKEN_TYPE_COMMENT), ]
504 );
505 }
506
507 #[mz_ore::test]
508 fn string_literal() {
509 let sql = "SELECT 'hello'";
510 let decoded = decode(&compute_semantic_tokens(sql));
511 assert_eq!(
512 decoded,
513 vec![
514 (0, 0, 6, TOKEN_TYPE_KEYWORD), (0, 7, 7, TOKEN_TYPE_STRING), ]
517 );
518 }
519
520 #[mz_ore::test]
521 fn comment_markers_inside_string_not_detected() {
522 let sql = "SELECT '--not a comment' FROM t";
523 let decoded = decode(&compute_semantic_tokens(sql));
524 assert_eq!(
525 decoded,
526 vec![
527 (0, 0, 6, TOKEN_TYPE_KEYWORD), (0, 7, 17, TOKEN_TYPE_STRING), (0, 25, 4, TOKEN_TYPE_KEYWORD), (0, 30, 1, TOKEN_TYPE_VARIABLE), ]
532 );
533 }
534
535 #[mz_ore::test]
536 fn quoted_identifier() {
537 let sql = r#"SELECT "My Col" FROM t"#;
538 let decoded = decode(&compute_semantic_tokens(sql));
539 assert_eq!(
540 decoded,
541 vec![
542 (0, 0, 6, TOKEN_TYPE_KEYWORD), (0, 7, 8, TOKEN_TYPE_VARIABLE), (0, 16, 4, TOKEN_TYPE_KEYWORD), (0, 21, 1, TOKEN_TYPE_VARIABLE), ]
547 );
548 }
549
550 #[mz_ore::test]
551 fn operators() {
552 let sql = "SELECT 1 + 2 * 3";
553 let decoded = decode(&compute_semantic_tokens(sql));
554 assert_eq!(
555 decoded,
556 vec![
557 (0, 0, 6, TOKEN_TYPE_KEYWORD), (0, 7, 1, TOKEN_TYPE_NUMBER), (0, 9, 1, TOKEN_TYPE_OPERATOR), (0, 11, 1, TOKEN_TYPE_NUMBER), (0, 13, 1, TOKEN_TYPE_OPERATOR), (0, 15, 1, TOKEN_TYPE_NUMBER), ]
564 );
565 }
566
567 #[mz_ore::test]
568 fn parameter() {
569 let sql = "SELECT $1 + $42";
570 let decoded = decode(&compute_semantic_tokens(sql));
571 assert_eq!(
572 decoded,
573 vec![
574 (0, 0, 6, TOKEN_TYPE_KEYWORD), (0, 7, 2, TOKEN_TYPE_PARAMETER), (0, 10, 1, TOKEN_TYPE_OPERATOR), (0, 12, 3, TOKEN_TYPE_PARAMETER), ]
579 );
580 }
581
582 #[mz_ore::test]
583 fn punctuation_is_skipped() {
584 let sql = "SELECT (a, b.c);";
585 let decoded = decode(&compute_semantic_tokens(sql));
586 assert_eq!(
588 decoded,
589 vec![
590 (0, 0, 6, TOKEN_TYPE_KEYWORD), (0, 8, 1, TOKEN_TYPE_VARIABLE), (0, 11, 1, TOKEN_TYPE_VARIABLE), (0, 13, 1, TOKEN_TYPE_VARIABLE), ]
595 );
596 }
597
598 #[mz_ore::test]
599 fn non_ascii_identifier() {
600 let sql = "SELECT テーブル";
602 let decoded = decode(&compute_semantic_tokens(sql));
603 assert_eq!(
604 decoded,
605 vec![
606 (0, 0, 6, TOKEN_TYPE_KEYWORD), (0, 7, 4, TOKEN_TYPE_VARIABLE), ]
609 );
610 }
611
612 #[mz_ore::test]
613 fn lex_error_does_not_panic() {
614 let sql = "-- leading comment\n@@@";
617 let tokens = compute_semantic_tokens(sql);
618 let decoded = decode(&tokens);
620 assert!(
621 decoded
622 .iter()
623 .any(|(_, _, _, ty)| *ty == TOKEN_TYPE_COMMENT),
624 "expected the comment to be recovered even on lex error: {decoded:?}"
625 );
626 }
627
628 #[mz_ore::test]
629 fn legend_order_matches_constants() {
630 let legend = legend_token_types();
631 assert_eq!(
632 legend[usize::cast_from(TOKEN_TYPE_KEYWORD)],
633 SemanticTokenType::KEYWORD
634 );
635 assert_eq!(
636 legend[usize::cast_from(TOKEN_TYPE_STRING)],
637 SemanticTokenType::STRING
638 );
639 assert_eq!(
640 legend[usize::cast_from(TOKEN_TYPE_NUMBER)],
641 SemanticTokenType::NUMBER
642 );
643 assert_eq!(
644 legend[usize::cast_from(TOKEN_TYPE_OPERATOR)],
645 SemanticTokenType::OPERATOR
646 );
647 assert_eq!(
648 legend[usize::cast_from(TOKEN_TYPE_VARIABLE)],
649 SemanticTokenType::VARIABLE
650 );
651 assert_eq!(
652 legend[usize::cast_from(TOKEN_TYPE_PARAMETER)],
653 SemanticTokenType::PARAMETER
654 );
655 assert_eq!(
656 legend[usize::cast_from(TOKEN_TYPE_COMMENT)],
657 SemanticTokenType::COMMENT
658 );
659 }
660}