sql.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. # Copyright Materialize, Inc. and contributors. All rights reserved.
  2. #
  3. # Use of this software is governed by the Business Source License
  4. # included in the LICENSE file at the root of this repository.
  5. #
  6. # As of the Change Date specified in that file, in accordance with
  7. # the Business Source License, use of this software will be governed
  8. # by the Apache License, Version 2.0.
  9. import contextlib
  10. import functools
  11. import logging
  12. import ssl
  13. import subprocess
  14. from collections.abc import Generator
  15. from pathlib import Path
  16. from typing import Any, cast
  17. import pg8000
  18. import sqlparse
  19. from pg8000.exceptions import InterfaceError
  20. from pg8000.native import literal
  21. from materialize.mzexplore.common import resource_path
  22. DictGenerator = Generator[dict[Any, Any], None, None]
  23. class Database:
  24. """An API to the database under exploration."""
  25. def __init__(
  26. self,
  27. port: int,
  28. host: str,
  29. user: str,
  30. password: str | None,
  31. database: str | None,
  32. require_ssl: bool,
  33. ) -> None:
  34. logging.debug(f"Initialize Database with host={host} port={port}, user={user}")
  35. if require_ssl:
  36. # verify_mode=ssl.CERT_REQUIRED is the default
  37. ssl_context = ssl.create_default_context()
  38. else:
  39. ssl_context = None
  40. self.conn = pg8000.connect(
  41. host=host,
  42. port=port,
  43. user=user,
  44. password=password,
  45. database=database,
  46. ssl_context=ssl_context,
  47. )
  48. self.conn.autocommit = True
  49. def close(self) -> None:
  50. self.conn.close()
  51. def query_one(self, query: str) -> dict[Any, Any]:
  52. with self.conn.cursor() as cursor:
  53. cursor.execute(query)
  54. cols = [d[0].lower() for d in cursor.description]
  55. row = {key: val for key, val in zip(cols, cursor.fetchone())}
  56. return cast(dict[Any, Any], row)
  57. def query_all(self, query: str) -> DictGenerator:
  58. with self.conn.cursor() as cursor:
  59. cursor.execute(query)
  60. cols = [d[0].lower() for d in cursor.description]
  61. for row in cursor.fetchall():
  62. yield {key: val for key, val in zip(cols, row)}
  63. def execute(self, statement: str) -> None:
  64. with self.conn.cursor() as cursor:
  65. cursor.execute(statement)
  66. def catalog_items(
  67. self,
  68. database: str | None = None,
  69. schema: str | None = None,
  70. name: str | None = None,
  71. system: bool = False,
  72. ) -> DictGenerator:
  73. p = resource_path("catalog/s_items.sql" if system else "catalog/u_items.sql")
  74. q = parse_query(p)
  75. yield from self.query_all(
  76. q.format(
  77. database="'%'" if database is None else literal(database),
  78. schema="'%'" if schema is None else literal(schema),
  79. name="'%'" if name is None else literal(name),
  80. )
  81. )
  82. def object_clusters(
  83. self,
  84. object_ids: list[str],
  85. ) -> DictGenerator:
  86. p = resource_path("catalog/u_object_clusters.sql")
  87. q = parse_query(p)
  88. yield from self.query_all(
  89. q.format(object_ids=", ".join(map(literal, object_ids)))
  90. )
  91. def clone_dependencies(
  92. self,
  93. source_ids: list[str],
  94. cluster_id: str,
  95. ) -> DictGenerator:
  96. p = resource_path("catalog/u_clone_dependencies.sql")
  97. q = parse_query(p)
  98. yield from self.query_all(
  99. q.format(
  100. source_ids=", ".join(map(literal, source_ids)),
  101. cluster_id=literal(cluster_id),
  102. )
  103. )
  104. def arrangement_sizes(self, id: str) -> DictGenerator:
  105. p = resource_path("catalog/u_arrangement_sizes.sql")
  106. q = parse_query(p)
  107. yield from self.query_all(q.format(id=literal(id)))
  108. @contextlib.contextmanager
  109. def update_environment(
  110. db: Database, env: dict[str, str]
  111. ) -> Generator[Database, None, None]:
  112. original = dict()
  113. for e in db.query_all("SHOW ALL"):
  114. key, old_value = e["name"], e["setting"]
  115. if key in env:
  116. original[key] = old_value
  117. new_value = env[key]
  118. db.execute(f"SET {identifier(key)} = {literal(new_value)}")
  119. yield db
  120. for key, old_value in original.items():
  121. db.execute(f"SET {identifier(key)} = {literal(old_value)}")
  122. # Utility functions
  123. # -----------------
  124. def parse_from_file(path: Path) -> list[str]:
  125. """Parses a *.sql file to a list of queries."""
  126. return sqlparse.split(path.read_text())
  127. def parse_query(path: Path) -> str:
  128. """Parses a *.sql file to a single query."""
  129. queries = parse_from_file(path)
  130. assert len(queries) == 1, f"Exactly one query expected in {path}"
  131. return queries[0]
  132. def try_mzfmt(sql: str) -> str:
  133. sql = sql.rstrip().rstrip(";")
  134. result = subprocess.run(
  135. ["mzfmt"],
  136. shell=True,
  137. input=sql.encode(),
  138. capture_output=True,
  139. )
  140. if result.returncode == 0:
  141. return result.stdout.decode("utf-8").rstrip()
  142. else:
  143. return sql.rstrip().rstrip(";")
  144. def identifier(s: str, force_quote: bool = False) -> str:
  145. """
  146. A version of pg8000.native.identifier (1) that is _ACTUALLY_ compatible with
  147. the Postgres code (2).
  148. 1. https://github.com/tlocke/pg8000/blob/017959e97751c35a3d58bc8bd5722cee5c10b656/pg8000/converters.py#L739-L761
  149. 2. https://github.com/postgres/postgres/blob/b0f7dd915bca6243f3daf52a81b8d0682a38ee3b/src/backend/utils/adt/ruleutils.c#L11968-L12050
  150. """
  151. if not isinstance(s, str):
  152. raise InterfaceError("identifier must be a str")
  153. if len(s) == 0:
  154. raise InterfaceError("identifier must be > 0 characters in length")
  155. # Look for characters that require quotation.
  156. def is_alpha(c: str) -> bool:
  157. return ord(c) >= ord("a") and ord(c) <= ord("z") or c == "_"
  158. def is_alphanum(c: str) -> bool:
  159. return is_alpha(c) or ord(c) >= ord("0") and ord(c) <= ord("9")
  160. quote = not (is_alpha(s[0]))
  161. for c in s[1:]:
  162. if not (is_alphanum(c)):
  163. if c == "\u0000":
  164. raise InterfaceError(
  165. "identifier cannot contain the code zero character"
  166. )
  167. quote = True
  168. if quote:
  169. break
  170. # Even if no speciall characters can be found we still want to quote
  171. # keywords.
  172. if s.upper() in keywords():
  173. quote = True
  174. if quote or force_quote:
  175. s = s.replace('"', '""')
  176. return f'"{s}"'
  177. else:
  178. return s
  179. @functools.lru_cache(maxsize=1)
  180. def keywords() -> set[str]:
  181. """
  182. Return a list of keywords reserved by Materialize.
  183. """
  184. with resource_path("sql/keywords.txt").open() as f:
  185. return set(
  186. line.strip().upper()
  187. for line in f.readlines()
  188. if not line.startswith("#") and len(line.strip()) > 0
  189. )