1#[cfg(feature = "runtime")]
4use crate::connect::connect;
5use crate::connect_raw::connect_raw;
6#[cfg(not(target_arch = "wasm32"))]
7use crate::keepalive::KeepaliveConfig;
8#[cfg(feature = "runtime")]
9use crate::tls::MakeTlsConnect;
10use crate::tls::TlsConnect;
11#[cfg(feature = "runtime")]
12use crate::Socket;
13use crate::{Client, Connection, Error};
14use std::borrow::Cow;
15#[cfg(unix)]
16use std::ffi::OsStr;
17use std::net::IpAddr;
18use std::ops::Deref;
19#[cfg(unix)]
20use std::os::unix::ffi::OsStrExt;
21#[cfg(unix)]
22use std::path::Path;
23#[cfg(unix)]
24use std::path::PathBuf;
25use std::str;
26use std::str::FromStr;
27use std::time::Duration;
28use std::{error, fmt, iter, mem};
29use tokio::io::{AsyncRead, AsyncWrite};
30
31#[derive(Debug, Copy, Clone, PartialEq, Eq)]
33#[non_exhaustive]
34pub enum TargetSessionAttrs {
35 Any,
37 ReadWrite,
39 ReadOnly,
41}
42
43#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
45#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
46#[non_exhaustive]
47pub enum SslMode {
48 Disable,
50 Prefer,
52 Require,
54 VerifyCa,
56 VerifyFull,
58}
59
60#[derive(Debug, Copy, Clone, PartialEq, Eq)]
62#[non_exhaustive]
63pub enum ChannelBinding {
64 Disable,
66 Prefer,
68 Require,
70}
71
72#[derive(Debug, Copy, Clone, PartialEq, Eq)]
74#[non_exhaustive]
75pub enum LoadBalanceHosts {
76 Disable,
78 Random,
80}
81
82#[derive(Debug, Copy, Clone, PartialEq, Eq)]
89#[non_exhaustive]
90pub enum ReplicationMode {
91 Physical,
93 Logical,
95}
96
97#[derive(Debug, Clone, PartialEq, Eq)]
99pub enum Host {
100 Tcp(String),
102 #[cfg(unix)]
106 Unix(PathBuf),
107}
108
109#[derive(Clone, PartialEq, Eq)]
223pub struct Config {
224 pub(crate) user: Option<String>,
225 pub(crate) password: Option<Vec<u8>>,
226 pub(crate) dbname: Option<String>,
227 pub(crate) options: Option<String>,
228 pub(crate) application_name: Option<String>,
229 pub(crate) ssl_cert: Option<Vec<u8>>,
230 pub(crate) ssl_key: Option<Vec<u8>>,
231 pub(crate) ssl_mode: SslMode,
232 pub(crate) ssl_root_cert: Option<Vec<u8>>,
233 pub(crate) host: Vec<Host>,
234 pub(crate) hostaddr: Vec<IpAddr>,
235 pub(crate) port: Vec<u16>,
236 pub(crate) connect_timeout: Option<Duration>,
237 pub(crate) tcp_user_timeout: Option<Duration>,
238 pub(crate) keepalives: bool,
239 #[cfg(not(target_arch = "wasm32"))]
240 pub(crate) keepalive_config: KeepaliveConfig,
241 pub(crate) target_session_attrs: TargetSessionAttrs,
242 pub(crate) channel_binding: ChannelBinding,
243 pub(crate) load_balance_hosts: LoadBalanceHosts,
244 pub(crate) replication_mode: Option<ReplicationMode>,
245}
246
247impl Default for Config {
248 fn default() -> Config {
249 Config::new()
250 }
251}
252
253impl Config {
254 pub fn new() -> Config {
256 Config {
257 user: None,
258 password: None,
259 dbname: None,
260 options: None,
261 application_name: None,
262 ssl_cert: None,
263 ssl_key: None,
264 ssl_mode: SslMode::Prefer,
265 ssl_root_cert: None,
266 host: vec![],
267 hostaddr: vec![],
268 port: vec![],
269 connect_timeout: None,
270 tcp_user_timeout: None,
271 keepalives: true,
272 #[cfg(not(target_arch = "wasm32"))]
273 keepalive_config: KeepaliveConfig {
274 idle: Duration::from_secs(2 * 60 * 60),
275 interval: None,
276 retries: None,
277 },
278 target_session_attrs: TargetSessionAttrs::Any,
279 channel_binding: ChannelBinding::Prefer,
280 load_balance_hosts: LoadBalanceHosts::Disable,
281 replication_mode: None,
282 }
283 }
284
285 pub fn user(&mut self, user: impl Into<String>) -> &mut Config {
289 self.user = Some(user.into());
290 self
291 }
292
293 pub fn get_user(&self) -> Option<&str> {
296 self.user.as_deref()
297 }
298
299 pub fn password<T>(&mut self, password: T) -> &mut Config
301 where
302 T: AsRef<[u8]>,
303 {
304 self.password = Some(password.as_ref().to_vec());
305 self
306 }
307
308 pub fn get_password(&self) -> Option<&[u8]> {
311 self.password.as_deref()
312 }
313
314 pub fn dbname(&mut self, dbname: impl Into<String>) -> &mut Config {
318 self.dbname = Some(dbname.into());
319 self
320 }
321
322 pub fn get_dbname(&self) -> Option<&str> {
325 self.dbname.as_deref()
326 }
327
328 pub fn options(&mut self, options: impl Into<String>) -> &mut Config {
330 self.options = Some(options.into());
331 self
332 }
333
334 pub fn get_options(&self) -> Option<&str> {
337 self.options.as_deref()
338 }
339
340 pub fn application_name(&mut self, application_name: impl Into<String>) -> &mut Config {
342 self.application_name = Some(application_name.into());
343 self
344 }
345
346 pub fn get_application_name(&self) -> Option<&str> {
349 self.application_name.as_deref()
350 }
351
352 pub fn ssl_cert(&mut self, ssl_cert: &[u8]) -> &mut Config {
356 self.ssl_cert = Some(ssl_cert.into());
357 self
358 }
359
360 pub fn get_ssl_cert(&self) -> Option<&[u8]> {
362 self.ssl_cert.as_deref()
363 }
364
365 pub fn ssl_key(&mut self, ssl_key: &[u8]) -> &mut Config {
369 self.ssl_key = Some(ssl_key.into());
370 self
371 }
372
373 pub fn get_ssl_key(&self) -> Option<&[u8]> {
375 self.ssl_key.as_deref()
376 }
377
378 pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
382 self.ssl_mode = ssl_mode;
383 self
384 }
385
386 pub fn get_ssl_mode(&self) -> SslMode {
388 self.ssl_mode
389 }
390
391 pub fn ssl_root_cert(&mut self, ssl_root_cert: &[u8]) -> &mut Config {
395 self.ssl_root_cert = Some(ssl_root_cert.into());
396 self
397 }
398
399 pub fn get_ssl_root_cert(&self) -> Option<&[u8]> {
401 self.ssl_root_cert.as_deref()
402 }
403
404 pub fn host(&mut self, host: impl Into<String>) -> &mut Config {
410 let host = host.into();
411
412 #[cfg(unix)]
413 {
414 if host.starts_with('/') {
415 return self.host_path(host);
416 }
417 }
418
419 self.host.push(Host::Tcp(host));
420 self
421 }
422
423 pub fn get_hosts(&self) -> &[Host] {
425 &self.host
426 }
427
428 pub fn get_hostaddrs(&self) -> &[IpAddr] {
430 self.hostaddr.deref()
431 }
432
433 #[cfg(unix)]
437 pub fn host_path<T>(&mut self, host: T) -> &mut Config
438 where
439 T: AsRef<Path>,
440 {
441 self.host.push(Host::Unix(host.as_ref().to_path_buf()));
442 self
443 }
444
445 pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config {
450 self.hostaddr.push(hostaddr);
451 self
452 }
453
454 pub fn port(&mut self, port: u16) -> &mut Config {
460 self.port.push(port);
461 self
462 }
463
464 pub fn get_ports(&self) -> &[u16] {
466 &self.port
467 }
468
469 pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config {
474 self.connect_timeout = Some(connect_timeout);
475 self
476 }
477
478 pub fn get_connect_timeout(&self) -> Option<&Duration> {
481 self.connect_timeout.as_ref()
482 }
483
484 pub fn tcp_user_timeout(&mut self, tcp_user_timeout: Duration) -> &mut Config {
490 self.tcp_user_timeout = Some(tcp_user_timeout);
491 self
492 }
493
494 pub fn get_tcp_user_timeout(&self) -> Option<&Duration> {
497 self.tcp_user_timeout.as_ref()
498 }
499
500 pub fn keepalives(&mut self, keepalives: bool) -> &mut Config {
504 self.keepalives = keepalives;
505 self
506 }
507
508 pub fn get_keepalives(&self) -> bool {
510 self.keepalives
511 }
512
513 #[cfg(not(target_arch = "wasm32"))]
517 pub fn keepalives_idle(&mut self, keepalives_idle: Duration) -> &mut Config {
518 self.keepalive_config.idle = keepalives_idle;
519 self
520 }
521
522 #[cfg(not(target_arch = "wasm32"))]
525 pub fn get_keepalives_idle(&self) -> Duration {
526 self.keepalive_config.idle
527 }
528
529 #[cfg(not(target_arch = "wasm32"))]
534 pub fn keepalives_interval(&mut self, keepalives_interval: Duration) -> &mut Config {
535 self.keepalive_config.interval = Some(keepalives_interval);
536 self
537 }
538
539 #[cfg(not(target_arch = "wasm32"))]
541 pub fn get_keepalives_interval(&self) -> Option<Duration> {
542 self.keepalive_config.interval
543 }
544
545 #[cfg(not(target_arch = "wasm32"))]
549 pub fn keepalives_retries(&mut self, keepalives_retries: u32) -> &mut Config {
550 self.keepalive_config.retries = Some(keepalives_retries);
551 self
552 }
553
554 #[cfg(not(target_arch = "wasm32"))]
556 pub fn get_keepalives_retries(&self) -> Option<u32> {
557 self.keepalive_config.retries
558 }
559
560 pub fn target_session_attrs(
565 &mut self,
566 target_session_attrs: TargetSessionAttrs,
567 ) -> &mut Config {
568 self.target_session_attrs = target_session_attrs;
569 self
570 }
571
572 pub fn get_target_session_attrs(&self) -> TargetSessionAttrs {
574 self.target_session_attrs
575 }
576
577 pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
581 self.channel_binding = channel_binding;
582 self
583 }
584
585 pub fn get_channel_binding(&self) -> ChannelBinding {
587 self.channel_binding
588 }
589
590 pub fn load_balance_hosts(&mut self, load_balance_hosts: LoadBalanceHosts) -> &mut Config {
594 self.load_balance_hosts = load_balance_hosts;
595 self
596 }
597
598 pub fn get_load_balance_hosts(&self) -> LoadBalanceHosts {
600 self.load_balance_hosts
601 }
602
603 pub fn replication_mode(&mut self, replication_mode: ReplicationMode) -> &mut Config {
610 self.replication_mode = Some(replication_mode);
611 self
612 }
613
614 pub fn get_replication_mode(&self) -> Option<ReplicationMode> {
616 self.replication_mode
617 }
618
619 fn param(&mut self, key: &str, value: &str) -> Result<(), Error> {
620 match key {
621 "user" => {
622 self.user(value);
623 }
624 "password" => {
625 self.password(value);
626 }
627 "dbname" => {
628 self.dbname(value);
629 }
630 "options" => {
631 self.options(value);
632 }
633 "application_name" => {
634 self.application_name(value);
635 }
636 "sslcert" => match std::fs::read(value) {
637 Ok(contents) => {
638 self.ssl_cert(&contents);
639 }
640 Err(_) => {
641 return Err(Error::config_parse(Box::new(InvalidValue("sslcert"))));
642 }
643 },
644 "sslcert_inline" => {
645 self.ssl_cert(value.as_bytes());
646 }
647 "sslkey" => match std::fs::read(value) {
648 Ok(contents) => {
649 self.ssl_key(&contents);
650 }
651 Err(_) => {
652 return Err(Error::config_parse(Box::new(InvalidValue("sslkey"))));
653 }
654 },
655 "sslkey_inline" => {
656 self.ssl_key(value.as_bytes());
657 }
658 "sslmode" => {
659 let mode = match value {
660 "disable" => SslMode::Disable,
661 "prefer" => SslMode::Prefer,
662 "require" => SslMode::Require,
663 "verify-ca" => SslMode::VerifyCa,
664 "verify-full" => SslMode::VerifyFull,
665 _ => return Err(Error::config_parse(Box::new(InvalidValue("sslmode")))),
666 };
667 self.ssl_mode(mode);
668 }
669 "sslrootcert" => match std::fs::read(value) {
670 Ok(contents) => {
671 self.ssl_root_cert(&contents);
672 }
673 Err(_) => {
674 return Err(Error::config_parse(Box::new(InvalidValue("sslrootcert"))));
675 }
676 },
677 "sslrootcert_inline" => {
678 self.ssl_root_cert(value.as_bytes());
679 }
680 "host" => {
681 for host in value.split(',') {
682 self.host(host);
683 }
684 }
685 "hostaddr" => {
686 for hostaddr in value.split(',') {
687 let addr = hostaddr
688 .parse()
689 .map_err(|_| Error::config_parse(Box::new(InvalidValue("hostaddr"))))?;
690 self.hostaddr(addr);
691 }
692 }
693 "port" => {
694 for port in value.split(',') {
695 let port = if port.is_empty() {
696 5432
697 } else {
698 port.parse()
699 .map_err(|_| Error::config_parse(Box::new(InvalidValue("port"))))?
700 };
701 self.port(port);
702 }
703 }
704 "connect_timeout" => {
705 let timeout = value
706 .parse::<i64>()
707 .map_err(|_| Error::config_parse(Box::new(InvalidValue("connect_timeout"))))?;
708 if timeout > 0 {
709 self.connect_timeout(Duration::from_secs(timeout as u64));
710 }
711 }
712 "tcp_user_timeout" => {
713 let timeout = value
714 .parse::<i64>()
715 .map_err(|_| Error::config_parse(Box::new(InvalidValue("tcp_user_timeout"))))?;
716 if timeout > 0 {
717 self.tcp_user_timeout(Duration::from_secs(timeout as u64));
718 }
719 }
720 #[cfg(not(target_arch = "wasm32"))]
721 "keepalives" => {
722 let keepalives = value
723 .parse::<u64>()
724 .map_err(|_| Error::config_parse(Box::new(InvalidValue("keepalives"))))?;
725 self.keepalives(keepalives != 0);
726 }
727 #[cfg(not(target_arch = "wasm32"))]
728 "keepalives_idle" => {
729 let keepalives_idle = value
730 .parse::<i64>()
731 .map_err(|_| Error::config_parse(Box::new(InvalidValue("keepalives_idle"))))?;
732 if keepalives_idle > 0 {
733 self.keepalives_idle(Duration::from_secs(keepalives_idle as u64));
734 }
735 }
736 #[cfg(not(target_arch = "wasm32"))]
737 "keepalives_interval" => {
738 let keepalives_interval = value.parse::<i64>().map_err(|_| {
739 Error::config_parse(Box::new(InvalidValue("keepalives_interval")))
740 })?;
741 if keepalives_interval > 0 {
742 self.keepalives_interval(Duration::from_secs(keepalives_interval as u64));
743 }
744 }
745 #[cfg(not(target_arch = "wasm32"))]
746 "keepalives_retries" => {
747 let keepalives_retries = value.parse::<u32>().map_err(|_| {
748 Error::config_parse(Box::new(InvalidValue("keepalives_retries")))
749 })?;
750 self.keepalives_retries(keepalives_retries);
751 }
752 "target_session_attrs" => {
753 let target_session_attrs = match value {
754 "any" => TargetSessionAttrs::Any,
755 "read-write" => TargetSessionAttrs::ReadWrite,
756 "read-only" => TargetSessionAttrs::ReadOnly,
757 _ => {
758 return Err(Error::config_parse(Box::new(InvalidValue(
759 "target_session_attrs",
760 ))));
761 }
762 };
763 self.target_session_attrs(target_session_attrs);
764 }
765 "channel_binding" => {
766 let channel_binding = match value {
767 "disable" => ChannelBinding::Disable,
768 "prefer" => ChannelBinding::Prefer,
769 "require" => ChannelBinding::Require,
770 _ => {
771 return Err(Error::config_parse(Box::new(InvalidValue(
772 "channel_binding",
773 ))))
774 }
775 };
776 self.channel_binding(channel_binding);
777 }
778 "load_balance_hosts" => {
779 let load_balance_hosts = match value {
780 "disable" => LoadBalanceHosts::Disable,
781 "random" => LoadBalanceHosts::Random,
782 _ => {
783 return Err(Error::config_parse(Box::new(InvalidValue(
784 "load_balance_hosts",
785 ))))
786 }
787 };
788 self.load_balance_hosts(load_balance_hosts);
789 }
790 "replication" => {
791 let mode = match value {
792 "off" => None,
793 "true" => Some(ReplicationMode::Physical),
794 "database" => Some(ReplicationMode::Logical),
795 _ => return Err(Error::config_parse(Box::new(InvalidValue("replication")))),
796 };
797 if let Some(mode) = mode {
798 self.replication_mode(mode);
799 }
800 }
801 key => {
802 return Err(Error::config_parse(Box::new(UnknownOption(
803 key.to_string(),
804 ))));
805 }
806 }
807
808 Ok(())
809 }
810
811 #[cfg(feature = "runtime")]
815 pub async fn connect<T>(&self, tls: T) -> Result<(Client, Connection<Socket, T::Stream>), Error>
816 where
817 T: MakeTlsConnect<Socket>,
818 {
819 connect(tls, self).await
820 }
821
822 pub async fn connect_raw<S, T>(
826 &self,
827 stream: S,
828 tls: T,
829 ) -> Result<(Client, Connection<S, T::Stream>), Error>
830 where
831 S: AsyncRead + AsyncWrite + Unpin,
832 T: TlsConnect<S>,
833 {
834 connect_raw(stream, tls, true, self).await
835 }
836}
837
838impl FromStr for Config {
839 type Err = Error;
840
841 fn from_str(s: &str) -> Result<Config, Error> {
842 match UrlParser::parse(s)? {
843 Some(config) => Ok(config),
844 None => Parser::parse(s),
845 }
846 }
847}
848
849impl fmt::Debug for Config {
851 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
852 struct Redaction {}
853 impl fmt::Debug for Redaction {
854 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
855 write!(f, "_")
856 }
857 }
858
859 let mut config_dbg = &mut f.debug_struct("Config");
860 config_dbg = config_dbg
861 .field("user", &self.user)
862 .field("password", &self.password.as_ref().map(|_| Redaction {}))
863 .field("dbname", &self.dbname)
864 .field("options", &self.options)
865 .field("application_name", &self.application_name)
866 .field("ssl_cert", &self.ssl_cert)
867 .field("ssl_key", &self.ssl_key)
868 .field("ssl_mode", &self.ssl_mode)
869 .field("ssl_root_cert", &self.ssl_root_cert)
870 .field("host", &self.host)
871 .field("hostaddr", &self.hostaddr)
872 .field("port", &self.port)
873 .field("connect_timeout", &self.connect_timeout)
874 .field("tcp_user_timeout", &self.tcp_user_timeout)
875 .field("keepalives", &self.keepalives);
876
877 #[cfg(not(target_arch = "wasm32"))]
878 {
879 config_dbg = config_dbg
880 .field("keepalives_idle", &self.keepalive_config.idle)
881 .field("keepalives_interval", &self.keepalive_config.interval)
882 .field("keepalives_retries", &self.keepalive_config.retries);
883 }
884
885 config_dbg
886 .field("target_session_attrs", &self.target_session_attrs)
887 .field("channel_binding", &self.channel_binding)
888 .field("replication", &self.replication_mode)
889 .finish()
890 }
891}
892
893#[derive(Debug)]
894struct UnknownOption(String);
895
896impl fmt::Display for UnknownOption {
897 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
898 write!(fmt, "unknown option `{}`", self.0)
899 }
900}
901
902impl error::Error for UnknownOption {}
903
904#[derive(Debug)]
905struct InvalidValue(&'static str);
906
907impl fmt::Display for InvalidValue {
908 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
909 write!(fmt, "invalid value for option `{}`", self.0)
910 }
911}
912
913impl error::Error for InvalidValue {}
914
915struct Parser<'a> {
916 s: &'a str,
917 it: iter::Peekable<str::CharIndices<'a>>,
918}
919
920impl<'a> Parser<'a> {
921 fn parse(s: &'a str) -> Result<Config, Error> {
922 let mut parser = Parser {
923 s,
924 it: s.char_indices().peekable(),
925 };
926
927 let mut config = Config::new();
928
929 while let Some((key, value)) = parser.parameter()? {
930 config.param(key, &value)?;
931 }
932
933 Ok(config)
934 }
935
936 fn skip_ws(&mut self) {
937 self.take_while(char::is_whitespace);
938 }
939
940 fn take_while<F>(&mut self, f: F) -> &'a str
941 where
942 F: Fn(char) -> bool,
943 {
944 let start = match self.it.peek() {
945 Some(&(i, _)) => i,
946 None => return "",
947 };
948
949 loop {
950 match self.it.peek() {
951 Some(&(_, c)) if f(c) => {
952 self.it.next();
953 }
954 Some(&(i, _)) => return &self.s[start..i],
955 None => return &self.s[start..],
956 }
957 }
958 }
959
960 fn eat(&mut self, target: char) -> Result<(), Error> {
961 match self.it.next() {
962 Some((_, c)) if c == target => Ok(()),
963 Some((i, c)) => {
964 let m = format!(
965 "unexpected character at byte {}: expected `{}` but got `{}`",
966 i, target, c
967 );
968 Err(Error::config_parse(m.into()))
969 }
970 None => Err(Error::config_parse("unexpected EOF".into())),
971 }
972 }
973
974 fn eat_if(&mut self, target: char) -> bool {
975 match self.it.peek() {
976 Some(&(_, c)) if c == target => {
977 self.it.next();
978 true
979 }
980 _ => false,
981 }
982 }
983
984 fn keyword(&mut self) -> Option<&'a str> {
985 let s = self.take_while(|c| match c {
986 c if c.is_whitespace() => false,
987 '=' => false,
988 _ => true,
989 });
990
991 if s.is_empty() {
992 None
993 } else {
994 Some(s)
995 }
996 }
997
998 fn value(&mut self) -> Result<String, Error> {
999 let value = if self.eat_if('\'') {
1000 let value = self.quoted_value()?;
1001 self.eat('\'')?;
1002 value
1003 } else {
1004 self.simple_value()?
1005 };
1006
1007 Ok(value)
1008 }
1009
1010 fn simple_value(&mut self) -> Result<String, Error> {
1011 let mut value = String::new();
1012
1013 while let Some(&(_, c)) = self.it.peek() {
1014 if c.is_whitespace() {
1015 break;
1016 }
1017
1018 self.it.next();
1019 if c == '\\' {
1020 if let Some((_, c2)) = self.it.next() {
1021 value.push(c2);
1022 }
1023 } else {
1024 value.push(c);
1025 }
1026 }
1027
1028 if value.is_empty() {
1029 return Err(Error::config_parse("unexpected EOF".into()));
1030 }
1031
1032 Ok(value)
1033 }
1034
1035 fn quoted_value(&mut self) -> Result<String, Error> {
1036 let mut value = String::new();
1037
1038 while let Some(&(_, c)) = self.it.peek() {
1039 if c == '\'' {
1040 return Ok(value);
1041 }
1042
1043 self.it.next();
1044 if c == '\\' {
1045 if let Some((_, c2)) = self.it.next() {
1046 value.push(c2);
1047 }
1048 } else {
1049 value.push(c);
1050 }
1051 }
1052
1053 Err(Error::config_parse(
1054 "unterminated quoted connection parameter value".into(),
1055 ))
1056 }
1057
1058 fn parameter(&mut self) -> Result<Option<(&'a str, String)>, Error> {
1059 self.skip_ws();
1060 let keyword = match self.keyword() {
1061 Some(keyword) => keyword,
1062 None => return Ok(None),
1063 };
1064 self.skip_ws();
1065 self.eat('=')?;
1066 self.skip_ws();
1067 let value = self.value()?;
1068
1069 Ok(Some((keyword, value)))
1070 }
1071}
1072
1073struct UrlParser<'a> {
1075 s: &'a str,
1076 config: Config,
1077}
1078
1079impl<'a> UrlParser<'a> {
1080 fn parse(s: &'a str) -> Result<Option<Config>, Error> {
1081 let s = match Self::remove_url_prefix(s) {
1082 Some(s) => s,
1083 None => return Ok(None),
1084 };
1085
1086 let mut parser = UrlParser {
1087 s,
1088 config: Config::new(),
1089 };
1090
1091 parser.parse_credentials()?;
1092 parser.parse_host()?;
1093 parser.parse_path()?;
1094 parser.parse_params()?;
1095
1096 Ok(Some(parser.config))
1097 }
1098
1099 fn remove_url_prefix(s: &str) -> Option<&str> {
1100 for prefix in &["postgres://", "postgresql://"] {
1101 if let Some(stripped) = s.strip_prefix(prefix) {
1102 return Some(stripped);
1103 }
1104 }
1105
1106 None
1107 }
1108
1109 fn take_until(&mut self, end: &[char]) -> Option<&'a str> {
1110 match self.s.find(end) {
1111 Some(pos) => {
1112 let (head, tail) = self.s.split_at(pos);
1113 self.s = tail;
1114 Some(head)
1115 }
1116 None => None,
1117 }
1118 }
1119
1120 fn take_all(&mut self) -> &'a str {
1121 mem::take(&mut self.s)
1122 }
1123
1124 fn eat_byte(&mut self) {
1125 self.s = &self.s[1..];
1126 }
1127
1128 fn parse_credentials(&mut self) -> Result<(), Error> {
1129 let creds = match self.take_until(&['@']) {
1130 Some(creds) => creds,
1131 None => return Ok(()),
1132 };
1133 self.eat_byte();
1134
1135 let mut it = creds.splitn(2, ':');
1136 let user = self.decode(it.next().unwrap())?;
1137 self.config.user(user);
1138
1139 if let Some(password) = it.next() {
1140 let password = Cow::from(percent_encoding::percent_decode(password.as_bytes()));
1141 self.config.password(password);
1142 }
1143
1144 Ok(())
1145 }
1146
1147 fn parse_host(&mut self) -> Result<(), Error> {
1148 let host = match self.take_until(&['/', '?']) {
1149 Some(host) => host,
1150 None => self.take_all(),
1151 };
1152
1153 if host.is_empty() {
1154 return Ok(());
1155 }
1156
1157 for chunk in host.split(',') {
1158 let (host, port) = if chunk.starts_with('[') {
1159 let idx = match chunk.find(']') {
1160 Some(idx) => idx,
1161 None => return Err(Error::config_parse(InvalidValue("host").into())),
1162 };
1163
1164 let host = &chunk[1..idx];
1165 let remaining = &chunk[idx + 1..];
1166 let port = if let Some(port) = remaining.strip_prefix(':') {
1167 Some(port)
1168 } else if remaining.is_empty() {
1169 None
1170 } else {
1171 return Err(Error::config_parse(InvalidValue("host").into()));
1172 };
1173
1174 (host, port)
1175 } else {
1176 let mut it = chunk.splitn(2, ':');
1177 (it.next().unwrap(), it.next())
1178 };
1179
1180 self.host_param(host)?;
1181 let port = self.decode(port.unwrap_or("5432"))?;
1182 self.config.param("port", &port)?;
1183 }
1184
1185 Ok(())
1186 }
1187
1188 fn parse_path(&mut self) -> Result<(), Error> {
1189 if !self.s.starts_with('/') {
1190 return Ok(());
1191 }
1192 self.eat_byte();
1193
1194 let dbname = match self.take_until(&['?']) {
1195 Some(dbname) => dbname,
1196 None => self.take_all(),
1197 };
1198
1199 if !dbname.is_empty() {
1200 self.config.dbname(self.decode(dbname)?);
1201 }
1202
1203 Ok(())
1204 }
1205
1206 fn parse_params(&mut self) -> Result<(), Error> {
1207 if !self.s.starts_with('?') {
1208 return Ok(());
1209 }
1210 self.eat_byte();
1211
1212 while !self.s.is_empty() {
1213 let key = match self.take_until(&['=']) {
1214 Some(key) => self.decode(key)?,
1215 None => return Err(Error::config_parse("unterminated parameter".into())),
1216 };
1217 self.eat_byte();
1218
1219 let value = match self.take_until(&['&']) {
1220 Some(value) => {
1221 self.eat_byte();
1222 value
1223 }
1224 None => self.take_all(),
1225 };
1226
1227 if key == "host" {
1228 self.host_param(value)?;
1229 } else {
1230 let value = self.decode(value)?;
1231 self.config.param(&key, &value)?;
1232 }
1233 }
1234
1235 Ok(())
1236 }
1237
1238 #[cfg(unix)]
1239 fn host_param(&mut self, s: &str) -> Result<(), Error> {
1240 let decoded = Cow::from(percent_encoding::percent_decode(s.as_bytes()));
1241 if decoded.first() == Some(&b'/') {
1242 self.config.host_path(OsStr::from_bytes(&decoded));
1243 } else {
1244 let decoded = str::from_utf8(&decoded).map_err(|e| Error::config_parse(Box::new(e)))?;
1245 self.config.host(decoded);
1246 }
1247
1248 Ok(())
1249 }
1250
1251 #[cfg(not(unix))]
1252 fn host_param(&mut self, s: &str) -> Result<(), Error> {
1253 let s = self.decode(s)?;
1254 self.config.param("host", &s)
1255 }
1256
1257 fn decode(&self, s: &'a str) -> Result<Cow<'a, str>, Error> {
1258 percent_encoding::percent_decode(s.as_bytes())
1259 .decode_utf8()
1260 .map_err(|e| Error::config_parse(e.into()))
1261 }
1262}
1263
1264#[cfg(test)]
1265mod tests {
1266 use std::net::IpAddr;
1267
1268 use crate::{config::Host, Config};
1269
1270 #[test]
1271 fn test_simple_parsing() {
1272 let s = "user=pass_user dbname=postgres host=host1,host2 hostaddr=127.0.0.1,127.0.0.2 port=26257";
1273 let config = s.parse::<Config>().unwrap();
1274 assert_eq!(Some("pass_user"), config.get_user());
1275 assert_eq!(Some("postgres"), config.get_dbname());
1276 assert_eq!(
1277 [
1278 Host::Tcp("host1".to_string()),
1279 Host::Tcp("host2".to_string())
1280 ],
1281 config.get_hosts(),
1282 );
1283
1284 assert_eq!(
1285 [
1286 "127.0.0.1".parse::<IpAddr>().unwrap(),
1287 "127.0.0.2".parse::<IpAddr>().unwrap()
1288 ],
1289 config.get_hostaddrs(),
1290 );
1291
1292 assert_eq!(1, 1);
1293 }
1294
1295 #[test]
1296 fn test_invalid_hostaddr_parsing() {
1297 let s = "user=pass_user dbname=postgres host=host1 hostaddr=127.0.0 port=26257";
1298 s.parse::<Config>().err().unwrap();
1299 }
1300}