sql.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. # Copyright Materialize, Inc. and contributors. All rights reserved.
  2. #
  3. # Use of this software is governed by the Business Source License
  4. # included in the LICENSE file at the root of this repository.
  5. #
  6. # As of the Change Date specified in that file, in accordance with
  7. # the Business Source License, use of this software will be governed
  8. # by the Apache License, Version 2.0.
  9. import logging
  10. import re
  11. from enum import Enum
  12. from pathlib import Path
  13. from typing import Any, cast
  14. import numpy as np
  15. import psycopg
  16. import sqlparse
  17. from . import Scenario, util
  18. class Dialect(Enum):
  19. PG = 0
  20. MZ = 1
  21. class Query:
  22. """An API for manipulating workload queries."""
  23. def __init__(self, query: str) -> None:
  24. self.query = query
  25. def __str__(self) -> str:
  26. return self.query
  27. def name(self) -> str:
  28. """Extracts and returns the name of this query from a '-- name: {name}' comment.
  29. Returns 'anonymous' if the name is not set."""
  30. p = r"-- name\: (?P<name>.+)"
  31. m = re.search(p, self.query, re.MULTILINE)
  32. return m.group("name") if m else "anonoymous"
  33. def explain(self, timing: bool, dialect: Dialect = Dialect.MZ) -> str:
  34. """Prepends 'EXPLAIN ...' to the query respecting the given dialect."""
  35. if dialect == Dialect.PG:
  36. if timing:
  37. return "\n".join(["EXPLAIN (ANALYZE, TIMING TRUE)", self.query])
  38. else:
  39. return "\n".join(["EXPLAIN", self.query])
  40. else:
  41. if timing:
  42. return "\n".join(["EXPLAIN WITH(timing)", self.query])
  43. else:
  44. return "\n".join(["EXPLAIN", self.query])
  45. class ExplainOutput:
  46. """An API for manipulating 'EXPLAIN ... PLAN FOR' results."""
  47. def __init__(self, output: str) -> None:
  48. self.output = output
  49. def __str__(self) -> str:
  50. return self.output
  51. def optimization_time(self) -> np.timedelta64 | None:
  52. """Optionally, returns the optimization_time time for an 'EXPLAIN' output."""
  53. p = r"(Optimization time|Planning Time)\: (?P<time>[0-9]+(\.[0-9]+)?\s?\S+)"
  54. m = re.search(p, self.output, re.MULTILINE)
  55. return util.duration_to_timedelta(m["time"]) if m else None
  56. class Database:
  57. """An API to the database under test."""
  58. def __init__(
  59. self,
  60. port: int,
  61. host: str,
  62. user: str,
  63. password: str | None,
  64. database: str | None,
  65. require_ssl: bool,
  66. ) -> None:
  67. logging.debug(f"Initialize Database with host={host} port={port}, user={user}")
  68. self.conn = psycopg.connect(
  69. host=host,
  70. port=port,
  71. user=user,
  72. password=password,
  73. dbname=database,
  74. sslmode="require" if require_ssl else "disable",
  75. )
  76. self.conn.autocommit = True
  77. self.dialect = Dialect.MZ if "Materialize" in self.version() else Dialect.PG
  78. def close(self) -> None:
  79. self.conn.close()
  80. def version(self) -> str:
  81. result = self.query_one("SELECT version()")
  82. return cast(str, result[0])
  83. def mz_version(self) -> str | None:
  84. if self.dialect == Dialect.MZ:
  85. result = self.query_one("SELECT mz_version()")
  86. return cast(str, result[0])
  87. else:
  88. return None
  89. def drop_database(self, scenario: Scenario) -> None:
  90. logging.debug(f'Drop database "{scenario}"')
  91. self.execute(f"DROP DATABASE IF EXISTS {scenario}")
  92. def create_database(self, scenario: Scenario) -> None:
  93. logging.debug(f'Create database "{scenario}"')
  94. self.execute(f"CREATE DATABASE {scenario}")
  95. def explain(self, query: Query, timing: bool) -> "ExplainOutput":
  96. result = self.query_all(query.explain(timing, self.dialect))
  97. return ExplainOutput("\n".join([col for line in result for col in line]))
  98. def execute(self, statement: str) -> None:
  99. with self.conn.cursor() as cursor:
  100. cursor.execute(statement.encode())
  101. def execute_all(self, statements: list[str]) -> None:
  102. with self.conn.cursor() as cursor:
  103. for statement in statements:
  104. cursor.execute(statement.encode())
  105. def query_one(self, query: str) -> dict[Any, Any]:
  106. with self.conn.cursor() as cursor:
  107. cursor.execute(query.encode())
  108. return cast(dict[Any, Any], cursor.fetchone())
  109. def query_all(self, query: str) -> dict[Any, Any]:
  110. with self.conn.cursor() as cursor:
  111. cursor.execute(query.encode())
  112. return cast(dict[Any, Any], cursor.fetchall())
  113. # Utility functions
  114. # -----------------
  115. def parse_from_file(path: Path) -> list[str]:
  116. """Parses a *.sql file to a list of queries."""
  117. return sqlparse.split(path.read_text())