misc.python.materialize.util
Various utilities
1# Copyright Materialize, Inc. and contributors. All rights reserved. 2# 3# Use of this software is governed by the Business Source License 4# included in the LICENSE file at the root of this repository. 5# 6# As of the Change Date specified in that file, in accordance with 7# the Business Source License, use of this software will be governed 8# by the Apache License, Version 2.0. 9 10"""Various utilities""" 11 12from __future__ import annotations 13 14import filecmp 15import hashlib 16import json 17import os 18import pathlib 19import random 20import subprocess 21from collections.abc import Iterator 22from dataclasses import dataclass 23from enum import Enum 24from pathlib import Path 25from threading import Thread 26from typing import Protocol, TypeVar 27from urllib.parse import parse_qs, quote, unquote, urlparse 28 29import psycopg 30import xxhash 31import zstandard 32 33MZ_ROOT = Path(os.environ["MZ_ROOT"]) 34 35 36def nonce(digits: int) -> str: 37 return "".join(random.choice("0123456789abcdef") for _ in range(digits)) 38 39 40T = TypeVar("T") 41 42 43def all_subclasses(cls: type[T]) -> set[type[T]]: 44 """Returns a recursive set of all subclasses of a class""" 45 sc = cls.__subclasses__() 46 return set(sc).union([subclass for c in sc for subclass in all_subclasses(c)]) 47 48 49NAUGHTY_STRINGS = None 50 51 52def naughty_strings() -> list[str]: 53 # Naughty strings taken from https://github.com/minimaxir/big-list-of-naughty-strings 54 # Under MIT license, Copyright (c) 2015-2020 Max Woolf 55 global NAUGHTY_STRINGS 56 if not NAUGHTY_STRINGS: 57 with open(MZ_ROOT / "misc" / "python" / "materialize" / "blns.json") as f: 58 NAUGHTY_STRINGS = json.load(f) 59 return NAUGHTY_STRINGS 60 61 62class YesNoOnce(Enum): 63 YES = 1 64 NO = 2 65 ONCE = 3 66 67 68class PropagatingThread(Thread): 69 def run(self): 70 self.exc = None 71 try: 72 self.ret = self._target(*self._args, **self._kwargs) # type: ignore 73 except BaseException as e: 74 self.exc = e 75 76 def join(self, timeout=None): 77 super().join(timeout) 78 if self.exc: 79 raise self.exc 80 if hasattr(self, "ret"): 81 return self.ret 82 83 84def decompress_zst_to_directory( 85 zst_file_path: str, destination_dir_path: str 86) -> list[str]: 87 """ 88 :return: file paths in destination dir 89 """ 90 input_file = pathlib.Path(zst_file_path) 91 output_paths = [] 92 93 with open(input_file, "rb") as compressed: 94 decompressor = zstandard.ZstdDecompressor() 95 output_path = pathlib.Path(destination_dir_path) / input_file.stem 96 output_paths.append(str(output_path)) 97 with open(output_path, "wb") as destination: 98 decompressor.copy_stream(compressed, destination) 99 100 return output_paths 101 102 103def ensure_dir_exists(path_to_dir: str) -> None: 104 subprocess.run( 105 [ 106 "mkdir", 107 "-p", 108 f"{path_to_dir}", 109 ], 110 check=True, 111 ) 112 113 114def sha256_of_file(path: str | Path) -> str: 115 sha256 = hashlib.sha256() 116 with open(path, "rb") as f: 117 for block in iter(lambda: f.read(filecmp.BUFSIZE), b""): 118 sha256.update(block) 119 return sha256.hexdigest() 120 121 122def sha256_of_utf8_string(value: str) -> str: 123 return hashlib.sha256(bytes(value, encoding="utf-8")).hexdigest() 124 125 126def stable_int_hash(*values: str) -> int: 127 if len(values) == 1: 128 return xxhash.xxh64(values[0], seed=0).intdigest() 129 130 return stable_int_hash(",".join([str(stable_int_hash(entry)) for entry in values])) 131 132 133class HasName(Protocol): 134 name: str 135 136 137U = TypeVar("U", bound=HasName) 138 139 140def selected_by_name(selected: list[str], objs: list[U]) -> Iterator[U]: 141 for name in selected: 142 for obj in objs: 143 if obj.name == name: 144 yield obj 145 break 146 else: 147 raise ValueError( 148 f"Unknown object with name {name} in {[obj.name for obj in objs]}" 149 ) 150 151 152@dataclass 153class PgConnInfo: 154 user: str 155 host: str 156 port: int 157 database: str 158 password: str | None = None 159 ssl: bool = False 160 cluster: str | None = None 161 autocommit: bool = False 162 163 def connect(self) -> psycopg.Connection: 164 conn = psycopg.connect( 165 host=self.host, 166 port=self.port, 167 user=self.user, 168 password=self.password, 169 dbname=self.database, 170 sslmode="require" if self.ssl else None, 171 ) 172 if self.autocommit: 173 conn.autocommit = True 174 if self.cluster: 175 with conn.cursor() as cur: 176 cur.execute(f"SET cluster = {self.cluster}".encode()) 177 return conn 178 179 def to_conn_string(self) -> str: 180 return ( 181 f"postgres://{quote(self.user)}:{quote(self.password)}@{self.host}:{self.port}/{quote(self.database)}" 182 if self.password 183 else f"postgres://{quote(self.user)}@{self.host}:{self.port}/{quote(self.database)}" 184 ) 185 186 187def parse_pg_conn_string(conn_string: str) -> PgConnInfo: 188 """Not supported natively by pg8000, so we have to parse ourselves""" 189 url = urlparse(conn_string) 190 query_params = parse_qs(url.query) 191 assert url.username 192 assert url.hostname 193 return PgConnInfo( 194 user=unquote(url.username), 195 password=unquote(url.password) if url.password else url.password, 196 host=url.hostname, 197 port=url.port or 5432, 198 database=url.path.lstrip("/"), 199 ssl=query_params.get("sslmode", ["disable"])[-1] != "disable", 200 )
44def all_subclasses(cls: type[T]) -> set[type[T]]: 45 """Returns a recursive set of all subclasses of a class""" 46 sc = cls.__subclasses__() 47 return set(sc).union([subclass for c in sc for subclass in all_subclasses(c)])
Returns a recursive set of all subclasses of a class
53def naughty_strings() -> list[str]: 54 # Naughty strings taken from https://github.com/minimaxir/big-list-of-naughty-strings 55 # Under MIT license, Copyright (c) 2015-2020 Max Woolf 56 global NAUGHTY_STRINGS 57 if not NAUGHTY_STRINGS: 58 with open(MZ_ROOT / "misc" / "python" / "materialize" / "blns.json") as f: 59 NAUGHTY_STRINGS = json.load(f) 60 return NAUGHTY_STRINGS
69class PropagatingThread(Thread): 70 def run(self): 71 self.exc = None 72 try: 73 self.ret = self._target(*self._args, **self._kwargs) # type: ignore 74 except BaseException as e: 75 self.exc = e 76 77 def join(self, timeout=None): 78 super().join(timeout) 79 if self.exc: 80 raise self.exc 81 if hasattr(self, "ret"): 82 return self.ret
A class that represents a thread of control.
This class can be safely subclassed in a limited fashion. There are two ways to specify the activity: by passing a callable object to the constructor, or by overriding the run() method in a subclass.
70 def run(self): 71 self.exc = None 72 try: 73 self.ret = self._target(*self._args, **self._kwargs) # type: ignore 74 except BaseException as e: 75 self.exc = e
Method representing the thread's activity.
You may override this method in a subclass. The standard run() method invokes the callable object passed to the object's constructor as the target argument, if any, with sequential and keyword arguments taken from the args and kwargs arguments, respectively.
77 def join(self, timeout=None): 78 super().join(timeout) 79 if self.exc: 80 raise self.exc 81 if hasattr(self, "ret"): 82 return self.ret
Wait until the thread terminates.
This blocks the calling thread until the thread whose join() method is called terminates -- either normally or through an unhandled exception or until the optional timeout occurs.
When the timeout argument is present and not None, it should be a floating point number specifying a timeout for the operation in seconds (or fractions thereof). As join() always returns None, you must call is_alive() after join() to decide whether a timeout happened -- if the thread is still alive, the join() call timed out.
When the timeout argument is not present or None, the operation will block until the thread terminates.
A thread can be join()ed many times.
join() raises a RuntimeError if an attempt is made to join the current thread as that would cause a deadlock. It is also an error to join() a thread before it has been started and attempts to do so raises the same exception.
85def decompress_zst_to_directory( 86 zst_file_path: str, destination_dir_path: str 87) -> list[str]: 88 """ 89 :return: file paths in destination dir 90 """ 91 input_file = pathlib.Path(zst_file_path) 92 output_paths = [] 93 94 with open(input_file, "rb") as compressed: 95 decompressor = zstandard.ZstdDecompressor() 96 output_path = pathlib.Path(destination_dir_path) / input_file.stem 97 output_paths.append(str(output_path)) 98 with open(output_path, "wb") as destination: 99 decompressor.copy_stream(compressed, destination) 100 101 return output_paths
Returns
file paths in destination dir
Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto[T](Protocol):
def meth(self) -> T:
...
1739def _no_init_or_replace_init(self, *args, **kwargs): 1740 cls = type(self) 1741 1742 if cls._is_protocol: 1743 raise TypeError('Protocols cannot be instantiated') 1744 1745 # Already using a custom `__init__`. No need to calculate correct 1746 # `__init__` to call. This can lead to RecursionError. See bpo-45121. 1747 if cls.__init__ is not _no_init_or_replace_init: 1748 return 1749 1750 # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`. 1751 # The first instantiation of the subclass will call `_no_init_or_replace_init` which 1752 # searches for a proper new `__init__` in the MRO. The new `__init__` 1753 # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent 1754 # instantiation of the protocol subclass will thus use the new 1755 # `__init__` and no longer call `_no_init_or_replace_init`. 1756 for base in cls.__mro__: 1757 init = base.__dict__.get('__init__', _no_init_or_replace_init) 1758 if init is not _no_init_or_replace_init: 1759 cls.__init__ = init 1760 break 1761 else: 1762 # should not happen 1763 cls.__init__ = object.__init__ 1764 1765 cls.__init__(self, *args, **kwargs)
153@dataclass 154class PgConnInfo: 155 user: str 156 host: str 157 port: int 158 database: str 159 password: str | None = None 160 ssl: bool = False 161 cluster: str | None = None 162 autocommit: bool = False 163 164 def connect(self) -> psycopg.Connection: 165 conn = psycopg.connect( 166 host=self.host, 167 port=self.port, 168 user=self.user, 169 password=self.password, 170 dbname=self.database, 171 sslmode="require" if self.ssl else None, 172 ) 173 if self.autocommit: 174 conn.autocommit = True 175 if self.cluster: 176 with conn.cursor() as cur: 177 cur.execute(f"SET cluster = {self.cluster}".encode()) 178 return conn 179 180 def to_conn_string(self) -> str: 181 return ( 182 f"postgres://{quote(self.user)}:{quote(self.password)}@{self.host}:{self.port}/{quote(self.database)}" 183 if self.password 184 else f"postgres://{quote(self.user)}@{self.host}:{self.port}/{quote(self.database)}" 185 )
164 def connect(self) -> psycopg.Connection: 165 conn = psycopg.connect( 166 host=self.host, 167 port=self.port, 168 user=self.user, 169 password=self.password, 170 dbname=self.database, 171 sslmode="require" if self.ssl else None, 172 ) 173 if self.autocommit: 174 conn.autocommit = True 175 if self.cluster: 176 with conn.cursor() as cur: 177 cur.execute(f"SET cluster = {self.cluster}".encode()) 178 return conn
188def parse_pg_conn_string(conn_string: str) -> PgConnInfo: 189 """Not supported natively by pg8000, so we have to parse ourselves""" 190 url = urlparse(conn_string) 191 query_params = parse_qs(url.query) 192 assert url.username 193 assert url.hostname 194 return PgConnInfo( 195 user=unquote(url.username), 196 password=unquote(url.password) if url.password else url.password, 197 host=url.hostname, 198 port=url.port or 5432, 199 database=url.path.lstrip("/"), 200 ssl=query_params.get("sslmode", ["disable"])[-1] != "disable", 201 )
Not supported natively by pg8000, so we have to parse ourselves