sql_executor.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. # Copyright Materialize, Inc. and contributors. All rights reserved.
  2. #
  3. # Use of this software is governed by the Business Source License
  4. # included in the LICENSE file at the root of this repository.
  5. #
  6. # As of the Change Date specified in that file, in accordance with
  7. # the Business Source License, use of this software will be governed
  8. # by the Apache License, Version 2.0.
  9. from collections import deque
  10. from collections.abc import Sequence
  11. from typing import Any
  12. from psycopg import Connection, DataError
  13. from psycopg.errors import (
  14. DatabaseError,
  15. InternalError_,
  16. OperationalError,
  17. ProgrammingError,
  18. SyntaxError,
  19. )
  20. from materialize.output_consistency.output.output_printer import OutputPrinter
  21. class SqlExecutionError(Exception):
  22. def __init__(self, message: str):
  23. super().__init__(message)
  24. # storing it here as well makes it easier to access the message
  25. self.message = message
  26. class SqlExecutor:
  27. """Base class of `PgWireDatabaseSqlExecutor` and `DryRunSqlExecutor`"""
  28. def __init__(
  29. self,
  30. name: str,
  31. ):
  32. self.name = name
  33. def __str__(self) -> str:
  34. return self.__class__.__name__
  35. def ddl(self, sql: str) -> None:
  36. raise NotImplementedError
  37. def begin_tx(self, isolation_level: str) -> None:
  38. raise NotImplementedError
  39. def commit(self) -> None:
  40. raise NotImplementedError
  41. def rollback(self) -> None:
  42. raise NotImplementedError
  43. def query(self, sql: str) -> Sequence[Sequence[Any]]:
  44. raise NotImplementedError
  45. def query_version(self) -> str:
  46. raise NotImplementedError
  47. def before_query_execution(self) -> None:
  48. pass
  49. def after_query_execution(self) -> None:
  50. pass
  51. def before_new_tx(self):
  52. pass
  53. def after_new_tx(self):
  54. pass
  55. class PgWireDatabaseSqlExecutor(SqlExecutor):
  56. def __init__(
  57. self,
  58. connection: Connection,
  59. use_autocommit: bool,
  60. output_printer: OutputPrinter,
  61. name: str,
  62. ):
  63. super().__init__(name)
  64. connection.autocommit = use_autocommit
  65. self.cursor = connection.cursor()
  66. self.output_printer = output_printer
  67. self.last_statements = deque[str](maxlen=5)
  68. def ddl(self, sql: str) -> None:
  69. self._execute_with_cursor(sql)
  70. def begin_tx(self, isolation_level: str) -> None:
  71. self._execute_with_cursor(f"BEGIN ISOLATION LEVEL {isolation_level};")
  72. def commit(self) -> None:
  73. self._execute_with_cursor("COMMIT;")
  74. def rollback(self) -> None:
  75. self._execute_with_cursor("ROLLBACK;")
  76. def query(self, sql: str) -> Sequence[Sequence[Any]]:
  77. try:
  78. self._execute_with_cursor(sql)
  79. return self.cursor.fetchall()
  80. except (ProgrammingError, DatabaseError) as err:
  81. raise SqlExecutionError(self._extract_message_from_error(err))
  82. def query_version(self) -> str:
  83. return self.query("SELECT version();")[0][0]
  84. def _execute_with_cursor(self, sql: str) -> None:
  85. try:
  86. self.last_statements.append(sql)
  87. self.cursor.execute(sql.encode())
  88. except OperationalError as e:
  89. if "server closed the connection unexpectedly" not in str(e):
  90. raise SqlExecutionError(self._extract_message_from_error(e))
  91. print("A network error occurred! Aborting!")
  92. # The current or one of previous queries might have broken the database.
  93. last_statements_desc = self.last_statements.copy()
  94. last_statements_desc.reverse()
  95. statements_str = "\n".join(
  96. f" {statement}" for statement in last_statements_desc
  97. )
  98. print(
  99. f"Last {len(last_statements_desc)} queries in descending order:\n{statements_str}"
  100. )
  101. exit(1)
  102. except (ProgrammingError, DatabaseError, SyntaxError, InternalError_) as err:
  103. raise SqlExecutionError(self._extract_message_from_error(err))
  104. except DataError as err: # type: ignore
  105. raise SqlExecutionError(err.args[0])
  106. except ValueError as err:
  107. self.output_printer.print_error(f"Query with value error is: {sql}")
  108. raise err
  109. except Exception:
  110. self.output_printer.print_error(f"Query with unexpected error is: {sql}")
  111. raise
  112. def _extract_message_from_error(
  113. self,
  114. error: (
  115. OperationalError
  116. | ProgrammingError
  117. | DataError
  118. | DatabaseError
  119. | SyntaxError
  120. | InternalError_
  121. ),
  122. ) -> str:
  123. if error.diag.message_primary is not None:
  124. result = str(error.diag.message_primary)
  125. if error.diag.message_detail is not None:
  126. result += f" ({error.diag.message_detail})"
  127. return result
  128. if len(error.args) > 0:
  129. return str(error.args[0])
  130. return str(error)
  131. class MzDatabaseSqlExecutor(PgWireDatabaseSqlExecutor):
  132. def __init__(
  133. self,
  134. default_connection: Connection,
  135. mz_system_connection: Connection,
  136. use_autocommit: bool,
  137. output_printer: OutputPrinter,
  138. name: str,
  139. ):
  140. super().__init__(default_connection, use_autocommit, output_printer, name)
  141. self.mz_system_connection = mz_system_connection
  142. def query_version(self) -> str:
  143. return self.query("SELECT mz_version();")[0][0]
  144. class DryRunSqlExecutor(SqlExecutor):
  145. def __init__(self, output_printer: OutputPrinter, name: str):
  146. super().__init__(name)
  147. self.output_printer = output_printer
  148. def consume_sql(self, sql: str) -> None:
  149. self.output_printer.print_sql(sql)
  150. def ddl(self, sql: str) -> None:
  151. self.consume_sql(sql)
  152. def begin_tx(self, isolation_level: str) -> None:
  153. self.consume_sql(f"BEGIN ISOLATION LEVEL {isolation_level};")
  154. def commit(self) -> None:
  155. self.consume_sql("COMMIT;")
  156. def rollback(self) -> None:
  157. self.consume_sql("ROLLBACK;")
  158. def query(self, sql: str) -> Sequence[Sequence[Any]]:
  159. self.consume_sql(sql)
  160. return []
  161. def query_version(self) -> str:
  162. return "(dry-run)"