util.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. # Copyright Materialize, Inc. and contributors. All rights reserved.
  2. #
  3. # Use of this software is governed by the Business Source License
  4. # included in the LICENSE file at the root of this repository.
  5. #
  6. # As of the Change Date specified in that file, in accordance with
  7. # the Business Source License, use of this software will be governed
  8. # by the Apache License, Version 2.0.
  9. """Various utilities"""
  10. from __future__ import annotations
  11. import filecmp
  12. import hashlib
  13. import json
  14. import os
  15. import pathlib
  16. import random
  17. import subprocess
  18. from collections.abc import Iterator
  19. from dataclasses import dataclass
  20. from enum import Enum
  21. from pathlib import Path
  22. from threading import Thread
  23. from typing import Protocol, TypeVar
  24. from urllib.parse import parse_qs, quote, unquote, urlparse
  25. import psycopg
  26. import xxhash
  27. import zstandard
  28. MZ_ROOT = Path(os.environ["MZ_ROOT"])
  29. def nonce(digits: int) -> str:
  30. return "".join(random.choice("0123456789abcdef") for _ in range(digits))
  31. T = TypeVar("T")
  32. def all_subclasses(cls: type[T]) -> set[type[T]]:
  33. """Returns a recursive set of all subclasses of a class"""
  34. sc = cls.__subclasses__()
  35. return set(sc).union([subclass for c in sc for subclass in all_subclasses(c)])
  36. NAUGHTY_STRINGS = None
  37. def naughty_strings() -> list[str]:
  38. # Naughty strings taken from https://github.com/minimaxir/big-list-of-naughty-strings
  39. # Under MIT license, Copyright (c) 2015-2020 Max Woolf
  40. global NAUGHTY_STRINGS
  41. if not NAUGHTY_STRINGS:
  42. with open(MZ_ROOT / "misc" / "python" / "materialize" / "blns.json") as f:
  43. NAUGHTY_STRINGS = json.load(f)
  44. return NAUGHTY_STRINGS
  45. class YesNoOnce(Enum):
  46. YES = 1
  47. NO = 2
  48. ONCE = 3
  49. class PropagatingThread(Thread):
  50. def run(self):
  51. self.exc = None
  52. try:
  53. self.ret = self._target(*self._args, **self._kwargs) # type: ignore
  54. except BaseException as e:
  55. self.exc = e
  56. def join(self, timeout=None):
  57. super().join(timeout)
  58. if self.exc:
  59. raise self.exc
  60. if hasattr(self, "ret"):
  61. return self.ret
  62. def decompress_zst_to_directory(
  63. zst_file_path: str, destination_dir_path: str
  64. ) -> list[str]:
  65. """
  66. :return: file paths in destination dir
  67. """
  68. input_file = pathlib.Path(zst_file_path)
  69. output_paths = []
  70. with open(input_file, "rb") as compressed:
  71. decompressor = zstandard.ZstdDecompressor()
  72. output_path = pathlib.Path(destination_dir_path) / input_file.stem
  73. output_paths.append(str(output_path))
  74. with open(output_path, "wb") as destination:
  75. decompressor.copy_stream(compressed, destination)
  76. return output_paths
  77. def ensure_dir_exists(path_to_dir: str) -> None:
  78. subprocess.run(
  79. [
  80. "mkdir",
  81. "-p",
  82. f"{path_to_dir}",
  83. ],
  84. check=True,
  85. )
  86. def sha256_of_file(path: str | Path) -> str:
  87. sha256 = hashlib.sha256()
  88. with open(path, "rb") as f:
  89. for block in iter(lambda: f.read(filecmp.BUFSIZE), b""):
  90. sha256.update(block)
  91. return sha256.hexdigest()
  92. def sha256_of_utf8_string(value: str) -> str:
  93. return hashlib.sha256(bytes(value, encoding="utf-8")).hexdigest()
  94. def stable_int_hash(*values: str) -> int:
  95. if len(values) == 1:
  96. return xxhash.xxh64(values[0], seed=0).intdigest()
  97. return stable_int_hash(",".join([str(stable_int_hash(entry)) for entry in values]))
  98. class HasName(Protocol):
  99. name: str
  100. U = TypeVar("U", bound=HasName)
  101. def selected_by_name(selected: list[str], objs: list[U]) -> Iterator[U]:
  102. for name in selected:
  103. for obj in objs:
  104. if obj.name == name:
  105. yield obj
  106. break
  107. else:
  108. raise ValueError(
  109. f"Unknown object with name {name} in {[obj.name for obj in objs]}"
  110. )
  111. @dataclass
  112. class PgConnInfo:
  113. user: str
  114. host: str
  115. port: int
  116. database: str
  117. password: str | None = None
  118. ssl: bool = False
  119. cluster: str | None = None
  120. autocommit: bool = False
  121. def connect(self) -> psycopg.Connection:
  122. conn = psycopg.connect(
  123. host=self.host,
  124. port=self.port,
  125. user=self.user,
  126. password=self.password,
  127. dbname=self.database,
  128. sslmode="require" if self.ssl else None,
  129. )
  130. if self.autocommit:
  131. conn.autocommit = True
  132. if self.cluster:
  133. with conn.cursor() as cur:
  134. cur.execute(f"SET cluster = {self.cluster}".encode())
  135. return conn
  136. def to_conn_string(self) -> str:
  137. return (
  138. f"postgres://{quote(self.user)}:{quote(self.password)}@{self.host}:{self.port}/{quote(self.database)}"
  139. if self.password
  140. else f"postgres://{quote(self.user)}@{self.host}:{self.port}/{quote(self.database)}"
  141. )
  142. def parse_pg_conn_string(conn_string: str) -> PgConnInfo:
  143. """Not supported natively by pg8000, so we have to parse ourselves"""
  144. url = urlparse(conn_string)
  145. query_params = parse_qs(url.query)
  146. assert url.username
  147. assert url.hostname
  148. return PgConnInfo(
  149. user=unquote(url.username),
  150. password=unquote(url.password) if url.password else url.password,
  151. host=url.hostname,
  152. port=url.port or 5432,
  153. database=url.path.lstrip("/"),
  154. ssl=query_params.get("sslmode", ["disable"])[-1] != "disable",
  155. )