misc.python.materialize.util

Various utilities

  1# Copyright Materialize, Inc. and contributors. All rights reserved.
  2#
  3# Use of this software is governed by the Business Source License
  4# included in the LICENSE file at the root of this repository.
  5#
  6# As of the Change Date specified in that file, in accordance with
  7# the Business Source License, use of this software will be governed
  8# by the Apache License, Version 2.0.
  9
 10"""Various utilities"""
 11
 12from __future__ import annotations
 13
 14import hashlib
 15import json
 16import os
 17import pathlib
 18import random
 19import subprocess
 20from collections.abc import Iterator
 21from dataclasses import dataclass
 22from enum import Enum
 23from pathlib import Path
 24from threading import Thread
 25from typing import Protocol, TypeVar
 26from urllib.parse import parse_qs, quote, unquote, urlparse
 27
 28import psycopg
 29import xxhash
 30import zstandard
 31
 32MZ_ROOT = Path(os.environ["MZ_ROOT"])
 33
 34
 35def nonce(digits: int) -> str:
 36    return "".join(random.choice("0123456789abcdef") for _ in range(digits))
 37
 38
 39T = TypeVar("T")
 40
 41
 42def all_subclasses(cls: type[T]) -> set[type[T]]:
 43    """Returns a recursive set of all subclasses of a class"""
 44    sc = cls.__subclasses__()
 45    return set(sc).union([subclass for c in sc for subclass in all_subclasses(c)])
 46
 47
 48NAUGHTY_STRINGS = None
 49
 50
 51def naughty_strings() -> list[str]:
 52    # Naughty strings taken from https://github.com/minimaxir/big-list-of-naughty-strings
 53    # Under MIT license, Copyright (c) 2015-2020 Max Woolf
 54    global NAUGHTY_STRINGS
 55    if not NAUGHTY_STRINGS:
 56        with open(MZ_ROOT / "misc" / "python" / "materialize" / "blns.json") as f:
 57            NAUGHTY_STRINGS = json.load(f)
 58    return NAUGHTY_STRINGS
 59
 60
 61class YesNoOnce(Enum):
 62    YES = 1
 63    NO = 2
 64    ONCE = 3
 65
 66
 67class PropagatingThread(Thread):
 68    def run(self):
 69        self.exc = None
 70        try:
 71            self.ret = self._target(*self._args, **self._kwargs)  # type: ignore
 72        except BaseException as e:
 73            self.exc = e
 74
 75    def join(self, timeout=None):
 76        super().join(timeout)
 77        if self.exc:
 78            raise self.exc
 79        if hasattr(self, "ret"):
 80            return self.ret
 81
 82
 83def decompress_zst_to_directory(
 84    zst_file_path: str, destination_dir_path: str
 85) -> list[str]:
 86    """
 87    :return: file paths in destination dir
 88    """
 89    input_file = pathlib.Path(zst_file_path)
 90    output_paths = []
 91
 92    with open(input_file, "rb") as compressed:
 93        decompressor = zstandard.ZstdDecompressor()
 94        output_path = pathlib.Path(destination_dir_path) / input_file.stem
 95        output_paths.append(str(output_path))
 96        with open(output_path, "wb") as destination:
 97            decompressor.copy_stream(compressed, destination)
 98
 99    return output_paths
100
101
102def ensure_dir_exists(path_to_dir: str) -> None:
103    subprocess.run(
104        [
105            "mkdir",
106            "-p",
107            f"{path_to_dir}",
108        ],
109        check=True,
110    )
111
112
113def sha256_of_utf8_string(value: str) -> str:
114    return hashlib.sha256(bytes(value, encoding="utf-8")).hexdigest()
115
116
117def stable_int_hash(*values: str) -> int:
118    if len(values) == 1:
119        return xxhash.xxh64(values[0], seed=0).intdigest()
120
121    return stable_int_hash(",".join([str(stable_int_hash(entry)) for entry in values]))
122
123
124class HasName(Protocol):
125    name: str
126
127
128U = TypeVar("U", bound=HasName)
129
130
131def selected_by_name(selected: list[str], objs: list[U]) -> Iterator[U]:
132    for name in selected:
133        for obj in objs:
134            if obj.name == name:
135                yield obj
136                break
137        else:
138            raise ValueError(
139                f"Unknown object with name {name} in {[obj.name for obj in objs]}"
140            )
141
142
143@dataclass
144class PgConnInfo:
145    user: str
146    host: str
147    port: int
148    database: str
149    password: str | None = None
150    ssl: bool = False
151    cluster: str | None = None
152    autocommit: bool = False
153
154    def connect(self) -> psycopg.Connection:
155        conn = psycopg.connect(
156            host=self.host,
157            port=self.port,
158            user=self.user,
159            password=self.password,
160            dbname=self.database,
161            sslmode="require" if self.ssl else None,
162        )
163        # Set SO_LINGER(1, 0) so close() sends RST instead of FIN, bypassing
164        # TIME_WAIT. Prevents exhausting the ~28k ephemeral port range under
165        # high connection churn (e.g. benchmarks doing rapid connect/disconnect).
166        self._set_linger(conn)
167        if self.autocommit:
168            conn.autocommit = True
169        if self.cluster:
170            with conn.cursor() as cur:
171                cur.execute(f"SET cluster = {self.cluster}".encode())
172        return conn
173
174    @staticmethod
175    def _set_linger(conn: psycopg.Connection) -> None:
176        import socket
177        import struct
178
179        fd = conn.pgconn.socket
180        if fd < 0:
181            return
182        sock = socket.socket(fileno=fd)
183        try:
184            sock.setsockopt(
185                socket.SOL_SOCKET, socket.SO_LINGER, struct.pack("ii", 1, 0)
186            )
187        finally:
188            sock.detach()
189
190    def to_conn_string(self) -> str:
191        return (
192            f"postgres://{quote(self.user)}:{quote(self.password)}@{self.host}:{self.port}/{quote(self.database)}"
193            if self.password
194            else f"postgres://{quote(self.user)}@{self.host}:{self.port}/{quote(self.database)}"
195        )
196
197
198def parse_pg_conn_string(conn_string: str) -> PgConnInfo:
199    """Not supported natively by pg8000, so we have to parse ourselves"""
200    url = urlparse(conn_string)
201    query_params = parse_qs(url.query)
202    assert url.username
203    assert url.hostname
204    return PgConnInfo(
205        user=unquote(url.username),
206        password=unquote(url.password) if url.password else url.password,
207        host=url.hostname,
208        port=url.port or 5432,
209        database=url.path.lstrip("/"),
210        ssl=query_params.get("sslmode", ["disable"])[-1] != "disable",
211    )
212
213
214FILTERED_ARGS = [
215    # Secrets
216    "mzp_",
217    "-----BEGIN PRIVATE KEY-----",
218    "-----BEGIN CERTIFICATE-----",
219    "confluent-api-key=",
220    "confluent-api-secret=",
221    "aws-access-key-id=",
222    "aws-secret-access-key=",
223    "default-sql-server-password=",
224    # Not a secret, but too spammy, filter too
225    "CLUSTER_REPLICA_SIZES",
226    "cluster-replica-sizes=",
227]
228
229
230def filter_cmd(args: list[str]) -> list[str]:
231    """Don't print out secrets in test logs"""
232    return [
233        (
234            "[REDACTED]"
235            if any(filtered_arg in arg for filtered_arg in FILTERED_ARGS)
236            else arg
237        )
238        for arg in args
239    ]
MZ_ROOT = PosixPath('/var/lib/buildkite-agent/builds/buildkite-15f2293-i-04151a3597e2b37a7-1/materialize/deploy')
def nonce(digits: int) -> str:
36def nonce(digits: int) -> str:
37    return "".join(random.choice("0123456789abcdef") for _ in range(digits))
def all_subclasses(cls: type[~T]) -> set[type[~T]]:
43def all_subclasses(cls: type[T]) -> set[type[T]]:
44    """Returns a recursive set of all subclasses of a class"""
45    sc = cls.__subclasses__()
46    return set(sc).union([subclass for c in sc for subclass in all_subclasses(c)])

Returns a recursive set of all subclasses of a class

NAUGHTY_STRINGS = None
def naughty_strings() -> list[str]:
52def naughty_strings() -> list[str]:
53    # Naughty strings taken from https://github.com/minimaxir/big-list-of-naughty-strings
54    # Under MIT license, Copyright (c) 2015-2020 Max Woolf
55    global NAUGHTY_STRINGS
56    if not NAUGHTY_STRINGS:
57        with open(MZ_ROOT / "misc" / "python" / "materialize" / "blns.json") as f:
58            NAUGHTY_STRINGS = json.load(f)
59    return NAUGHTY_STRINGS
class YesNoOnce(enum.Enum):
62class YesNoOnce(Enum):
63    YES = 1
64    NO = 2
65    ONCE = 3
YES = <YesNoOnce.YES: 1>
NO = <YesNoOnce.NO: 2>
ONCE = <YesNoOnce.ONCE: 3>
class PropagatingThread(threading.Thread):
68class PropagatingThread(Thread):
69    def run(self):
70        self.exc = None
71        try:
72            self.ret = self._target(*self._args, **self._kwargs)  # type: ignore
73        except BaseException as e:
74            self.exc = e
75
76    def join(self, timeout=None):
77        super().join(timeout)
78        if self.exc:
79            raise self.exc
80        if hasattr(self, "ret"):
81            return self.ret

A class that represents a thread of control.

This class can be safely subclassed in a limited fashion. There are two ways to specify the activity: by passing a callable object to the constructor, or by overriding the run() method in a subclass.

def run(self):
69    def run(self):
70        self.exc = None
71        try:
72            self.ret = self._target(*self._args, **self._kwargs)  # type: ignore
73        except BaseException as e:
74            self.exc = e

Method representing the thread's activity.

You may override this method in a subclass. The standard run() method invokes the callable object passed to the object's constructor as the target argument, if any, with sequential and keyword arguments taken from the args and kwargs arguments, respectively.

def join(self, timeout=None):
76    def join(self, timeout=None):
77        super().join(timeout)
78        if self.exc:
79            raise self.exc
80        if hasattr(self, "ret"):
81            return self.ret

Wait until the thread terminates.

This blocks the calling thread until the thread whose join() method is called terminates -- either normally or through an unhandled exception or until the optional timeout occurs.

When the timeout argument is present and not None, it should be a floating point number specifying a timeout for the operation in seconds (or fractions thereof). As join() always returns None, you must call is_alive() after join() to decide whether a timeout happened -- if the thread is still alive, the join() call timed out.

When the timeout argument is not present or None, the operation will block until the thread terminates.

A thread can be join()ed many times.

join() raises a RuntimeError if an attempt is made to join the current thread as that would cause a deadlock. It is also an error to join() a thread before it has been started and attempts to do so raises the same exception.

def decompress_zst_to_directory(zst_file_path: str, destination_dir_path: str) -> list[str]:
 84def decompress_zst_to_directory(
 85    zst_file_path: str, destination_dir_path: str
 86) -> list[str]:
 87    """
 88    :return: file paths in destination dir
 89    """
 90    input_file = pathlib.Path(zst_file_path)
 91    output_paths = []
 92
 93    with open(input_file, "rb") as compressed:
 94        decompressor = zstandard.ZstdDecompressor()
 95        output_path = pathlib.Path(destination_dir_path) / input_file.stem
 96        output_paths.append(str(output_path))
 97        with open(output_path, "wb") as destination:
 98            decompressor.copy_stream(compressed, destination)
 99
100    return output_paths
Returns

file paths in destination dir

def ensure_dir_exists(path_to_dir: str) -> None:
103def ensure_dir_exists(path_to_dir: str) -> None:
104    subprocess.run(
105        [
106            "mkdir",
107            "-p",
108            f"{path_to_dir}",
109        ],
110        check=True,
111    )
def sha256_of_utf8_string(value: str) -> str:
114def sha256_of_utf8_string(value: str) -> str:
115    return hashlib.sha256(bytes(value, encoding="utf-8")).hexdigest()
def stable_int_hash(*values: str) -> int:
118def stable_int_hash(*values: str) -> int:
119    if len(values) == 1:
120        return xxhash.xxh64(values[0], seed=0).intdigest()
121
122    return stable_int_hash(",".join([str(stable_int_hash(entry)) for entry in values]))
class HasName(typing.Protocol):
125class HasName(Protocol):
126    name: str

Base class for protocol classes.

Protocol classes are defined as::

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).

For example::

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::

class GenProto[T](Protocol):
    def meth(self) -> T:
        ...
HasName(*args, **kwargs)
1739def _no_init_or_replace_init(self, *args, **kwargs):
1740    cls = type(self)
1741
1742    if cls._is_protocol:
1743        raise TypeError('Protocols cannot be instantiated')
1744
1745    # Already using a custom `__init__`. No need to calculate correct
1746    # `__init__` to call. This can lead to RecursionError. See bpo-45121.
1747    if cls.__init__ is not _no_init_or_replace_init:
1748        return
1749
1750    # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`.
1751    # The first instantiation of the subclass will call `_no_init_or_replace_init` which
1752    # searches for a proper new `__init__` in the MRO. The new `__init__`
1753    # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent
1754    # instantiation of the protocol subclass will thus use the new
1755    # `__init__` and no longer call `_no_init_or_replace_init`.
1756    for base in cls.__mro__:
1757        init = base.__dict__.get('__init__', _no_init_or_replace_init)
1758        if init is not _no_init_or_replace_init:
1759            cls.__init__ = init
1760            break
1761    else:
1762        # should not happen
1763        cls.__init__ = object.__init__
1764
1765    cls.__init__(self, *args, **kwargs)
name: str
def selected_by_name(selected: list[str], objs: list[~U]) -> Iterator[~U]:
132def selected_by_name(selected: list[str], objs: list[U]) -> Iterator[U]:
133    for name in selected:
134        for obj in objs:
135            if obj.name == name:
136                yield obj
137                break
138        else:
139            raise ValueError(
140                f"Unknown object with name {name} in {[obj.name for obj in objs]}"
141            )
@dataclass
class PgConnInfo:
144@dataclass
145class PgConnInfo:
146    user: str
147    host: str
148    port: int
149    database: str
150    password: str | None = None
151    ssl: bool = False
152    cluster: str | None = None
153    autocommit: bool = False
154
155    def connect(self) -> psycopg.Connection:
156        conn = psycopg.connect(
157            host=self.host,
158            port=self.port,
159            user=self.user,
160            password=self.password,
161            dbname=self.database,
162            sslmode="require" if self.ssl else None,
163        )
164        # Set SO_LINGER(1, 0) so close() sends RST instead of FIN, bypassing
165        # TIME_WAIT. Prevents exhausting the ~28k ephemeral port range under
166        # high connection churn (e.g. benchmarks doing rapid connect/disconnect).
167        self._set_linger(conn)
168        if self.autocommit:
169            conn.autocommit = True
170        if self.cluster:
171            with conn.cursor() as cur:
172                cur.execute(f"SET cluster = {self.cluster}".encode())
173        return conn
174
175    @staticmethod
176    def _set_linger(conn: psycopg.Connection) -> None:
177        import socket
178        import struct
179
180        fd = conn.pgconn.socket
181        if fd < 0:
182            return
183        sock = socket.socket(fileno=fd)
184        try:
185            sock.setsockopt(
186                socket.SOL_SOCKET, socket.SO_LINGER, struct.pack("ii", 1, 0)
187            )
188        finally:
189            sock.detach()
190
191    def to_conn_string(self) -> str:
192        return (
193            f"postgres://{quote(self.user)}:{quote(self.password)}@{self.host}:{self.port}/{quote(self.database)}"
194            if self.password
195            else f"postgres://{quote(self.user)}@{self.host}:{self.port}/{quote(self.database)}"
196        )
PgConnInfo( user: str, host: str, port: int, database: str, password: str | None = None, ssl: bool = False, cluster: str | None = None, autocommit: bool = False)
user: str
host: str
port: int
database: str
password: str | None = None
ssl: bool = False
cluster: str | None = None
autocommit: bool = False
def connect(self) -> psycopg.Connection:
155    def connect(self) -> psycopg.Connection:
156        conn = psycopg.connect(
157            host=self.host,
158            port=self.port,
159            user=self.user,
160            password=self.password,
161            dbname=self.database,
162            sslmode="require" if self.ssl else None,
163        )
164        # Set SO_LINGER(1, 0) so close() sends RST instead of FIN, bypassing
165        # TIME_WAIT. Prevents exhausting the ~28k ephemeral port range under
166        # high connection churn (e.g. benchmarks doing rapid connect/disconnect).
167        self._set_linger(conn)
168        if self.autocommit:
169            conn.autocommit = True
170        if self.cluster:
171            with conn.cursor() as cur:
172                cur.execute(f"SET cluster = {self.cluster}".encode())
173        return conn
def to_conn_string(self) -> str:
191    def to_conn_string(self) -> str:
192        return (
193            f"postgres://{quote(self.user)}:{quote(self.password)}@{self.host}:{self.port}/{quote(self.database)}"
194            if self.password
195            else f"postgres://{quote(self.user)}@{self.host}:{self.port}/{quote(self.database)}"
196        )
def parse_pg_conn_string(conn_string: str) -> PgConnInfo:
199def parse_pg_conn_string(conn_string: str) -> PgConnInfo:
200    """Not supported natively by pg8000, so we have to parse ourselves"""
201    url = urlparse(conn_string)
202    query_params = parse_qs(url.query)
203    assert url.username
204    assert url.hostname
205    return PgConnInfo(
206        user=unquote(url.username),
207        password=unquote(url.password) if url.password else url.password,
208        host=url.hostname,
209        port=url.port or 5432,
210        database=url.path.lstrip("/"),
211        ssl=query_params.get("sslmode", ["disable"])[-1] != "disable",
212    )

Not supported natively by pg8000, so we have to parse ourselves

FILTERED_ARGS = ['mzp_', '-----BEGIN PRIVATE KEY-----', '-----BEGIN CERTIFICATE-----', 'confluent-api-key=', 'confluent-api-secret=', 'aws-access-key-id=', 'aws-secret-access-key=', 'default-sql-server-password=', 'CLUSTER_REPLICA_SIZES', 'cluster-replica-sizes=']
def filter_cmd(args: list[str]) -> list[str]:
231def filter_cmd(args: list[str]) -> list[str]:
232    """Don't print out secrets in test logs"""
233    return [
234        (
235            "[REDACTED]"
236            if any(filtered_arg in arg for filtered_arg in FILTERED_ARGS)
237            else arg
238        )
239        for arg in args
240    ]

Don't print out secrets in test logs