1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
// Copyright Materialize, Inc. and contributors. All rights reserved.
//
// Use of this software is governed by the Business Source License
// included in the LICENSE file.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0.

use std::str::FromStr;
use tokio_postgres::Client;

use mz_ssh_util::tunnel_manager::SshTunnelManager;

use crate::{simple_query_opt, Config, PostgresError};

#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub enum WalLevel {
    Minimal,
    Replica,
    Logical,
}

impl std::str::FromStr for WalLevel {
    type Err = anyhow::Error;
    fn from_str(s: &str) -> Result<Self, Self::Err> {
        match s {
            "minimal" => Ok(Self::Minimal),
            "replica" => Ok(Self::Replica),
            "logical" => Ok(Self::Logical),
            o => Err(anyhow::anyhow!("unknown wal_level {}", o)),
        }
    }
}

impl std::fmt::Display for WalLevel {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        let s = match self {
            WalLevel::Minimal => "minimal",
            WalLevel::Replica => "replica",
            WalLevel::Logical => "logical",
        };

        f.write_str(s)
    }
}

#[mz_ore::test]
fn test_wal_level_max() {
    // Ensure `WalLevel::Logical` is the max among all levels.
    for o in [WalLevel::Minimal, WalLevel::Replica, WalLevel::Logical] {
        assert_eq!(WalLevel::Logical, WalLevel::Logical.max(o))
    }
}

pub async fn get_wal_level(
    ssh_tunnel_manager: &SshTunnelManager,
    config: &Config,
) -> Result<WalLevel, PostgresError> {
    let client = config
        .connect("wal_level_check", ssh_tunnel_manager)
        .await?;
    let wal_level = client.query_one("SHOW wal_level", &[]).await?;
    let wal_level: String = wal_level.get("wal_level");
    Ok(WalLevel::from_str(&wal_level)?)
}

pub async fn get_max_wal_senders(
    ssh_tunnel_manager: &SshTunnelManager,
    config: &Config,
) -> Result<i64, PostgresError> {
    let client = config
        .connect("max_wal_senders_check", ssh_tunnel_manager)
        .await?;
    let max_wal_senders = client
        .query_one(
            "SELECT CAST(current_setting('max_wal_senders') AS int8) AS max_wal_senders",
            &[],
        )
        .await?;
    Ok(max_wal_senders.get("max_wal_senders"))
}

pub async fn available_replication_slots(
    ssh_tunnel_manager: &SshTunnelManager,
    config: &Config,
) -> Result<i64, PostgresError> {
    let client = config
        .connect("postgres_check_replication_slots", ssh_tunnel_manager)
        .await?;

    let available_replication_slots = client
        .query_one(
            "SELECT
            CAST(current_setting('max_replication_slots') AS int8)
              - (SELECT count(*) FROM pg_catalog.pg_replication_slots)
              AS available_replication_slots;",
            &[],
        )
        .await?;

    let available_replication_slots: i64 =
        available_replication_slots.get("available_replication_slots");

    Ok(available_replication_slots)
}

pub async fn drop_replication_slots(
    ssh_tunnel_manager: &SshTunnelManager,
    config: Config,
    slots: &[&str],
) -> Result<(), PostgresError> {
    let client = config
        .connect("postgres_drop_replication_slots", ssh_tunnel_manager)
        .await?;
    let replication_client = config.connect_replication(ssh_tunnel_manager).await?;
    for slot in slots {
        let rows = client
            .query(
                "SELECT active_pid FROM pg_replication_slots WHERE slot_name = $1::TEXT",
                &[&slot],
            )
            .await?;
        match rows.len() {
            0 => {
                // DROP_REPLICATION_SLOT will error if the slot does not exist
                tracing::info!(
                    "drop_replication_slots called on non-existent slot {}",
                    slot
                );
                continue;
            }
            1 => {
                replication_client
                    .simple_query(&format!("DROP_REPLICATION_SLOT {} WAIT", slot))
                    .await?;
            }
            _ => {
                return Err(PostgresError::Generic(anyhow::anyhow!(
                    "multiple pg_replication_slots entries for slot {}",
                    &slot
                )))
            }
        }
    }
    Ok(())
}

pub async fn get_timeline_id(replication_client: &Client) -> Result<u64, PostgresError> {
    if let Some(r) = simple_query_opt(replication_client, "IDENTIFY_SYSTEM").await? {
        r.get("timeline")
            .expect("Returns a row with a timeline ID")
            .parse::<u64>()
            .map_err(|err| {
                PostgresError::Generic(anyhow::anyhow!(
                    "Failed to parse timeline ID from IDENTIFY_SYSTEM: {}",
                    err
                ))
            })
    } else {
        Err(PostgresError::Generic(anyhow::anyhow!(
            "IDENTIFY_SYSTEM did not return a result row"
        )))
    }
}