mz_sql_lexer/
lexer.rs

1// Copyright 2018 sqlparser-rs contributors. All rights reserved.
2// Copyright Materialize, Inc. and contributors. All rights reserved.
3//
4// This file is derived from the sqlparser-rs project, available at
5// https://github.com/andygrove/sqlparser-rs. It was incorporated
6// directly into Materialize on December 21, 2019.
7//
8// Licensed under the Apache License, Version 2.0 (the "License");
9// you may not use this file except in compliance with the License.
10// You may obtain a copy of the License in the LICENSE file at the
11// root of this repository, or online at
12//
13//     http://www.apache.org/licenses/LICENSE-2.0
14//
15// Unless required by applicable law or agreed to in writing, software
16// distributed under the License is distributed on an "AS IS" BASIS,
17// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18// See the License for the specific language governing permissions and
19// limitations under the License.
20
21//! SQL lexer.
22//!
23//! This module lexes SQL according to the rules described in the ["Lexical
24//! Structure"] section of the PostgreSQL documentation. The description is
25//! intentionally not replicated here. Please refer to that chapter as you
26//! read the code in this module.
27//!
28//! Where the PostgreSQL documentation is unclear, refer to their flex source
29//! instead, located in the [backend/parser/scan.l] file in the PostgreSQL
30//! Git repository.
31//!
32//! ["Lexical Structure"]: https://www.postgresql.org/docs/current/sql-syntax-lexical.html
33//! [backend/parser/scan.l]: https://github.com/postgres/postgres/blob/90851d1d26f54ccb4d7b1bc49449138113d6ec83/src/backend/parser/scan.l
34
35extern crate alloc;
36
37use std::error::Error;
38use std::{char, fmt};
39
40use mz_ore::lex::LexBuf;
41use mz_ore::str::{MaxLenString, StrExt};
42use serde::{Deserialize, Serialize};
43
44use crate::keywords::Keyword;
45
46/// Maximum allowed identifier length in bytes.
47pub const MAX_IDENTIFIER_LENGTH: usize = 255;
48
49/// Newtype that limits the length of identifiers.
50pub type IdentString = MaxLenString<MAX_IDENTIFIER_LENGTH>;
51
52#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
53pub struct LexerError {
54    /// The error message.
55    pub message: String,
56    /// The byte position with which the error is associated.
57    pub pos: usize,
58}
59
60impl fmt::Display for LexerError {
61    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
62        f.write_str(&self.message)
63    }
64}
65
66impl Error for LexerError {}
67
68impl LexerError {
69    /// Constructs an error with the provided message at the provided position.
70    pub(crate) fn new<S>(pos: usize, message: S) -> LexerError
71    where
72        S: Into<String>,
73    {
74        LexerError {
75            pos,
76            message: message.into(),
77        }
78    }
79}
80
81#[derive(Debug, Clone, PartialEq)]
82pub enum Token {
83    Keyword(Keyword),
84    Ident(IdentString),
85    String(String),
86    HexString(String),
87    Number(String),
88    Parameter(usize),
89    Op(String),
90    Star,
91    Eq,
92    LParen,
93    RParen,
94    LBracket,
95    RBracket,
96    Dot,
97    Comma,
98    Colon,
99    DoubleColon,
100    Semicolon,
101    Arrow,
102}
103
104impl fmt::Display for Token {
105    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
106        match self {
107            Token::Keyword(kw) => f.write_str(kw.as_str()),
108            Token::Ident(id) => write!(f, "identifier {}", id.quoted()),
109            Token::String(s) => write!(f, "string literal {}", s.quoted()),
110            Token::HexString(s) => write!(f, "hex string literal {}", s.quoted()),
111            Token::Number(n) => write!(f, "number \"{}\"", n),
112            Token::Parameter(n) => write!(f, "parameter \"${}\"", n),
113            Token::Op(op) => write!(f, "operator {}", op.quoted()),
114            Token::Star => f.write_str("star"),
115            Token::Eq => f.write_str("equals sign"),
116            Token::LParen => f.write_str("left parenthesis"),
117            Token::RParen => f.write_str("right parenthesis"),
118            Token::LBracket => f.write_str("left square bracket"),
119            Token::RBracket => f.write_str("right square bracket"),
120            Token::Dot => f.write_str("dot"),
121            Token::Comma => f.write_str("comma"),
122            Token::Colon => f.write_str("colon"),
123            Token::DoubleColon => f.write_str("double colon"),
124            Token::Semicolon => f.write_str("semicolon"),
125            Token::Arrow => f.write_str("arrow"),
126        }
127    }
128}
129
130pub struct PosToken {
131    pub kind: Token,
132    pub offset: usize,
133}
134
135macro_rules! bail {
136    ($pos:expr, $($fmt:expr),*) => {
137        return Err(LexerError::new($pos, format!($($fmt),*)))
138    }
139}
140
141/// Lexes a SQL query.
142///
143/// Returns a list of tokens alongside their corresponding byte offset in the
144/// input string. Returns an error if the SQL query is lexically invalid.
145///
146/// See the module documentation for more information about the lexical
147/// structure of SQL.
148pub fn lex(query: &str) -> Result<Vec<PosToken>, LexerError> {
149    let buf = &mut LexBuf::new(query);
150    let mut tokens = vec![];
151    while let Some(ch) = buf.next() {
152        let pos = buf.pos() - ch.len_utf8();
153        let token = match ch {
154            _ if ch.is_ascii_whitespace() => continue,
155            '-' if buf.consume('-') => {
156                lex_line_comment(buf);
157                continue;
158            }
159            '/' if buf.consume('*') => {
160                lex_multiline_comment(buf)?;
161                continue;
162            }
163            '\'' => Token::String(lex_string(buf)?),
164            'x' | 'X' if buf.consume('\'') => Token::HexString(lex_string(buf)?),
165            'e' | 'E' if buf.consume('\'') => lex_extended_string(buf)?,
166            'A'..='Z' | 'a'..='z' | '_' | '\u{80}'..=char::MAX => lex_ident(buf)?,
167            '"' => lex_quoted_ident(buf)?,
168            '0'..='9' => lex_number(buf)?,
169            '.' if matches!(buf.peek(), Some('0'..='9')) => lex_number(buf)?,
170            '$' if matches!(buf.peek(), Some('0'..='9')) => lex_parameter(buf)?,
171            '$' => lex_dollar_string(buf)?,
172            '(' => Token::LParen,
173            ')' => Token::RParen,
174            ',' => Token::Comma,
175            '.' => Token::Dot,
176            ':' if buf.consume(':') => Token::DoubleColon,
177            ':' => Token::Colon,
178            ';' => Token::Semicolon,
179            '[' => Token::LBracket,
180            ']' => Token::RBracket,
181            #[rustfmt::skip]
182            '+'|'-'|'*'|'/'|'<'|'>'|'='|'~'|'!'|'@'|'#'|'%'|'^'|'&'|'|'|'`'|'?' => lex_op(buf),
183            _ => bail!(pos, "unexpected character in input: {}", ch),
184        };
185        tokens.push(PosToken {
186            kind: token,
187            offset: pos,
188        })
189    }
190
191    #[cfg(debug_assertions)]
192    for token in &tokens {
193        assert!(query.is_char_boundary(token.offset));
194    }
195
196    Ok(tokens)
197}
198
199fn lex_line_comment(buf: &mut LexBuf) {
200    buf.take_while(|ch| ch != '\n');
201}
202
203fn lex_multiline_comment(buf: &mut LexBuf) -> Result<(), LexerError> {
204    let pos = buf.pos() - 2;
205    let mut nesting = 0;
206    while let Some(ch) = buf.next() {
207        match ch {
208            '*' if buf.consume('/') => {
209                if nesting == 0 {
210                    return Ok(());
211                } else {
212                    nesting -= 1;
213                }
214            }
215            '/' if buf.consume('*') => nesting += 1,
216            _ => (),
217        }
218    }
219    bail!(pos, "unterminated multiline comment")
220}
221
222fn lex_ident(buf: &mut LexBuf) -> Result<Token, LexerError> {
223    buf.prev();
224    let pos: usize = buf.pos();
225    let word = buf.take_while(
226        |ch| matches!(ch, 'A'..='Z' | 'a'..='z' | '0'..='9' | '$' | '_' | '\u{80}'..=char::MAX),
227    );
228    match word.parse() {
229        Ok(kw) => Ok(Token::Keyword(kw)),
230        Err(_) => {
231            let Ok(small) = IdentString::new(word.to_lowercase()) else {
232                bail!(
233                    pos,
234                    "identifier length exceeds {MAX_IDENTIFIER_LENGTH} bytes"
235                )
236            };
237            Ok(Token::Ident(small))
238        }
239    }
240}
241
242fn lex_quoted_ident(buf: &mut LexBuf) -> Result<Token, LexerError> {
243    let mut s = String::new();
244    let pos = buf.pos() - 1;
245    loop {
246        match buf.next() {
247            Some('"') if buf.consume('"') => s.push('"'),
248            Some('"') => break,
249            Some('\0') => bail!(pos, "null character in quoted identifier"),
250            Some(c) => s.push(c),
251            None => bail!(pos, "unterminated quoted identifier"),
252        }
253    }
254    let Ok(small) = IdentString::new(s) else {
255        bail!(
256            pos,
257            "identifier length exceeds {MAX_IDENTIFIER_LENGTH} bytes"
258        )
259    };
260    Ok(Token::Ident(small))
261}
262
263fn lex_string(buf: &mut LexBuf) -> Result<String, LexerError> {
264    let mut s = String::new();
265    loop {
266        let pos = buf.pos() - 1;
267        loop {
268            match buf.next() {
269                Some('\'') if buf.consume('\'') => s.push('\''),
270                Some('\'') => break,
271                Some(c) => s.push(c),
272                None => bail!(pos, "unterminated quoted string"),
273            }
274        }
275        if !lex_to_adjacent_string(buf) {
276            return Ok(s);
277        }
278    }
279}
280
281fn lex_extended_string(buf: &mut LexBuf) -> Result<Token, LexerError> {
282    fn lex_unicode_escape(buf: &mut LexBuf, n: usize) -> Result<char, LexerError> {
283        let pos = buf.pos() - 2;
284        buf.next_n(n)
285            .and_then(|s| u32::from_str_radix(s, 16).ok())
286            .and_then(|codepoint| char::try_from(codepoint).ok())
287            .ok_or_else(|| LexerError::new(pos, "invalid unicode escape"))
288    }
289
290    // We do not support octal (\o) or hexadecimal (\x) escapes, since it is
291    // possible to construct invalid UTF-8 with these escapes. We could check
292    // for and reject invalid UTF-8, of course, but it is too annoying to be
293    // worth doing right now. We still lex the escapes to produce nice error
294    // messages.
295
296    fn lex_octal_escape(buf: &mut LexBuf) -> LexerError {
297        let pos = buf.pos() - 2;
298        buf.take_while(|ch| matches!(ch, '0'..='7'));
299        LexerError::new(pos, "octal escapes are not supported")
300    }
301
302    fn lex_hexadecimal_escape(buf: &mut LexBuf) -> LexerError {
303        let pos = buf.pos() - 2;
304        buf.take_while(|ch| matches!(ch, '0'..='9' | 'A'..='F' | 'a'..='f'));
305        LexerError::new(pos, "hexadecimal escapes are not supported")
306    }
307
308    let mut s = String::new();
309    loop {
310        let pos = buf.pos() - 1;
311        loop {
312            match buf.next() {
313                Some('\'') if buf.consume('\'') => s.push('\''),
314                Some('\'') => break,
315                Some('\\') => match buf.next() {
316                    Some('b') => s.push('\x08'),
317                    Some('f') => s.push('\x0c'),
318                    Some('n') => s.push('\n'),
319                    Some('r') => s.push('\r'),
320                    Some('t') => s.push('\t'),
321                    Some('u') => s.push(lex_unicode_escape(buf, 4)?),
322                    Some('U') => s.push(lex_unicode_escape(buf, 8)?),
323                    Some('0'..='7') => return Err(lex_octal_escape(buf)),
324                    Some('x') => return Err(lex_hexadecimal_escape(buf)),
325                    Some(c) => s.push(c),
326                    None => bail!(pos, "unterminated quoted string"),
327                },
328                Some(c) => s.push(c),
329                None => bail!(pos, "unterminated quoted string"),
330            }
331        }
332        if !lex_to_adjacent_string(buf) {
333            return Ok(Token::String(s));
334        }
335    }
336}
337
338fn lex_to_adjacent_string(buf: &mut LexBuf) -> bool {
339    // Adjacent string literals that are separated by whitespace are
340    // concatenated if and only if that whitespace contains at least one newline
341    // character. This bizarre rule matches PostgreSQL and the SQL standard.
342    let whitespace = buf.take_while(|ch| ch.is_ascii_whitespace());
343    whitespace.contains(&['\n', '\r'][..]) && buf.consume('\'')
344}
345
346fn lex_dollar_string(buf: &mut LexBuf) -> Result<Token, LexerError> {
347    let pos = buf.pos() - 1;
348    let tag = format!("${}$", buf.take_while(|ch| ch != '$'));
349    let _ = buf.next();
350    if let Some(s) = buf.take_to_delimiter(&tag) {
351        Ok(Token::String(s.into()))
352    } else {
353        Err(LexerError::new(pos, "unterminated dollar-quoted string"))
354    }
355}
356
357fn lex_parameter(buf: &mut LexBuf) -> Result<Token, LexerError> {
358    let pos = buf.pos() - 1;
359    let n = buf
360        .take_while(|ch| matches!(ch, '0'..='9'))
361        .parse()
362        .map_err(|_| LexerError::new(pos, "invalid parameter number"))?;
363    Ok(Token::Parameter(n))
364}
365
366fn lex_number(buf: &mut LexBuf) -> Result<Token, LexerError> {
367    buf.prev();
368    let mut s = buf.take_while(|ch| matches!(ch, '0'..='9')).to_owned();
369
370    // Optional decimal component.
371    if buf.consume('.') {
372        s.push('.');
373        s.push_str(buf.take_while(|ch| matches!(ch, '0'..='9')));
374    }
375
376    // Optional exponent.
377    if buf.consume('e') || buf.consume('E') {
378        s.push('E');
379        let require_exp = if buf.consume('-') {
380            s.push('-');
381            true
382        } else {
383            buf.consume('+')
384        };
385        let exp = buf.take_while(|ch| matches!(ch, '0'..='9'));
386        if require_exp && exp.is_empty() {
387            return Err(LexerError::new(buf.pos() - 1, "missing required exponent"));
388        } else if exp.is_empty() {
389            // Put back consumed E.
390            buf.prev();
391            s.pop();
392        } else {
393            s.push_str(exp);
394        }
395    }
396
397    Ok(Token::Number(s))
398}
399
400fn lex_op(buf: &mut LexBuf) -> Token {
401    buf.prev();
402
403    // Materialize special case: `=>` is lexed as an arrow token, rather than
404    // an operator.
405    if buf.consume_str("=>") {
406        return Token::Arrow;
407    }
408
409    let mut s = String::new();
410
411    // In PostgreSQL, operators might be composed of any of the characters in
412    // the set below...
413    while let Some(ch) = buf.next() {
414        match ch {
415            // ...except the sequences `--` and `/*` start comments, even within
416            // what would otherwise be an operator...
417            '-' if buf.peek() == Some('-') => {
418                buf.prev();
419                break;
420            }
421            '/' if buf.peek() == Some('*') => {
422                buf.prev();
423                break;
424            }
425            #[rustfmt::skip]
426            '+'|'-'|'*'|'/'|'<'|'>'|'='|'~'|'!'|'@'|'#'|'%'|'^'|'&'|'|'|'`'|'?' => s.push(ch),
427            _ => {
428                buf.prev();
429                break;
430            }
431        }
432    }
433
434    // ...and a multi-character operator that ends with `-` or `+` must also
435    // contain at least one nonstandard operator character. This is so that e.g.
436    // `1+-2` is lexed as `1 + (-2)` as required by the SQL standard, but `1@+2`
437    // is lexed as `1 @+ 2`, as `@+` is meant to be a user-definable operator.
438    if s.len() > 1
439        && s.ends_with(&['-', '+'][..])
440        && !s.contains(&['~', '!', '@', '#', '%', '^', '&', '|', '`', '?'][..])
441    {
442        while s.len() > 1 && s.ends_with(&['-', '+'][..]) {
443            buf.prev();
444            s.pop();
445        }
446    }
447
448    match s.as_str() {
449        // `*` and `=` are not just expression operators in SQL, so give them
450        // dedicated tokens to simplify the parser.
451        "*" => Token::Star,
452        "=" => Token::Eq,
453        // Normalize the two forms of the not-equals operator.
454        "!=" => Token::Op("<>".into()),
455        // Emit all other operators as is.
456        _ => Token::Op(s),
457    }
458}