misc.python.materialize.util
Various utilities
1# Copyright Materialize, Inc. and contributors. All rights reserved. 2# 3# Use of this software is governed by the Business Source License 4# included in the LICENSE file at the root of this repository. 5# 6# As of the Change Date specified in that file, in accordance with 7# the Business Source License, use of this software will be governed 8# by the Apache License, Version 2.0. 9 10"""Various utilities""" 11 12from __future__ import annotations 13 14import hashlib 15import json 16import os 17import pathlib 18import random 19import subprocess 20from collections.abc import Iterator 21from dataclasses import dataclass 22from enum import Enum 23from pathlib import Path 24from threading import Thread 25from typing import Protocol, TypeVar 26from urllib.parse import parse_qs, quote, unquote, urlparse 27 28import psycopg 29import xxhash 30import zstandard 31 32MZ_ROOT = Path(os.environ["MZ_ROOT"]) 33 34 35def nonce(digits: int) -> str: 36 return "".join(random.choice("0123456789abcdef") for _ in range(digits)) 37 38 39T = TypeVar("T") 40 41 42def all_subclasses(cls: type[T]) -> set[type[T]]: 43 """Returns a recursive set of all subclasses of a class""" 44 sc = cls.__subclasses__() 45 return set(sc).union([subclass for c in sc for subclass in all_subclasses(c)]) 46 47 48NAUGHTY_STRINGS = None 49 50 51def naughty_strings() -> list[str]: 52 # Naughty strings taken from https://github.com/minimaxir/big-list-of-naughty-strings 53 # Under MIT license, Copyright (c) 2015-2020 Max Woolf 54 global NAUGHTY_STRINGS 55 if not NAUGHTY_STRINGS: 56 with open(MZ_ROOT / "misc" / "python" / "materialize" / "blns.json") as f: 57 NAUGHTY_STRINGS = json.load(f) 58 return NAUGHTY_STRINGS 59 60 61class YesNoOnce(Enum): 62 YES = 1 63 NO = 2 64 ONCE = 3 65 66 67class PropagatingThread(Thread): 68 def run(self): 69 self.exc = None 70 try: 71 self.ret = self._target(*self._args, **self._kwargs) # type: ignore 72 except BaseException as e: 73 self.exc = e 74 75 def join(self, timeout=None): 76 super().join(timeout) 77 if self.exc: 78 raise self.exc 79 if hasattr(self, "ret"): 80 return self.ret 81 82 83def decompress_zst_to_directory( 84 zst_file_path: str, destination_dir_path: str 85) -> list[str]: 86 """ 87 :return: file paths in destination dir 88 """ 89 input_file = pathlib.Path(zst_file_path) 90 output_paths = [] 91 92 with open(input_file, "rb") as compressed: 93 decompressor = zstandard.ZstdDecompressor() 94 output_path = pathlib.Path(destination_dir_path) / input_file.stem 95 output_paths.append(str(output_path)) 96 with open(output_path, "wb") as destination: 97 decompressor.copy_stream(compressed, destination) 98 99 return output_paths 100 101 102def ensure_dir_exists(path_to_dir: str) -> None: 103 subprocess.run( 104 [ 105 "mkdir", 106 "-p", 107 f"{path_to_dir}", 108 ], 109 check=True, 110 ) 111 112 113def sha256_of_utf8_string(value: str) -> str: 114 return hashlib.sha256(bytes(value, encoding="utf-8")).hexdigest() 115 116 117def stable_int_hash(*values: str) -> int: 118 if len(values) == 1: 119 return xxhash.xxh64(values[0], seed=0).intdigest() 120 121 return stable_int_hash(",".join([str(stable_int_hash(entry)) for entry in values])) 122 123 124class HasName(Protocol): 125 name: str 126 127 128U = TypeVar("U", bound=HasName) 129 130 131def selected_by_name(selected: list[str], objs: list[U]) -> Iterator[U]: 132 for name in selected: 133 for obj in objs: 134 if obj.name == name: 135 yield obj 136 break 137 else: 138 raise ValueError( 139 f"Unknown object with name {name} in {[obj.name for obj in objs]}" 140 ) 141 142 143@dataclass 144class PgConnInfo: 145 user: str 146 host: str 147 port: int 148 database: str 149 password: str | None = None 150 ssl: bool = False 151 cluster: str | None = None 152 autocommit: bool = False 153 154 def connect(self) -> psycopg.Connection: 155 conn = psycopg.connect( 156 host=self.host, 157 port=self.port, 158 user=self.user, 159 password=self.password, 160 dbname=self.database, 161 sslmode="require" if self.ssl else None, 162 ) 163 # Set SO_LINGER(1, 0) so close() sends RST instead of FIN, bypassing 164 # TIME_WAIT. Prevents exhausting the ~28k ephemeral port range under 165 # high connection churn (e.g. benchmarks doing rapid connect/disconnect). 166 self._set_linger(conn) 167 if self.autocommit: 168 conn.autocommit = True 169 if self.cluster: 170 with conn.cursor() as cur: 171 cur.execute(f"SET cluster = {self.cluster}".encode()) 172 return conn 173 174 @staticmethod 175 def _set_linger(conn: psycopg.Connection) -> None: 176 import socket 177 import struct 178 179 fd = conn.pgconn.socket 180 if fd < 0: 181 return 182 sock = socket.socket(fileno=fd) 183 try: 184 sock.setsockopt( 185 socket.SOL_SOCKET, socket.SO_LINGER, struct.pack("ii", 1, 0) 186 ) 187 finally: 188 sock.detach() 189 190 def to_conn_string(self) -> str: 191 return ( 192 f"postgres://{quote(self.user)}:{quote(self.password)}@{self.host}:{self.port}/{quote(self.database)}" 193 if self.password 194 else f"postgres://{quote(self.user)}@{self.host}:{self.port}/{quote(self.database)}" 195 ) 196 197 198def parse_pg_conn_string(conn_string: str) -> PgConnInfo: 199 """Not supported natively by pg8000, so we have to parse ourselves""" 200 url = urlparse(conn_string) 201 query_params = parse_qs(url.query) 202 assert url.username 203 assert url.hostname 204 return PgConnInfo( 205 user=unquote(url.username), 206 password=unquote(url.password) if url.password else url.password, 207 host=url.hostname, 208 port=url.port or 5432, 209 database=url.path.lstrip("/"), 210 ssl=query_params.get("sslmode", ["disable"])[-1] != "disable", 211 ) 212 213 214FILTERED_ARGS = [ 215 # Secrets 216 "mzp_", 217 "-----BEGIN PRIVATE KEY-----", 218 "-----BEGIN CERTIFICATE-----", 219 "confluent-api-key=", 220 "confluent-api-secret=", 221 "aws-access-key-id=", 222 "aws-secret-access-key=", 223 "default-sql-server-password=", 224 # Not a secret, but too spammy, filter too 225 "CLUSTER_REPLICA_SIZES", 226 "cluster-replica-sizes=", 227] 228 229 230def filter_cmd(args: list[str]) -> list[str]: 231 """Don't print out secrets in test logs""" 232 return [ 233 ( 234 "[REDACTED]" 235 if any(filtered_arg in arg for filtered_arg in FILTERED_ARGS) 236 else arg 237 ) 238 for arg in args 239 ]
43def all_subclasses(cls: type[T]) -> set[type[T]]: 44 """Returns a recursive set of all subclasses of a class""" 45 sc = cls.__subclasses__() 46 return set(sc).union([subclass for c in sc for subclass in all_subclasses(c)])
Returns a recursive set of all subclasses of a class
52def naughty_strings() -> list[str]: 53 # Naughty strings taken from https://github.com/minimaxir/big-list-of-naughty-strings 54 # Under MIT license, Copyright (c) 2015-2020 Max Woolf 55 global NAUGHTY_STRINGS 56 if not NAUGHTY_STRINGS: 57 with open(MZ_ROOT / "misc" / "python" / "materialize" / "blns.json") as f: 58 NAUGHTY_STRINGS = json.load(f) 59 return NAUGHTY_STRINGS
68class PropagatingThread(Thread): 69 def run(self): 70 self.exc = None 71 try: 72 self.ret = self._target(*self._args, **self._kwargs) # type: ignore 73 except BaseException as e: 74 self.exc = e 75 76 def join(self, timeout=None): 77 super().join(timeout) 78 if self.exc: 79 raise self.exc 80 if hasattr(self, "ret"): 81 return self.ret
A class that represents a thread of control.
This class can be safely subclassed in a limited fashion. There are two ways to specify the activity: by passing a callable object to the constructor, or by overriding the run() method in a subclass.
69 def run(self): 70 self.exc = None 71 try: 72 self.ret = self._target(*self._args, **self._kwargs) # type: ignore 73 except BaseException as e: 74 self.exc = e
Method representing the thread's activity.
You may override this method in a subclass. The standard run() method invokes the callable object passed to the object's constructor as the target argument, if any, with sequential and keyword arguments taken from the args and kwargs arguments, respectively.
76 def join(self, timeout=None): 77 super().join(timeout) 78 if self.exc: 79 raise self.exc 80 if hasattr(self, "ret"): 81 return self.ret
Wait until the thread terminates.
This blocks the calling thread until the thread whose join() method is called terminates -- either normally or through an unhandled exception or until the optional timeout occurs.
When the timeout argument is present and not None, it should be a floating point number specifying a timeout for the operation in seconds (or fractions thereof). As join() always returns None, you must call is_alive() after join() to decide whether a timeout happened -- if the thread is still alive, the join() call timed out.
When the timeout argument is not present or None, the operation will block until the thread terminates.
A thread can be join()ed many times.
join() raises a RuntimeError if an attempt is made to join the current thread as that would cause a deadlock. It is also an error to join() a thread before it has been started and attempts to do so raises the same exception.
84def decompress_zst_to_directory( 85 zst_file_path: str, destination_dir_path: str 86) -> list[str]: 87 """ 88 :return: file paths in destination dir 89 """ 90 input_file = pathlib.Path(zst_file_path) 91 output_paths = [] 92 93 with open(input_file, "rb") as compressed: 94 decompressor = zstandard.ZstdDecompressor() 95 output_path = pathlib.Path(destination_dir_path) / input_file.stem 96 output_paths.append(str(output_path)) 97 with open(output_path, "wb") as destination: 98 decompressor.copy_stream(compressed, destination) 99 100 return output_paths
Returns
file paths in destination dir
Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto[T](Protocol):
def meth(self) -> T:
...
1739def _no_init_or_replace_init(self, *args, **kwargs): 1740 cls = type(self) 1741 1742 if cls._is_protocol: 1743 raise TypeError('Protocols cannot be instantiated') 1744 1745 # Already using a custom `__init__`. No need to calculate correct 1746 # `__init__` to call. This can lead to RecursionError. See bpo-45121. 1747 if cls.__init__ is not _no_init_or_replace_init: 1748 return 1749 1750 # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`. 1751 # The first instantiation of the subclass will call `_no_init_or_replace_init` which 1752 # searches for a proper new `__init__` in the MRO. The new `__init__` 1753 # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent 1754 # instantiation of the protocol subclass will thus use the new 1755 # `__init__` and no longer call `_no_init_or_replace_init`. 1756 for base in cls.__mro__: 1757 init = base.__dict__.get('__init__', _no_init_or_replace_init) 1758 if init is not _no_init_or_replace_init: 1759 cls.__init__ = init 1760 break 1761 else: 1762 # should not happen 1763 cls.__init__ = object.__init__ 1764 1765 cls.__init__(self, *args, **kwargs)
144@dataclass 145class PgConnInfo: 146 user: str 147 host: str 148 port: int 149 database: str 150 password: str | None = None 151 ssl: bool = False 152 cluster: str | None = None 153 autocommit: bool = False 154 155 def connect(self) -> psycopg.Connection: 156 conn = psycopg.connect( 157 host=self.host, 158 port=self.port, 159 user=self.user, 160 password=self.password, 161 dbname=self.database, 162 sslmode="require" if self.ssl else None, 163 ) 164 # Set SO_LINGER(1, 0) so close() sends RST instead of FIN, bypassing 165 # TIME_WAIT. Prevents exhausting the ~28k ephemeral port range under 166 # high connection churn (e.g. benchmarks doing rapid connect/disconnect). 167 self._set_linger(conn) 168 if self.autocommit: 169 conn.autocommit = True 170 if self.cluster: 171 with conn.cursor() as cur: 172 cur.execute(f"SET cluster = {self.cluster}".encode()) 173 return conn 174 175 @staticmethod 176 def _set_linger(conn: psycopg.Connection) -> None: 177 import socket 178 import struct 179 180 fd = conn.pgconn.socket 181 if fd < 0: 182 return 183 sock = socket.socket(fileno=fd) 184 try: 185 sock.setsockopt( 186 socket.SOL_SOCKET, socket.SO_LINGER, struct.pack("ii", 1, 0) 187 ) 188 finally: 189 sock.detach() 190 191 def to_conn_string(self) -> str: 192 return ( 193 f"postgres://{quote(self.user)}:{quote(self.password)}@{self.host}:{self.port}/{quote(self.database)}" 194 if self.password 195 else f"postgres://{quote(self.user)}@{self.host}:{self.port}/{quote(self.database)}" 196 )
155 def connect(self) -> psycopg.Connection: 156 conn = psycopg.connect( 157 host=self.host, 158 port=self.port, 159 user=self.user, 160 password=self.password, 161 dbname=self.database, 162 sslmode="require" if self.ssl else None, 163 ) 164 # Set SO_LINGER(1, 0) so close() sends RST instead of FIN, bypassing 165 # TIME_WAIT. Prevents exhausting the ~28k ephemeral port range under 166 # high connection churn (e.g. benchmarks doing rapid connect/disconnect). 167 self._set_linger(conn) 168 if self.autocommit: 169 conn.autocommit = True 170 if self.cluster: 171 with conn.cursor() as cur: 172 cur.execute(f"SET cluster = {self.cluster}".encode()) 173 return conn
199def parse_pg_conn_string(conn_string: str) -> PgConnInfo: 200 """Not supported natively by pg8000, so we have to parse ourselves""" 201 url = urlparse(conn_string) 202 query_params = parse_qs(url.query) 203 assert url.username 204 assert url.hostname 205 return PgConnInfo( 206 user=unquote(url.username), 207 password=unquote(url.password) if url.password else url.password, 208 host=url.hostname, 209 port=url.port or 5432, 210 database=url.path.lstrip("/"), 211 ssl=query_params.get("sslmode", ["disable"])[-1] != "disable", 212 )
Not supported natively by pg8000, so we have to parse ourselves
231def filter_cmd(args: list[str]) -> list[str]: 232 """Don't print out secrets in test logs""" 233 return [ 234 ( 235 "[REDACTED]" 236 if any(filtered_arg in arg for filtered_arg in FILTERED_ARGS) 237 else arg 238 ) 239 for arg in args 240 ]
Don't print out secrets in test logs