connection_string/
jdbc.rs

1use std::str::FromStr;
2use std::{collections::HashMap, fmt::Display};
3
4use crate::{bail, ensure};
5
6/// JDBC connection string parser for SqlServer
7///
8/// [Read more](https://docs.microsoft.com/en-us/sql/connect/jdbc/building-the-connection-url?view=sql-server-ver15)
9///
10/// # Format
11///
12/// ```txt
13/// jdbc:sqlserver://[serverName[\instanceName][:portNumber]][;property=value[;property=value]]
14/// ```
15#[derive(Debug, PartialEq, Eq, Clone)]
16pub struct JdbcString {
17    sub_protocol: String,
18    server_name: Option<String>,
19    instance_name: Option<String>,
20    port: Option<u16>,
21    properties: HashMap<String, String>,
22}
23
24impl JdbcString {
25    /// Access the connection sub-protocol
26    pub fn sub_protocol(&self) -> &str {
27        &self.sub_protocol
28    }
29
30    /// Access the connection server name
31    pub fn server_name(&self) -> Option<&str> {
32        self.server_name.as_deref()
33    }
34
35    /// Get a reference to the connection's instance name.
36    pub fn instance_name(&self) -> Option<&str> {
37        self.instance_name.as_deref()
38    }
39
40    /// Access the connection's port
41    pub fn port(&self) -> Option<u16> {
42        self.port
43    }
44
45    /// Access the connection's key-value pairs
46    pub fn properties(&self) -> &HashMap<String, String> {
47        &self.properties
48    }
49
50    /// Mutably access the connection's key-value pairs
51    pub fn properties_mut(&mut self) -> &mut HashMap<String, String> {
52        &mut self.properties
53    }
54
55    /// Get an iterator over all keys from the connection's key-value pairs
56    pub fn keys(&self) -> impl ExactSizeIterator<Item = &str> + '_ {
57        self.properties.keys().map(AsRef::as_ref)
58    }
59}
60
61impl Display for JdbcString {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        /// Escape all non-alphanumeric characters in a string..
64        fn escape(s: &str) -> String {
65            let mut output = String::with_capacity(s.len());
66            let mut escaping = false;
67            for b in s.chars() {
68                if matches!(b, ':' | '=' | '\\' | '/' | ';' | '{' | '}' | '[' | ']') {
69                    if !escaping {
70                        escaping = true;
71                        output.push('{');
72                    }
73                    output.push(b);
74                } else {
75                    if escaping {
76                        escaping = false;
77                        output.push('}');
78                    }
79                    output.push(b);
80                }
81            }
82            if escaping {
83                output.push('}');
84            }
85            output
86        }
87
88        write!(f, "{}://", self.sub_protocol)?;
89        if let Some(server_name) = &self.server_name {
90            write!(f, "{}", escape(server_name))?;
91        }
92        if let Some(instance_name) = &self.instance_name {
93            write!(f, r#"\{}"#, escape(instance_name))?;
94        }
95        if let Some(port) = self.port {
96            write!(f, ":{}", port)?;
97        }
98
99        for (k, v) in self.properties().iter() {
100            write!(f, ";{}={}", escape(k.trim()), escape(v.trim()))?;
101        }
102        Ok(())
103    }
104}
105
106// NOTE(yosh): Unfortunately we can't parse using `split(';')` because JDBC
107// strings support escaping. This means that `{;}` is valid and we need to write
108// an actual LR parser.
109impl FromStr for JdbcString {
110    type Err = crate::Error;
111
112    fn from_str(input: &str) -> Result<Self, Self::Err> {
113        let mut lexer = Lexer::tokenize(input)?;
114
115        // ```
116        // jdbc:sqlserver://[serverName[\instanceName][:portNumber]][;property=value[;property=value]]
117        // ^^^^^^^^^^^^^^^^^
118        // ```
119        let err = "Invalid JDBC sub-protocol";
120        cmp_str(&mut lexer, "jdbc", err)?;
121        ensure!(lexer.next().kind() == &TokenKind::Colon, err);
122        let sub_protocol = format!("jdbc:{}", read_ident(&mut lexer, err)?);
123
124        ensure!(lexer.next().kind() == &TokenKind::Colon, err);
125        ensure!(lexer.next().kind() == &TokenKind::FSlash, err);
126        ensure!(lexer.next().kind() == &TokenKind::FSlash, err);
127
128        // ```
129        // jdbc:sqlserver://[serverName[\instanceName][:portNumber]][;property=value[;property=value]]
130        //                  ^^^^^^^^^^^
131        // ```
132        // NOTE: this can also be an IPv6 address.
133        let mut server_name = None;
134        match lexer.peek().kind() {
135            TokenKind::OpenBracket => {
136                let err_msg = "Invalid server name: invalid IPv6 address";
137                server_name = Some(parse_ipv6(&mut lexer, err_msg)?);
138            }
139            TokenKind::Atom(_) | TokenKind::Escaped(_) => {
140                server_name = Some(read_ident(&mut lexer, "Invalid server name")?);
141            }
142            _ => {}
143        }
144
145        // ```
146        // jdbc:sqlserver://[serverName[\instanceName][:portNumber]][;property=value[;property=value]]
147        //                             ^^^^^^^^^^^^^^^
148        // ```
149        let mut instance_name = None;
150        if matches!(lexer.peek().kind(), TokenKind::BSlash) {
151            let _ = lexer.next();
152            instance_name = Some(read_ident(&mut lexer, "Invalid instance name")?);
153        }
154
155        // ```
156        // jdbc:sqlserver://[serverName[\instanceName][:portNumber]][;property=value[;property=value]]
157        //                                            ^^^^^^^^^^^^^
158        // ```
159        let mut port = None;
160        if matches!(lexer.peek().kind(), TokenKind::Colon) {
161            let _ = lexer.next();
162            let err = "Invalid port";
163            let s = read_ident(&mut lexer, err)?;
164            port = Some(s.parse()?);
165        }
166
167        // ```
168        // jdbc:sqlserver://[serverName[\instanceName][:portNumber]][;property=value[;property=value]]
169        //                                                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
170        // ```
171        // NOTE: we're choosing to only keep the last value per key rather than support multiple inserts per key.
172        let mut properties = HashMap::new();
173        while let TokenKind::Semi = lexer.peek().kind() {
174            let _ = lexer.next();
175
176            // Handle trailing semis.
177            if let TokenKind::Eof = lexer.peek().kind() {
178                let _ = lexer.next();
179                break;
180            }
181
182            let err = "Invalid property key";
183            let key = read_ident(&mut lexer, err)?.to_lowercase();
184
185            let err = "Property pairs must be joined by a `=`";
186            ensure!(lexer.next().kind() == &TokenKind::Eq, err);
187
188            let err = "Invalid property value";
189            let value = read_ident(&mut lexer, err)?;
190
191            properties.insert(key, value);
192        }
193
194        let token = lexer.next();
195        ensure!(token.kind() == &TokenKind::Eof, "Invalid JDBC token");
196
197        Ok(Self {
198            sub_protocol,
199            server_name,
200            instance_name,
201            port,
202            properties,
203        })
204    }
205}
206
207/// Validate a sequence of `TokenKind::Atom` matches the content of a string.
208fn cmp_str(lexer: &mut Lexer, s: &str, err_msg: &'static str) -> crate::Result<()> {
209    for char in s.chars() {
210        if let Token {
211            kind: TokenKind::Atom(tchar),
212            ..
213        } = lexer.next()
214        {
215            ensure!(char == tchar, err_msg);
216        } else {
217            bail!(err_msg);
218        }
219    }
220    Ok(())
221}
222
223/// Read sequences of `TokenKind::Atom` and `TokenKind::Escaped` into a String.
224fn read_ident(lexer: &mut Lexer, err_msg: &'static str) -> crate::Result<String> {
225    let mut output = String::new();
226    loop {
227        let token = lexer.next();
228        match token.kind() {
229            TokenKind::Escaped(seq) => output.extend(seq),
230            TokenKind::Atom(c) => output.push(*c),
231            _ => {
232                // push the token back in the lexer
233                lexer.push(token);
234                break;
235            }
236        }
237    }
238    ensure!(!output.is_empty(), err_msg);
239    Ok(output)
240}
241
242/// Read a URL encoded IPv6 sequence into a string.
243///
244/// Example: `[2001:db8:85a3:8d3:1319:8a2e:370:7348]`
245///
246/// See also: https://en.wikipedia.org/wiki/IPv6_address#Literal_IPv6_addresses_in_network_resource_identifiers
247fn parse_ipv6(lexer: &mut Lexer, err_msg: &'static str) -> crate::Result<String> {
248    let _ = lexer.next();
249    let mut output = String::from('[');
250
251    loop {
252        match lexer.next().kind() {
253            TokenKind::Colon => output.push(':'),
254            TokenKind::Atom(c) if c.is_ascii_alphanumeric() => output.push(*c),
255            TokenKind::CloseBracket => {
256                output.push(']');
257                break;
258            }
259            _ => bail!(err_msg),
260        }
261    }
262
263    ensure!(!output.is_empty(), err_msg);
264    Ok(output)
265}
266
267#[derive(Debug)]
268struct Lexer {
269    tokens: Vec<Token>,
270}
271
272impl Lexer {
273    /// Parse a string into a list of tokens.
274    pub(crate) fn tokenize(mut input: &str) -> crate::Result<Self> {
275        let mut tokens = vec![];
276        let mut loc = Location::default();
277        while !input.is_empty() {
278            let old_input = input;
279            let mut chars = input.chars();
280            let kind = match chars.next().unwrap() {
281                // c if c.is_ascii_whitespace() => continue,
282                ':' => TokenKind::Colon,
283                '=' => TokenKind::Eq,
284                '\\' => TokenKind::BSlash,
285                '/' => TokenKind::FSlash,
286                ';' => TokenKind::Semi,
287                '[' => TokenKind::OpenBracket,
288                ']' => TokenKind::CloseBracket,
289                '{' => {
290                    let mut buf = Vec::new();
291                    // Read alphanumeric ASCII including whitespace until we find a closing curly.
292                    loop {
293                        match chars.next() {
294                            None => bail!("unclosed escape literal"),
295                            Some('}') => break,
296                            Some(c) if c.is_ascii() => buf.push(c),
297                            Some(c) => bail!("Invalid JDBC token `{}`", c),
298                        }
299                    }
300                    TokenKind::Escaped(buf)
301                }
302                c if c.is_ascii() => TokenKind::Atom(c),
303                c => bail!("Invalid JDBC token `{}`", c),
304            };
305            tokens.push(Token { kind, loc });
306            input = chars.as_str();
307
308            let consumed = old_input.len() - input.len();
309            loc.advance(&old_input[..consumed]);
310        }
311        tokens.reverse();
312        Ok(Self { tokens })
313    }
314
315    /// Get the next token from the queue.
316    #[must_use]
317    pub(crate) fn next(&mut self) -> Token {
318        self.tokens.pop().unwrap_or(Token {
319            kind: TokenKind::Eof,
320            loc: Location::default(),
321        })
322    }
323
324    /// Push a token back onto the queue.
325    pub(crate) fn push(&mut self, token: Token) {
326        self.tokens.push(token);
327    }
328
329    /// Peek at the next token in the queue.
330    #[must_use]
331    pub(crate) fn peek(&mut self) -> Token {
332        self.tokens.last().cloned().unwrap_or(Token {
333            kind: TokenKind::Eof,
334            loc: Location::default(),
335        })
336    }
337}
338
339/// Track the location of the Token inside the string.
340#[derive(Copy, Clone, Default, Debug)]
341pub(crate) struct Location {
342    pub(crate) column: usize,
343}
344
345impl Location {
346    fn advance(&mut self, text: &str) {
347        self.column += text.chars().count();
348    }
349}
350
351/// A pair of `Location` and `TokenKind`.
352#[derive(Debug, Clone)]
353struct Token {
354    #[allow(dead_code)] // for future use...
355    loc: Location,
356    kind: TokenKind,
357}
358
359impl Token {
360    /// What kind of token is this?
361    pub(crate) fn kind(&self) -> &TokenKind {
362        &self.kind
363    }
364}
365
366/// The kind of token we're encoding.
367#[derive(Debug, Clone, Eq, PartialEq)]
368enum TokenKind {
369    OpenBracket,
370    CloseBracket,
371    Colon,
372    Eq,
373    BSlash,
374    FSlash,
375    Semi,
376    /// An ident that falls inside a `{}` pair.
377    Escaped(Vec<char>),
378    /// An identifier in the connection string.
379    Atom(char),
380    Eof,
381}
382
383#[cfg(test)]
384mod test {
385    use super::JdbcString;
386
387    #[test]
388    fn parse_sub_protocol() -> crate::Result<()> {
389        let conn: JdbcString = "jdbc:sqlserver://".parse()?;
390        assert_eq!(conn.sub_protocol(), "jdbc:sqlserver");
391        Ok(())
392    }
393
394    #[test]
395    fn keys() -> crate::Result<()> {
396        let input = r#"jdbc:sqlserver://server:1433;database=prisma-demo;user=SA;password=Pr1sm4_Pr1sm4;trustServerCertificate=true;encrypt=true"#;
397        let conn: JdbcString = input.parse()?;
398        let keys = conn.keys().collect::<Vec<&str>>();
399        assert_eq!(keys.len(), 5);
400        assert!(keys.contains(&"database"));
401        assert!(keys.contains(&"user"));
402        assert!(keys.contains(&"password"));
403        assert!(keys.contains(&"trustservercertificate"));
404        assert!(keys.contains(&"encrypt"));
405        Ok(())
406    }
407
408    #[test]
409    fn parse_server_name() -> crate::Result<()> {
410        let conn: JdbcString = r#"jdbc:sqlserver://server"#.parse()?;
411        assert_eq!(conn.sub_protocol(), "jdbc:sqlserver");
412        assert_eq!(conn.server_name(), Some("server"));
413        Ok(())
414    }
415
416    #[test]
417    fn parse_instance_name() -> crate::Result<()> {
418        let conn: JdbcString = r#"jdbc:sqlserver://server\instance"#.parse()?;
419        assert_eq!(conn.sub_protocol(), "jdbc:sqlserver");
420        assert_eq!(conn.server_name(), Some("server"));
421        assert_eq!(conn.instance_name(), Some("instance"));
422        Ok(())
423    }
424
425    #[test]
426    fn parse_ipv6_url() -> crate::Result<()> {
427        let input = r#"jdbc:sqlserver://[::1]:1433;database=prisma-demo;user=SA;password=Pr1sm4_Pr1sm4;trustServerCertificate=true;encrypt=true"#;
428        let conn: JdbcString = input.parse()?;
429        assert_eq!(conn.server_name(), Some("[::1]"));
430        assert_eq!(conn.port(), Some(1433));
431
432        let input = r#"jdbc:sqlserver://[:1433;"#;
433        assert!(input.parse::<JdbcString>().is_err());
434
435        let input = r#"jdbc:sqlserver://[0f0f0:f00f==:09:12]:1433;"#;
436        assert!(input.parse::<JdbcString>().is_err());
437
438        let input = r#"jdbc:sqlserver://]:1433;"#;
439        assert!(input.parse::<JdbcString>().is_err());
440        Ok(())
441    }
442
443    #[test]
444    fn parse_port() -> crate::Result<()> {
445        let conn: JdbcString = r#"jdbc:sqlserver://server\instance:80"#.parse()?;
446        assert_eq!(conn.sub_protocol(), "jdbc:sqlserver");
447        assert_eq!(conn.server_name(), Some("server"));
448        assert_eq!(conn.instance_name(), Some("instance"));
449        assert_eq!(conn.port(), Some(80));
450        Ok(())
451    }
452
453    #[test]
454    fn parse_properties() -> crate::Result<()> {
455        let conn: JdbcString =
456            r#"jdbc:sqlserver://server\instance:80;key=value;foo=bar"#.parse()?;
457        assert_eq!(conn.sub_protocol(), "jdbc:sqlserver");
458        assert_eq!(conn.server_name(), Some("server"));
459        assert_eq!(conn.instance_name(), Some("instance"));
460        assert_eq!(conn.port(), Some(80));
461
462        let kv = conn.properties();
463        assert_eq!(kv.get("foo"), Some(&"bar".to_string()));
464        assert_eq!(kv.get("key"), Some(&"value".to_string()));
465        Ok(())
466    }
467
468    #[test]
469    fn escaped_properties() -> crate::Result<()> {
470        let conn: JdbcString =
471            r#"jdbc:sqlserver://se{r}ver{;}\instance:80;key={va[]}lue"#.parse()?;
472        assert_eq!(conn.sub_protocol(), "jdbc:sqlserver");
473        assert_eq!(conn.server_name(), Some("server;"));
474        assert_eq!(conn.instance_name(), Some("instance"));
475        assert_eq!(conn.port(), Some(80));
476
477        let kv = conn.properties();
478        assert_eq!(kv.get("key"), Some(&"va[]lue".to_string()));
479        Ok(())
480    }
481
482    #[test]
483    fn sub_protocol_error() -> crate::Result<()> {
484        let err = r#"jdbq:sqlserver://"#.parse::<JdbcString>().unwrap_err().to_string();
485        assert_eq!(err, "Conversion error: Invalid JDBC sub-protocol");
486        Ok(())
487    }
488
489    #[test]
490    fn whitespace() -> crate::Result<()> {
491        let conn: JdbcString =
492            r#"jdbc:sqlserver://server\instance:80;key=value;foo=bar;user id=musti naukio"#
493                .parse()?;
494        assert_eq!(conn.sub_protocol(), "jdbc:sqlserver");
495        assert_eq!(conn.server_name(), Some(r#"server"#));
496        assert_eq!(conn.instance_name(), Some("instance"));
497        assert_eq!(conn.port(), Some(80));
498
499        let kv = conn.properties();
500        assert_eq!(kv.get("user id"), Some(&"musti naukio".to_string()));
501        Ok(())
502    }
503
504    // Test for dashes and dots in the name, and parse names other than oracle
505    #[test]
506    fn regression_2020_10_06() -> crate::Result<()> {
507        let input = "jdbc:sqlserver://my-server.com:5433;foo=bar";
508        let _conn: JdbcString = input.parse()?;
509
510        let input = "jdbc:oracle://foo.bar:1234";
511        let _conn: JdbcString = input.parse()?;
512
513        Ok(())
514    }
515
516    /// While strictly disallowed, we should not fail if we detect a trailing semi.
517    #[test]
518    fn regression_2020_10_07_handle_trailing_semis() -> crate::Result<()> {
519        let input = "jdbc:sqlserver://my-server.com:5433;foo=bar;";
520        let _conn: JdbcString = input.parse()?;
521
522        let input = "jdbc:sqlserver://my-server.com:4200;User ID=musti;Password={abc;}}45}";
523        let conn: JdbcString = input.parse()?;
524        let props = conn.properties();
525        assert_eq!(props.get("user id"), Some(&"musti".to_owned()));
526        assert_eq!(props.get("password"), Some(&"abc;}45}".to_owned()));
527        Ok(())
528    }
529
530    #[test]
531    fn display_with_escaping() -> crate::Result<()> {
532        let input = r#"jdbc:sqlserver://server{;}\instance:80;key=va{[]}lue"#;
533        let conn: JdbcString = input.parse()?;
534
535        assert_eq!(format!("{}", conn), input);
536        Ok(())
537    }
538
539    // Output was being over-escaped and not split with semis, causing all sorts of uri failures.
540    #[test]
541    fn regression_2020_10_27_dont_escape_underscores_whitespace() -> crate::Result<()> {
542        let input = r#"jdbc:sqlserver://test-db-mssql-2017:1433;user=SA;encrypt=DANGER_PLAINTEXT;isolationlevel=READ UNCOMMITTED;schema=NonEmbeddedUpsertDesignSpec;trustservercertificate=true;password=<YourStrong@Passw0rd>"#;
543        let conn: JdbcString = input.parse()?;
544
545        let output = format!("{}", conn);
546        let mut output: Vec<String> = output.split(';').map(|s| s.to_owned()).collect();
547        output.pop();
548        output.sort();
549
550        let input = format!("{}", conn);
551        let mut input: Vec<String> = input.split(';').map(|s| s.to_owned()).collect();
552        input.pop();
553        input.sort();
554
555        assert_eq!(output, input);
556        Ok(())
557    }
558}