misc.python.materialize.util

Various utilities

  1# Copyright Materialize, Inc. and contributors. All rights reserved.
  2#
  3# Use of this software is governed by the Business Source License
  4# included in the LICENSE file at the root of this repository.
  5#
  6# As of the Change Date specified in that file, in accordance with
  7# the Business Source License, use of this software will be governed
  8# by the Apache License, Version 2.0.
  9
 10"""Various utilities"""
 11
 12from __future__ import annotations
 13
 14import filecmp
 15import hashlib
 16import json
 17import os
 18import pathlib
 19import random
 20import subprocess
 21from collections.abc import Iterator
 22from dataclasses import dataclass
 23from enum import Enum
 24from pathlib import Path
 25from threading import Thread
 26from typing import Protocol, TypeVar
 27from urllib.parse import parse_qs, quote, unquote, urlparse
 28
 29import psycopg
 30import xxhash
 31import zstandard
 32
 33MZ_ROOT = Path(os.environ["MZ_ROOT"])
 34
 35
 36def nonce(digits: int) -> str:
 37    return "".join(random.choice("0123456789abcdef") for _ in range(digits))
 38
 39
 40T = TypeVar("T")
 41
 42
 43def all_subclasses(cls: type[T]) -> set[type[T]]:
 44    """Returns a recursive set of all subclasses of a class"""
 45    sc = cls.__subclasses__()
 46    return set(sc).union([subclass for c in sc for subclass in all_subclasses(c)])
 47
 48
 49NAUGHTY_STRINGS = None
 50
 51
 52def naughty_strings() -> list[str]:
 53    # Naughty strings taken from https://github.com/minimaxir/big-list-of-naughty-strings
 54    # Under MIT license, Copyright (c) 2015-2020 Max Woolf
 55    global NAUGHTY_STRINGS
 56    if not NAUGHTY_STRINGS:
 57        with open(MZ_ROOT / "misc" / "python" / "materialize" / "blns.json") as f:
 58            NAUGHTY_STRINGS = json.load(f)
 59    return NAUGHTY_STRINGS
 60
 61
 62class YesNoOnce(Enum):
 63    YES = 1
 64    NO = 2
 65    ONCE = 3
 66
 67
 68class PropagatingThread(Thread):
 69    def run(self):
 70        self.exc = None
 71        try:
 72            self.ret = self._target(*self._args, **self._kwargs)  # type: ignore
 73        except BaseException as e:
 74            self.exc = e
 75
 76    def join(self, timeout=None):
 77        super().join(timeout)
 78        if self.exc:
 79            raise self.exc
 80        if hasattr(self, "ret"):
 81            return self.ret
 82
 83
 84def decompress_zst_to_directory(
 85    zst_file_path: str, destination_dir_path: str
 86) -> list[str]:
 87    """
 88    :return: file paths in destination dir
 89    """
 90    input_file = pathlib.Path(zst_file_path)
 91    output_paths = []
 92
 93    with open(input_file, "rb") as compressed:
 94        decompressor = zstandard.ZstdDecompressor()
 95        output_path = pathlib.Path(destination_dir_path) / input_file.stem
 96        output_paths.append(str(output_path))
 97        with open(output_path, "wb") as destination:
 98            decompressor.copy_stream(compressed, destination)
 99
100    return output_paths
101
102
103def ensure_dir_exists(path_to_dir: str) -> None:
104    subprocess.run(
105        [
106            "mkdir",
107            "-p",
108            f"{path_to_dir}",
109        ],
110        check=True,
111    )
112
113
114def sha256_of_file(path: str | Path) -> str:
115    sha256 = hashlib.sha256()
116    with open(path, "rb") as f:
117        for block in iter(lambda: f.read(filecmp.BUFSIZE), b""):
118            sha256.update(block)
119    return sha256.hexdigest()
120
121
122def sha256_of_utf8_string(value: str) -> str:
123    return hashlib.sha256(bytes(value, encoding="utf-8")).hexdigest()
124
125
126def stable_int_hash(*values: str) -> int:
127    if len(values) == 1:
128        return xxhash.xxh64(values[0], seed=0).intdigest()
129
130    return stable_int_hash(",".join([str(stable_int_hash(entry)) for entry in values]))
131
132
133class HasName(Protocol):
134    name: str
135
136
137U = TypeVar("U", bound=HasName)
138
139
140def selected_by_name(selected: list[str], objs: list[U]) -> Iterator[U]:
141    for name in selected:
142        for obj in objs:
143            if obj.name == name:
144                yield obj
145                break
146        else:
147            raise ValueError(
148                f"Unknown object with name {name} in {[obj.name for obj in objs]}"
149            )
150
151
152@dataclass
153class PgConnInfo:
154    user: str
155    host: str
156    port: int
157    database: str
158    password: str | None = None
159    ssl: bool = False
160    cluster: str | None = None
161    autocommit: bool = False
162
163    def connect(self) -> psycopg.Connection:
164        conn = psycopg.connect(
165            host=self.host,
166            port=self.port,
167            user=self.user,
168            password=self.password,
169            dbname=self.database,
170            sslmode="require" if self.ssl else None,
171        )
172        if self.autocommit:
173            conn.autocommit = True
174        if self.cluster:
175            with conn.cursor() as cur:
176                cur.execute(f"SET cluster = {self.cluster}".encode())
177        return conn
178
179    def to_conn_string(self) -> str:
180        return (
181            f"postgres://{quote(self.user)}:{quote(self.password)}@{self.host}:{self.port}/{quote(self.database)}"
182            if self.password
183            else f"postgres://{quote(self.user)}@{self.host}:{self.port}/{quote(self.database)}"
184        )
185
186
187def parse_pg_conn_string(conn_string: str) -> PgConnInfo:
188    """Not supported natively by pg8000, so we have to parse ourselves"""
189    url = urlparse(conn_string)
190    query_params = parse_qs(url.query)
191    assert url.username
192    assert url.hostname
193    return PgConnInfo(
194        user=unquote(url.username),
195        password=unquote(url.password) if url.password else url.password,
196        host=url.hostname,
197        port=url.port or 5432,
198        database=url.path.lstrip("/"),
199        ssl=query_params.get("sslmode", ["disable"])[-1] != "disable",
200    )
MZ_ROOT = PosixPath('/var/lib/buildkite-agent/builds/buildkite-15f2293-i-0fd014cc4b97c0422-1/materialize/deploy')
def nonce(digits: int) -> str:
37def nonce(digits: int) -> str:
38    return "".join(random.choice("0123456789abcdef") for _ in range(digits))
def all_subclasses(cls: type[~T]) -> set[type[~T]]:
44def all_subclasses(cls: type[T]) -> set[type[T]]:
45    """Returns a recursive set of all subclasses of a class"""
46    sc = cls.__subclasses__()
47    return set(sc).union([subclass for c in sc for subclass in all_subclasses(c)])

Returns a recursive set of all subclasses of a class

NAUGHTY_STRINGS = None
def naughty_strings() -> list[str]:
53def naughty_strings() -> list[str]:
54    # Naughty strings taken from https://github.com/minimaxir/big-list-of-naughty-strings
55    # Under MIT license, Copyright (c) 2015-2020 Max Woolf
56    global NAUGHTY_STRINGS
57    if not NAUGHTY_STRINGS:
58        with open(MZ_ROOT / "misc" / "python" / "materialize" / "blns.json") as f:
59            NAUGHTY_STRINGS = json.load(f)
60    return NAUGHTY_STRINGS
class YesNoOnce(enum.Enum):
63class YesNoOnce(Enum):
64    YES = 1
65    NO = 2
66    ONCE = 3
YES = <YesNoOnce.YES: 1>
NO = <YesNoOnce.NO: 2>
ONCE = <YesNoOnce.ONCE: 3>
class PropagatingThread(threading.Thread):
69class PropagatingThread(Thread):
70    def run(self):
71        self.exc = None
72        try:
73            self.ret = self._target(*self._args, **self._kwargs)  # type: ignore
74        except BaseException as e:
75            self.exc = e
76
77    def join(self, timeout=None):
78        super().join(timeout)
79        if self.exc:
80            raise self.exc
81        if hasattr(self, "ret"):
82            return self.ret

A class that represents a thread of control.

This class can be safely subclassed in a limited fashion. There are two ways to specify the activity: by passing a callable object to the constructor, or by overriding the run() method in a subclass.

def run(self):
70    def run(self):
71        self.exc = None
72        try:
73            self.ret = self._target(*self._args, **self._kwargs)  # type: ignore
74        except BaseException as e:
75            self.exc = e

Method representing the thread's activity.

You may override this method in a subclass. The standard run() method invokes the callable object passed to the object's constructor as the target argument, if any, with sequential and keyword arguments taken from the args and kwargs arguments, respectively.

def join(self, timeout=None):
77    def join(self, timeout=None):
78        super().join(timeout)
79        if self.exc:
80            raise self.exc
81        if hasattr(self, "ret"):
82            return self.ret

Wait until the thread terminates.

This blocks the calling thread until the thread whose join() method is called terminates -- either normally or through an unhandled exception or until the optional timeout occurs.

When the timeout argument is present and not None, it should be a floating point number specifying a timeout for the operation in seconds (or fractions thereof). As join() always returns None, you must call is_alive() after join() to decide whether a timeout happened -- if the thread is still alive, the join() call timed out.

When the timeout argument is not present or None, the operation will block until the thread terminates.

A thread can be join()ed many times.

join() raises a RuntimeError if an attempt is made to join the current thread as that would cause a deadlock. It is also an error to join() a thread before it has been started and attempts to do so raises the same exception.

def decompress_zst_to_directory(zst_file_path: str, destination_dir_path: str) -> list[str]:
 85def decompress_zst_to_directory(
 86    zst_file_path: str, destination_dir_path: str
 87) -> list[str]:
 88    """
 89    :return: file paths in destination dir
 90    """
 91    input_file = pathlib.Path(zst_file_path)
 92    output_paths = []
 93
 94    with open(input_file, "rb") as compressed:
 95        decompressor = zstandard.ZstdDecompressor()
 96        output_path = pathlib.Path(destination_dir_path) / input_file.stem
 97        output_paths.append(str(output_path))
 98        with open(output_path, "wb") as destination:
 99            decompressor.copy_stream(compressed, destination)
100
101    return output_paths
Returns

file paths in destination dir

def ensure_dir_exists(path_to_dir: str) -> None:
104def ensure_dir_exists(path_to_dir: str) -> None:
105    subprocess.run(
106        [
107            "mkdir",
108            "-p",
109            f"{path_to_dir}",
110        ],
111        check=True,
112    )
def sha256_of_file(path: str | pathlib.Path) -> str:
115def sha256_of_file(path: str | Path) -> str:
116    sha256 = hashlib.sha256()
117    with open(path, "rb") as f:
118        for block in iter(lambda: f.read(filecmp.BUFSIZE), b""):
119            sha256.update(block)
120    return sha256.hexdigest()
def sha256_of_utf8_string(value: str) -> str:
123def sha256_of_utf8_string(value: str) -> str:
124    return hashlib.sha256(bytes(value, encoding="utf-8")).hexdigest()
def stable_int_hash(*values: str) -> int:
127def stable_int_hash(*values: str) -> int:
128    if len(values) == 1:
129        return xxhash.xxh64(values[0], seed=0).intdigest()
130
131    return stable_int_hash(",".join([str(stable_int_hash(entry)) for entry in values]))
class HasName(typing.Protocol):
134class HasName(Protocol):
135    name: str

Base class for protocol classes.

Protocol classes are defined as::

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).

For example::

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::

class GenProto[T](Protocol):
    def meth(self) -> T:
        ...
HasName(*args, **kwargs)
1739def _no_init_or_replace_init(self, *args, **kwargs):
1740    cls = type(self)
1741
1742    if cls._is_protocol:
1743        raise TypeError('Protocols cannot be instantiated')
1744
1745    # Already using a custom `__init__`. No need to calculate correct
1746    # `__init__` to call. This can lead to RecursionError. See bpo-45121.
1747    if cls.__init__ is not _no_init_or_replace_init:
1748        return
1749
1750    # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`.
1751    # The first instantiation of the subclass will call `_no_init_or_replace_init` which
1752    # searches for a proper new `__init__` in the MRO. The new `__init__`
1753    # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent
1754    # instantiation of the protocol subclass will thus use the new
1755    # `__init__` and no longer call `_no_init_or_replace_init`.
1756    for base in cls.__mro__:
1757        init = base.__dict__.get('__init__', _no_init_or_replace_init)
1758        if init is not _no_init_or_replace_init:
1759            cls.__init__ = init
1760            break
1761    else:
1762        # should not happen
1763        cls.__init__ = object.__init__
1764
1765    cls.__init__(self, *args, **kwargs)
name: str
def selected_by_name(selected: list[str], objs: list[~U]) -> Iterator[~U]:
141def selected_by_name(selected: list[str], objs: list[U]) -> Iterator[U]:
142    for name in selected:
143        for obj in objs:
144            if obj.name == name:
145                yield obj
146                break
147        else:
148            raise ValueError(
149                f"Unknown object with name {name} in {[obj.name for obj in objs]}"
150            )
@dataclass
class PgConnInfo:
153@dataclass
154class PgConnInfo:
155    user: str
156    host: str
157    port: int
158    database: str
159    password: str | None = None
160    ssl: bool = False
161    cluster: str | None = None
162    autocommit: bool = False
163
164    def connect(self) -> psycopg.Connection:
165        conn = psycopg.connect(
166            host=self.host,
167            port=self.port,
168            user=self.user,
169            password=self.password,
170            dbname=self.database,
171            sslmode="require" if self.ssl else None,
172        )
173        if self.autocommit:
174            conn.autocommit = True
175        if self.cluster:
176            with conn.cursor() as cur:
177                cur.execute(f"SET cluster = {self.cluster}".encode())
178        return conn
179
180    def to_conn_string(self) -> str:
181        return (
182            f"postgres://{quote(self.user)}:{quote(self.password)}@{self.host}:{self.port}/{quote(self.database)}"
183            if self.password
184            else f"postgres://{quote(self.user)}@{self.host}:{self.port}/{quote(self.database)}"
185        )
PgConnInfo( user: str, host: str, port: int, database: str, password: str | None = None, ssl: bool = False, cluster: str | None = None, autocommit: bool = False)
user: str
host: str
port: int
database: str
password: str | None = None
ssl: bool = False
cluster: str | None = None
autocommit: bool = False
def connect(self) -> psycopg.Connection:
164    def connect(self) -> psycopg.Connection:
165        conn = psycopg.connect(
166            host=self.host,
167            port=self.port,
168            user=self.user,
169            password=self.password,
170            dbname=self.database,
171            sslmode="require" if self.ssl else None,
172        )
173        if self.autocommit:
174            conn.autocommit = True
175        if self.cluster:
176            with conn.cursor() as cur:
177                cur.execute(f"SET cluster = {self.cluster}".encode())
178        return conn
def to_conn_string(self) -> str:
180    def to_conn_string(self) -> str:
181        return (
182            f"postgres://{quote(self.user)}:{quote(self.password)}@{self.host}:{self.port}/{quote(self.database)}"
183            if self.password
184            else f"postgres://{quote(self.user)}@{self.host}:{self.port}/{quote(self.database)}"
185        )
def parse_pg_conn_string(conn_string: str) -> PgConnInfo:
188def parse_pg_conn_string(conn_string: str) -> PgConnInfo:
189    """Not supported natively by pg8000, so we have to parse ourselves"""
190    url = urlparse(conn_string)
191    query_params = parse_qs(url.query)
192    assert url.username
193    assert url.hostname
194    return PgConnInfo(
195        user=unquote(url.username),
196        password=unquote(url.password) if url.password else url.password,
197        host=url.hostname,
198        port=url.port or 5432,
199        database=url.path.lstrip("/"),
200        ssl=query_params.get("sslmode", ["disable"])[-1] != "disable",
201    )

Not supported natively by pg8000, so we have to parse ourselves