1#![allow(clippy::doc_overindented_list_items)]
4
5#[cfg(feature = "runtime")]
6use crate::connect::connect;
7use crate::connect_raw::connect_raw;
8#[cfg(not(target_arch = "wasm32"))]
9use crate::keepalive::KeepaliveConfig;
10#[cfg(feature = "runtime")]
11use crate::tls::MakeTlsConnect;
12use crate::tls::TlsConnect;
13#[cfg(feature = "runtime")]
14use crate::Socket;
15use crate::{Client, Connection, Error};
16use std::borrow::Cow;
17#[cfg(unix)]
18use std::ffi::OsStr;
19use std::net::IpAddr;
20use std::ops::Deref;
21#[cfg(unix)]
22use std::os::unix::ffi::OsStrExt;
23#[cfg(unix)]
24use std::path::Path;
25#[cfg(unix)]
26use std::path::PathBuf;
27use std::str;
28use std::str::FromStr;
29use std::time::Duration;
30use std::{error, fmt, iter, mem};
31use tokio::io::{AsyncRead, AsyncWrite};
32
33#[derive(Debug, Copy, Clone, PartialEq, Eq)]
35#[non_exhaustive]
36pub enum TargetSessionAttrs {
37 Any,
39 ReadWrite,
41 ReadOnly,
43}
44
45#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
47#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
48#[non_exhaustive]
49pub enum SslMode {
50 Disable,
52 Prefer,
54 Require,
56 VerifyCa,
58 VerifyFull,
60}
61
62#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)]
67#[non_exhaustive]
68pub enum SslNegotiation {
69 #[default]
71 Postgres,
72 Direct,
74}
75
76#[derive(Debug, Copy, Clone, PartialEq, Eq)]
78#[non_exhaustive]
79pub enum ChannelBinding {
80 Disable,
82 Prefer,
84 Require,
86}
87
88#[derive(Debug, Copy, Clone, PartialEq, Eq)]
90#[non_exhaustive]
91pub enum LoadBalanceHosts {
92 Disable,
94 Random,
96}
97
98#[derive(Debug, Copy, Clone, PartialEq, Eq)]
105#[non_exhaustive]
106pub enum ReplicationMode {
107 Physical,
109 Logical,
111}
112
113#[derive(Debug, Clone, PartialEq, Eq)]
115pub enum Host {
116 Tcp(String),
118 #[cfg(unix)]
122 Unix(PathBuf),
123}
124
125#[derive(Clone, PartialEq, Eq)]
248pub struct Config {
249 pub(crate) user: Option<String>,
250 pub(crate) password: Option<Vec<u8>>,
251 pub(crate) dbname: Option<String>,
252 pub(crate) options: Option<String>,
253 pub(crate) application_name: Option<String>,
254 pub(crate) ssl_cert: Option<Vec<u8>>,
255 pub(crate) ssl_key: Option<Vec<u8>>,
256 pub(crate) ssl_mode: SslMode,
257 pub(crate) ssl_root_cert: Option<Vec<u8>>,
258 pub(crate) ssl_negotiation: SslNegotiation,
259 pub(crate) host: Vec<Host>,
260 pub(crate) hostaddr: Vec<IpAddr>,
261 pub(crate) port: Vec<u16>,
262 pub(crate) connect_timeout: Option<Duration>,
263 pub(crate) tcp_user_timeout: Option<Duration>,
264 pub(crate) keepalives: bool,
265 #[cfg(not(target_arch = "wasm32"))]
266 pub(crate) keepalive_config: KeepaliveConfig,
267 pub(crate) target_session_attrs: TargetSessionAttrs,
268 pub(crate) channel_binding: ChannelBinding,
269 pub(crate) load_balance_hosts: LoadBalanceHosts,
270 pub(crate) replication_mode: Option<ReplicationMode>,
271}
272
273impl Default for Config {
274 fn default() -> Config {
275 Config::new()
276 }
277}
278
279impl Config {
280 pub fn new() -> Config {
282 Config {
283 user: None,
284 password: None,
285 dbname: None,
286 options: None,
287 application_name: None,
288 ssl_cert: None,
289 ssl_key: None,
290 ssl_mode: SslMode::Prefer,
291 ssl_root_cert: None,
292 ssl_negotiation: SslNegotiation::Postgres,
293 host: vec![],
294 hostaddr: vec![],
295 port: vec![],
296 connect_timeout: None,
297 tcp_user_timeout: None,
298 keepalives: true,
299 #[cfg(not(target_arch = "wasm32"))]
300 keepalive_config: KeepaliveConfig {
301 idle: Duration::from_secs(2 * 60 * 60),
302 interval: None,
303 retries: None,
304 },
305 target_session_attrs: TargetSessionAttrs::Any,
306 channel_binding: ChannelBinding::Prefer,
307 load_balance_hosts: LoadBalanceHosts::Disable,
308 replication_mode: None,
309 }
310 }
311
312 pub fn user(&mut self, user: impl Into<String>) -> &mut Config {
316 self.user = Some(user.into());
317 self
318 }
319
320 pub fn get_user(&self) -> Option<&str> {
323 self.user.as_deref()
324 }
325
326 pub fn password<T>(&mut self, password: T) -> &mut Config
328 where
329 T: AsRef<[u8]>,
330 {
331 self.password = Some(password.as_ref().to_vec());
332 self
333 }
334
335 pub fn get_password(&self) -> Option<&[u8]> {
338 self.password.as_deref()
339 }
340
341 pub fn dbname(&mut self, dbname: impl Into<String>) -> &mut Config {
345 self.dbname = Some(dbname.into());
346 self
347 }
348
349 pub fn get_dbname(&self) -> Option<&str> {
352 self.dbname.as_deref()
353 }
354
355 pub fn options(&mut self, options: impl Into<String>) -> &mut Config {
357 self.options = Some(options.into());
358 self
359 }
360
361 pub fn get_options(&self) -> Option<&str> {
364 self.options.as_deref()
365 }
366
367 pub fn application_name(&mut self, application_name: impl Into<String>) -> &mut Config {
369 self.application_name = Some(application_name.into());
370 self
371 }
372
373 pub fn get_application_name(&self) -> Option<&str> {
376 self.application_name.as_deref()
377 }
378
379 pub fn ssl_cert(&mut self, ssl_cert: &[u8]) -> &mut Config {
383 self.ssl_cert = Some(ssl_cert.into());
384 self
385 }
386
387 pub fn get_ssl_cert(&self) -> Option<&[u8]> {
389 self.ssl_cert.as_deref()
390 }
391
392 pub fn ssl_key(&mut self, ssl_key: &[u8]) -> &mut Config {
396 self.ssl_key = Some(ssl_key.into());
397 self
398 }
399
400 pub fn get_ssl_key(&self) -> Option<&[u8]> {
402 self.ssl_key.as_deref()
403 }
404
405 pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
409 self.ssl_mode = ssl_mode;
410 self
411 }
412
413 pub fn get_ssl_mode(&self) -> SslMode {
415 self.ssl_mode
416 }
417
418 pub fn ssl_root_cert(&mut self, ssl_root_cert: &[u8]) -> &mut Config {
422 self.ssl_root_cert = Some(ssl_root_cert.into());
423 self
424 }
425
426 pub fn get_ssl_root_cert(&self) -> Option<&[u8]> {
428 self.ssl_root_cert.as_deref()
429 }
430
431 pub fn ssl_negotiation(&mut self, ssl_negotiation: SslNegotiation) -> &mut Config {
435 self.ssl_negotiation = ssl_negotiation;
436 self
437 }
438
439 pub fn get_ssl_negotiation(&self) -> SslNegotiation {
441 self.ssl_negotiation
442 }
443
444 pub fn host(&mut self, host: impl Into<String>) -> &mut Config {
450 let host = host.into();
451
452 #[cfg(unix)]
453 {
454 if host.starts_with('/') {
455 return self.host_path(host);
456 }
457 }
458
459 self.host.push(Host::Tcp(host));
460 self
461 }
462
463 pub fn get_hosts(&self) -> &[Host] {
465 &self.host
466 }
467
468 pub fn get_hostaddrs(&self) -> &[IpAddr] {
470 self.hostaddr.deref()
471 }
472
473 #[cfg(unix)]
477 pub fn host_path<T>(&mut self, host: T) -> &mut Config
478 where
479 T: AsRef<Path>,
480 {
481 self.host.push(Host::Unix(host.as_ref().to_path_buf()));
482 self
483 }
484
485 pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config {
490 self.hostaddr.push(hostaddr);
491 self
492 }
493
494 pub fn port(&mut self, port: u16) -> &mut Config {
500 self.port.push(port);
501 self
502 }
503
504 pub fn get_ports(&self) -> &[u16] {
506 &self.port
507 }
508
509 pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config {
514 self.connect_timeout = Some(connect_timeout);
515 self
516 }
517
518 pub fn get_connect_timeout(&self) -> Option<&Duration> {
521 self.connect_timeout.as_ref()
522 }
523
524 pub fn tcp_user_timeout(&mut self, tcp_user_timeout: Duration) -> &mut Config {
530 self.tcp_user_timeout = Some(tcp_user_timeout);
531 self
532 }
533
534 pub fn get_tcp_user_timeout(&self) -> Option<&Duration> {
537 self.tcp_user_timeout.as_ref()
538 }
539
540 pub fn keepalives(&mut self, keepalives: bool) -> &mut Config {
544 self.keepalives = keepalives;
545 self
546 }
547
548 pub fn get_keepalives(&self) -> bool {
550 self.keepalives
551 }
552
553 #[cfg(not(target_arch = "wasm32"))]
557 pub fn keepalives_idle(&mut self, keepalives_idle: Duration) -> &mut Config {
558 self.keepalive_config.idle = keepalives_idle;
559 self
560 }
561
562 #[cfg(not(target_arch = "wasm32"))]
565 pub fn get_keepalives_idle(&self) -> Duration {
566 self.keepalive_config.idle
567 }
568
569 #[cfg(not(target_arch = "wasm32"))]
574 pub fn keepalives_interval(&mut self, keepalives_interval: Duration) -> &mut Config {
575 self.keepalive_config.interval = Some(keepalives_interval);
576 self
577 }
578
579 #[cfg(not(target_arch = "wasm32"))]
581 pub fn get_keepalives_interval(&self) -> Option<Duration> {
582 self.keepalive_config.interval
583 }
584
585 #[cfg(not(target_arch = "wasm32"))]
589 pub fn keepalives_retries(&mut self, keepalives_retries: u32) -> &mut Config {
590 self.keepalive_config.retries = Some(keepalives_retries);
591 self
592 }
593
594 #[cfg(not(target_arch = "wasm32"))]
596 pub fn get_keepalives_retries(&self) -> Option<u32> {
597 self.keepalive_config.retries
598 }
599
600 pub fn target_session_attrs(
605 &mut self,
606 target_session_attrs: TargetSessionAttrs,
607 ) -> &mut Config {
608 self.target_session_attrs = target_session_attrs;
609 self
610 }
611
612 pub fn get_target_session_attrs(&self) -> TargetSessionAttrs {
614 self.target_session_attrs
615 }
616
617 pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
621 self.channel_binding = channel_binding;
622 self
623 }
624
625 pub fn get_channel_binding(&self) -> ChannelBinding {
627 self.channel_binding
628 }
629
630 pub fn load_balance_hosts(&mut self, load_balance_hosts: LoadBalanceHosts) -> &mut Config {
634 self.load_balance_hosts = load_balance_hosts;
635 self
636 }
637
638 pub fn get_load_balance_hosts(&self) -> LoadBalanceHosts {
640 self.load_balance_hosts
641 }
642
643 pub fn replication_mode(&mut self, replication_mode: ReplicationMode) -> &mut Config {
650 self.replication_mode = Some(replication_mode);
651 self
652 }
653
654 pub fn get_replication_mode(&self) -> Option<ReplicationMode> {
656 self.replication_mode
657 }
658
659 fn param(&mut self, key: &str, value: &str) -> Result<(), Error> {
660 match key {
661 "user" => {
662 self.user(value);
663 }
664 "password" => {
665 self.password(value);
666 }
667 "dbname" => {
668 self.dbname(value);
669 }
670 "options" => {
671 self.options(value);
672 }
673 "application_name" => {
674 self.application_name(value);
675 }
676 "sslcert" => match std::fs::read(value) {
677 Ok(contents) => {
678 self.ssl_cert(&contents);
679 }
680 Err(_) => {
681 return Err(Error::config_parse(Box::new(InvalidValue("sslcert"))));
682 }
683 },
684 "sslcert_inline" => {
685 self.ssl_cert(value.as_bytes());
686 }
687 "sslkey" => match std::fs::read(value) {
688 Ok(contents) => {
689 self.ssl_key(&contents);
690 }
691 Err(_) => {
692 return Err(Error::config_parse(Box::new(InvalidValue("sslkey"))));
693 }
694 },
695 "sslkey_inline" => {
696 self.ssl_key(value.as_bytes());
697 }
698 "sslmode" => {
699 let mode = match value {
700 "disable" => SslMode::Disable,
701 "prefer" => SslMode::Prefer,
702 "require" => SslMode::Require,
703 "verify-ca" => SslMode::VerifyCa,
704 "verify-full" => SslMode::VerifyFull,
705 _ => return Err(Error::config_parse(Box::new(InvalidValue("sslmode")))),
706 };
707 self.ssl_mode(mode);
708 }
709 "sslrootcert" => match std::fs::read(value) {
710 Ok(contents) => {
711 self.ssl_root_cert(&contents);
712 }
713 Err(_) => {
714 return Err(Error::config_parse(Box::new(InvalidValue("sslrootcert"))));
715 }
716 },
717 "sslrootcert_inline" => {
718 self.ssl_root_cert(value.as_bytes());
719 }
720 "sslnegotiation" => {
721 let mode = match value {
722 "postgres" => SslNegotiation::Postgres,
723 "direct" => SslNegotiation::Direct,
724 _ => {
725 return Err(Error::config_parse(Box::new(InvalidValue(
726 "sslnegotiation",
727 ))))
728 }
729 };
730 self.ssl_negotiation(mode);
731 }
732 "host" => {
733 for host in value.split(',') {
734 self.host(host);
735 }
736 }
737 "hostaddr" => {
738 for hostaddr in value.split(',') {
739 let addr = hostaddr
740 .parse()
741 .map_err(|_| Error::config_parse(Box::new(InvalidValue("hostaddr"))))?;
742 self.hostaddr(addr);
743 }
744 }
745 "port" => {
746 for port in value.split(',') {
747 let port = if port.is_empty() {
748 5432
749 } else {
750 port.parse()
751 .map_err(|_| Error::config_parse(Box::new(InvalidValue("port"))))?
752 };
753 self.port(port);
754 }
755 }
756 "connect_timeout" => {
757 let timeout = value
758 .parse::<i64>()
759 .map_err(|_| Error::config_parse(Box::new(InvalidValue("connect_timeout"))))?;
760 if timeout > 0 {
761 self.connect_timeout(Duration::from_secs(timeout as u64));
762 }
763 }
764 "tcp_user_timeout" => {
765 let timeout = value
766 .parse::<i64>()
767 .map_err(|_| Error::config_parse(Box::new(InvalidValue("tcp_user_timeout"))))?;
768 if timeout > 0 {
769 self.tcp_user_timeout(Duration::from_secs(timeout as u64));
770 }
771 }
772 #[cfg(not(target_arch = "wasm32"))]
773 "keepalives" => {
774 let keepalives = value
775 .parse::<u64>()
776 .map_err(|_| Error::config_parse(Box::new(InvalidValue("keepalives"))))?;
777 self.keepalives(keepalives != 0);
778 }
779 #[cfg(not(target_arch = "wasm32"))]
780 "keepalives_idle" => {
781 let keepalives_idle = value
782 .parse::<i64>()
783 .map_err(|_| Error::config_parse(Box::new(InvalidValue("keepalives_idle"))))?;
784 if keepalives_idle > 0 {
785 self.keepalives_idle(Duration::from_secs(keepalives_idle as u64));
786 }
787 }
788 #[cfg(not(target_arch = "wasm32"))]
789 "keepalives_interval" => {
790 let keepalives_interval = value.parse::<i64>().map_err(|_| {
791 Error::config_parse(Box::new(InvalidValue("keepalives_interval")))
792 })?;
793 if keepalives_interval > 0 {
794 self.keepalives_interval(Duration::from_secs(keepalives_interval as u64));
795 }
796 }
797 #[cfg(not(target_arch = "wasm32"))]
798 "keepalives_retries" => {
799 let keepalives_retries = value.parse::<u32>().map_err(|_| {
800 Error::config_parse(Box::new(InvalidValue("keepalives_retries")))
801 })?;
802 self.keepalives_retries(keepalives_retries);
803 }
804 "target_session_attrs" => {
805 let target_session_attrs = match value {
806 "any" => TargetSessionAttrs::Any,
807 "read-write" => TargetSessionAttrs::ReadWrite,
808 "read-only" => TargetSessionAttrs::ReadOnly,
809 _ => {
810 return Err(Error::config_parse(Box::new(InvalidValue(
811 "target_session_attrs",
812 ))));
813 }
814 };
815 self.target_session_attrs(target_session_attrs);
816 }
817 "channel_binding" => {
818 let channel_binding = match value {
819 "disable" => ChannelBinding::Disable,
820 "prefer" => ChannelBinding::Prefer,
821 "require" => ChannelBinding::Require,
822 _ => {
823 return Err(Error::config_parse(Box::new(InvalidValue(
824 "channel_binding",
825 ))))
826 }
827 };
828 self.channel_binding(channel_binding);
829 }
830 "load_balance_hosts" => {
831 let load_balance_hosts = match value {
832 "disable" => LoadBalanceHosts::Disable,
833 "random" => LoadBalanceHosts::Random,
834 _ => {
835 return Err(Error::config_parse(Box::new(InvalidValue(
836 "load_balance_hosts",
837 ))))
838 }
839 };
840 self.load_balance_hosts(load_balance_hosts);
841 }
842 "replication" => {
843 let mode = match value {
844 "off" => None,
845 "true" => Some(ReplicationMode::Physical),
846 "database" => Some(ReplicationMode::Logical),
847 _ => return Err(Error::config_parse(Box::new(InvalidValue("replication")))),
848 };
849 if let Some(mode) = mode {
850 self.replication_mode(mode);
851 }
852 }
853 key => {
854 return Err(Error::config_parse(Box::new(UnknownOption(
855 key.to_string(),
856 ))));
857 }
858 }
859
860 Ok(())
861 }
862
863 #[cfg(feature = "runtime")]
867 pub async fn connect<T>(&self, tls: T) -> Result<(Client, Connection<Socket, T::Stream>), Error>
868 where
869 T: MakeTlsConnect<Socket>,
870 {
871 connect(tls, self).await
872 }
873
874 pub async fn connect_raw<S, T>(
878 &self,
879 stream: S,
880 tls: T,
881 ) -> Result<(Client, Connection<S, T::Stream>), Error>
882 where
883 S: AsyncRead + AsyncWrite + Unpin,
884 T: TlsConnect<S>,
885 {
886 connect_raw(stream, tls, true, self).await
887 }
888}
889
890impl FromStr for Config {
891 type Err = Error;
892
893 fn from_str(s: &str) -> Result<Config, Error> {
894 match UrlParser::parse(s)? {
895 Some(config) => Ok(config),
896 None => Parser::parse(s),
897 }
898 }
899}
900
901impl fmt::Debug for Config {
903 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
904 struct Redaction {}
905 impl fmt::Debug for Redaction {
906 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
907 write!(f, "_")
908 }
909 }
910
911 let mut config_dbg = &mut f.debug_struct("Config");
912 config_dbg = config_dbg
913 .field("user", &self.user)
914 .field("password", &self.password.as_ref().map(|_| Redaction {}))
915 .field("dbname", &self.dbname)
916 .field("options", &self.options)
917 .field("application_name", &self.application_name)
918 .field("ssl_cert", &self.ssl_cert)
919 .field("ssl_key", &self.ssl_key)
920 .field("ssl_mode", &self.ssl_mode)
921 .field("ssl_root_cert", &self.ssl_root_cert)
922 .field("host", &self.host)
923 .field("hostaddr", &self.hostaddr)
924 .field("port", &self.port)
925 .field("connect_timeout", &self.connect_timeout)
926 .field("tcp_user_timeout", &self.tcp_user_timeout)
927 .field("keepalives", &self.keepalives);
928
929 #[cfg(not(target_arch = "wasm32"))]
930 {
931 config_dbg = config_dbg
932 .field("keepalives_idle", &self.keepalive_config.idle)
933 .field("keepalives_interval", &self.keepalive_config.interval)
934 .field("keepalives_retries", &self.keepalive_config.retries);
935 }
936
937 config_dbg
938 .field("target_session_attrs", &self.target_session_attrs)
939 .field("channel_binding", &self.channel_binding)
940 .field("replication", &self.replication_mode)
941 .field("load_balance_hosts", &self.load_balance_hosts)
942 .finish()
943 }
944}
945
946#[derive(Debug)]
947struct UnknownOption(String);
948
949impl fmt::Display for UnknownOption {
950 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
951 write!(fmt, "unknown option `{}`", self.0)
952 }
953}
954
955impl error::Error for UnknownOption {}
956
957#[derive(Debug)]
958struct InvalidValue(&'static str);
959
960impl fmt::Display for InvalidValue {
961 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
962 write!(fmt, "invalid value for option `{}`", self.0)
963 }
964}
965
966impl error::Error for InvalidValue {}
967
968struct Parser<'a> {
969 s: &'a str,
970 it: iter::Peekable<str::CharIndices<'a>>,
971}
972
973impl<'a> Parser<'a> {
974 fn parse(s: &'a str) -> Result<Config, Error> {
975 let mut parser = Parser {
976 s,
977 it: s.char_indices().peekable(),
978 };
979
980 let mut config = Config::new();
981
982 while let Some((key, value)) = parser.parameter()? {
983 config.param(key, &value)?;
984 }
985
986 Ok(config)
987 }
988
989 fn skip_ws(&mut self) {
990 self.take_while(char::is_whitespace);
991 }
992
993 fn take_while<F>(&mut self, f: F) -> &'a str
994 where
995 F: Fn(char) -> bool,
996 {
997 let start = match self.it.peek() {
998 Some(&(i, _)) => i,
999 None => return "",
1000 };
1001
1002 loop {
1003 match self.it.peek() {
1004 Some(&(_, c)) if f(c) => {
1005 self.it.next();
1006 }
1007 Some(&(i, _)) => return &self.s[start..i],
1008 None => return &self.s[start..],
1009 }
1010 }
1011 }
1012
1013 fn eat(&mut self, target: char) -> Result<(), Error> {
1014 match self.it.next() {
1015 Some((_, c)) if c == target => Ok(()),
1016 Some((i, c)) => {
1017 let m =
1018 format!("unexpected character at byte {i}: expected `{target}` but got `{c}`");
1019 Err(Error::config_parse(m.into()))
1020 }
1021 None => Err(Error::config_parse("unexpected EOF".into())),
1022 }
1023 }
1024
1025 fn eat_if(&mut self, target: char) -> bool {
1026 match self.it.peek() {
1027 Some(&(_, c)) if c == target => {
1028 self.it.next();
1029 true
1030 }
1031 _ => false,
1032 }
1033 }
1034
1035 fn keyword(&mut self) -> Option<&'a str> {
1036 let s = self.take_while(|c| match c {
1037 c if c.is_whitespace() => false,
1038 '=' => false,
1039 _ => true,
1040 });
1041
1042 if s.is_empty() {
1043 None
1044 } else {
1045 Some(s)
1046 }
1047 }
1048
1049 fn value(&mut self) -> Result<String, Error> {
1050 let value = if self.eat_if('\'') {
1051 let value = self.quoted_value()?;
1052 self.eat('\'')?;
1053 value
1054 } else {
1055 self.simple_value()?
1056 };
1057
1058 Ok(value)
1059 }
1060
1061 fn simple_value(&mut self) -> Result<String, Error> {
1062 let mut value = String::new();
1063
1064 while let Some(&(_, c)) = self.it.peek() {
1065 if c.is_whitespace() {
1066 break;
1067 }
1068
1069 self.it.next();
1070 if c == '\\' {
1071 if let Some((_, c2)) = self.it.next() {
1072 value.push(c2);
1073 }
1074 } else {
1075 value.push(c);
1076 }
1077 }
1078
1079 if value.is_empty() {
1080 return Err(Error::config_parse("unexpected EOF".into()));
1081 }
1082
1083 Ok(value)
1084 }
1085
1086 fn quoted_value(&mut self) -> Result<String, Error> {
1087 let mut value = String::new();
1088
1089 while let Some(&(_, c)) = self.it.peek() {
1090 if c == '\'' {
1091 return Ok(value);
1092 }
1093
1094 self.it.next();
1095 if c == '\\' {
1096 if let Some((_, c2)) = self.it.next() {
1097 value.push(c2);
1098 }
1099 } else {
1100 value.push(c);
1101 }
1102 }
1103
1104 Err(Error::config_parse(
1105 "unterminated quoted connection parameter value".into(),
1106 ))
1107 }
1108
1109 fn parameter(&mut self) -> Result<Option<(&'a str, String)>, Error> {
1110 self.skip_ws();
1111 let keyword = match self.keyword() {
1112 Some(keyword) => keyword,
1113 None => return Ok(None),
1114 };
1115 self.skip_ws();
1116 self.eat('=')?;
1117 self.skip_ws();
1118 let value = self.value()?;
1119
1120 Ok(Some((keyword, value)))
1121 }
1122}
1123
1124struct UrlParser<'a> {
1126 s: &'a str,
1127 config: Config,
1128}
1129
1130impl<'a> UrlParser<'a> {
1131 fn parse(s: &'a str) -> Result<Option<Config>, Error> {
1132 let s = match Self::remove_url_prefix(s) {
1133 Some(s) => s,
1134 None => return Ok(None),
1135 };
1136
1137 let mut parser = UrlParser {
1138 s,
1139 config: Config::new(),
1140 };
1141
1142 parser.parse_credentials()?;
1143 parser.parse_host()?;
1144 parser.parse_path()?;
1145 parser.parse_params()?;
1146
1147 Ok(Some(parser.config))
1148 }
1149
1150 fn remove_url_prefix(s: &str) -> Option<&str> {
1151 for prefix in &["postgres://", "postgresql://"] {
1152 if let Some(stripped) = s.strip_prefix(prefix) {
1153 return Some(stripped);
1154 }
1155 }
1156
1157 None
1158 }
1159
1160 fn take_until(&mut self, end: &[char]) -> Option<&'a str> {
1161 match self.s.find(end) {
1162 Some(pos) => {
1163 let (head, tail) = self.s.split_at(pos);
1164 self.s = tail;
1165 Some(head)
1166 }
1167 None => None,
1168 }
1169 }
1170
1171 fn take_all(&mut self) -> &'a str {
1172 mem::take(&mut self.s)
1173 }
1174
1175 fn eat_byte(&mut self) {
1176 self.s = &self.s[1..];
1177 }
1178
1179 fn parse_credentials(&mut self) -> Result<(), Error> {
1180 let creds = match self.take_until(&['@']) {
1181 Some(creds) => creds,
1182 None => return Ok(()),
1183 };
1184 self.eat_byte();
1185
1186 let mut it = creds.splitn(2, ':');
1187 let user = self.decode(it.next().unwrap())?;
1188 self.config.user(user);
1189
1190 if let Some(password) = it.next() {
1191 let password = Cow::from(percent_encoding::percent_decode(password.as_bytes()));
1192 self.config.password(password);
1193 }
1194
1195 Ok(())
1196 }
1197
1198 fn parse_host(&mut self) -> Result<(), Error> {
1199 let host = match self.take_until(&['/', '?']) {
1200 Some(host) => host,
1201 None => self.take_all(),
1202 };
1203
1204 if host.is_empty() {
1205 return Ok(());
1206 }
1207
1208 for chunk in host.split(',') {
1209 let (host, port) = if chunk.starts_with('[') {
1210 let idx = match chunk.find(']') {
1211 Some(idx) => idx,
1212 None => return Err(Error::config_parse(InvalidValue("host").into())),
1213 };
1214
1215 let host = &chunk[1..idx];
1216 let remaining = &chunk[idx + 1..];
1217 let port = if let Some(port) = remaining.strip_prefix(':') {
1218 Some(port)
1219 } else if remaining.is_empty() {
1220 None
1221 } else {
1222 return Err(Error::config_parse(InvalidValue("host").into()));
1223 };
1224
1225 (host, port)
1226 } else {
1227 let mut it = chunk.splitn(2, ':');
1228 (it.next().unwrap(), it.next())
1229 };
1230
1231 self.host_param(host)?;
1232 let port = self.decode(port.unwrap_or("5432"))?;
1233 self.config.param("port", &port)?;
1234 }
1235
1236 Ok(())
1237 }
1238
1239 fn parse_path(&mut self) -> Result<(), Error> {
1240 if !self.s.starts_with('/') {
1241 return Ok(());
1242 }
1243 self.eat_byte();
1244
1245 let dbname = match self.take_until(&['?']) {
1246 Some(dbname) => dbname,
1247 None => self.take_all(),
1248 };
1249
1250 if !dbname.is_empty() {
1251 self.config.dbname(self.decode(dbname)?);
1252 }
1253
1254 Ok(())
1255 }
1256
1257 fn parse_params(&mut self) -> Result<(), Error> {
1258 if !self.s.starts_with('?') {
1259 return Ok(());
1260 }
1261 self.eat_byte();
1262
1263 while !self.s.is_empty() {
1264 let key = match self.take_until(&['=']) {
1265 Some(key) => self.decode(key)?,
1266 None => return Err(Error::config_parse("unterminated parameter".into())),
1267 };
1268 self.eat_byte();
1269
1270 let value = match self.take_until(&['&']) {
1271 Some(value) => {
1272 self.eat_byte();
1273 value
1274 }
1275 None => self.take_all(),
1276 };
1277
1278 if key == "host" {
1279 self.host_param(value)?;
1280 } else {
1281 let value = self.decode(value)?;
1282 self.config.param(&key, &value)?;
1283 }
1284 }
1285
1286 Ok(())
1287 }
1288
1289 #[cfg(unix)]
1290 fn host_param(&mut self, s: &str) -> Result<(), Error> {
1291 let decoded = Cow::from(percent_encoding::percent_decode(s.as_bytes()));
1292 if decoded.first() == Some(&b'/') {
1293 self.config.host_path(OsStr::from_bytes(&decoded));
1294 } else {
1295 let decoded = str::from_utf8(&decoded).map_err(|e| Error::config_parse(Box::new(e)))?;
1296 self.config.host(decoded);
1297 }
1298
1299 Ok(())
1300 }
1301
1302 #[cfg(not(unix))]
1303 fn host_param(&mut self, s: &str) -> Result<(), Error> {
1304 let s = self.decode(s)?;
1305 self.config.param("host", &s)
1306 }
1307
1308 fn decode(&self, s: &'a str) -> Result<Cow<'a, str>, Error> {
1309 percent_encoding::percent_decode(s.as_bytes())
1310 .decode_utf8()
1311 .map_err(|e| Error::config_parse(e.into()))
1312 }
1313}
1314
1315#[cfg(test)]
1316mod tests {
1317 use std::net::IpAddr;
1318
1319 use crate::{config::Host, Config};
1320
1321 #[test]
1322 fn test_simple_parsing() {
1323 let s = "user=pass_user dbname=postgres host=host1,host2 hostaddr=127.0.0.1,127.0.0.2 port=26257";
1324 let config = s.parse::<Config>().unwrap();
1325 assert_eq!(Some("pass_user"), config.get_user());
1326 assert_eq!(Some("postgres"), config.get_dbname());
1327 assert_eq!(
1328 [
1329 Host::Tcp("host1".to_string()),
1330 Host::Tcp("host2".to_string())
1331 ],
1332 config.get_hosts(),
1333 );
1334
1335 assert_eq!(
1336 [
1337 "127.0.0.1".parse::<IpAddr>().unwrap(),
1338 "127.0.0.2".parse::<IpAddr>().unwrap()
1339 ],
1340 config.get_hostaddrs(),
1341 );
1342
1343 assert_eq!(1, 1);
1344 }
1345
1346 #[test]
1347 fn test_invalid_hostaddr_parsing() {
1348 let s = "user=pass_user dbname=postgres host=host1 hostaddr=127.0.0 port=26257";
1349 s.parse::<Config>().err().unwrap();
1350 }
1351}