query_template.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442
  1. # Copyright Materialize, Inc. and contributors. All rights reserved.
  2. #
  3. # Use of this software is governed by the Business Source License
  4. # included in the LICENSE file at the root of this repository.
  5. #
  6. # As of the Change Date specified in that file, in accordance with
  7. # the Business Source License, use of this software will be governed
  8. # by the Apache License, Version 2.0.
  9. from __future__ import annotations
  10. from collections.abc import Callable
  11. from materialize.output_consistency.execution.evaluation_strategy import (
  12. EvaluationStrategy,
  13. )
  14. from materialize.output_consistency.execution.query_output_mode import (
  15. QueryOutputMode,
  16. query_output_mode_to_sql,
  17. )
  18. from materialize.output_consistency.execution.sql_dialect_adjuster import (
  19. SqlDialectAdjuster,
  20. )
  21. from materialize.output_consistency.execution.value_storage_layout import (
  22. ROW_INDEX_COL_NAME,
  23. ValueStorageLayout,
  24. )
  25. from materialize.output_consistency.expression.expression import Expression
  26. from materialize.output_consistency.expression.expression_characteristics import (
  27. ExpressionCharacteristics,
  28. )
  29. from materialize.output_consistency.query.additional_source import (
  30. AdditionalSource,
  31. as_data_sources,
  32. )
  33. from materialize.output_consistency.query.data_source import (
  34. DataSource,
  35. )
  36. from materialize.output_consistency.query.query_format import QueryOutputFormat
  37. from materialize.output_consistency.selection.column_selection import (
  38. QueryColumnByIndexSelection,
  39. )
  40. from materialize.output_consistency.selection.row_selection import (
  41. DataRowSelection,
  42. )
  43. class QueryTemplate:
  44. """Query template as base for creating SQL for different evaluation strategies"""
  45. def __init__(
  46. self,
  47. expect_error: bool,
  48. select_expressions: list[Expression],
  49. where_expression: Expression | None,
  50. storage_layout: ValueStorageLayout,
  51. data_source: DataSource,
  52. contains_aggregations: bool,
  53. row_selection: DataRowSelection,
  54. offset: int | None = None,
  55. limit: int | None = None,
  56. additional_sources: list[AdditionalSource] = [],
  57. custom_order_expressions: list[Expression] | None = None,
  58. ) -> None:
  59. assert storage_layout != ValueStorageLayout.ANY
  60. self.expect_error = expect_error
  61. self.select_expressions: list[Expression] = select_expressions
  62. self.where_expression = where_expression
  63. self.storage_layout = storage_layout
  64. self.data_source = data_source
  65. self.additional_sources = additional_sources
  66. self.contains_aggregations = contains_aggregations
  67. self.row_selection = row_selection
  68. self.offset = offset
  69. self.limit = limit
  70. self.custom_order_expressions = custom_order_expressions
  71. self.disable_error_message_validation = not self.__can_compare_error_messages()
  72. def get_all_data_sources(self) -> list[DataSource]:
  73. all_data_sources = [self.data_source]
  74. all_data_sources.extend(as_data_sources(self.additional_sources))
  75. return all_data_sources
  76. def get_all_expressions(
  77. self,
  78. include_select_expressions: bool = True,
  79. include_join_constraints: bool = False,
  80. ) -> list[Expression]:
  81. all_expressions = []
  82. if include_select_expressions:
  83. all_expressions.extend(self.select_expressions)
  84. if self.has_where_condition():
  85. all_expressions.append(self.where_expression)
  86. if self.custom_order_expressions is not None:
  87. all_expressions.extend(self.custom_order_expressions)
  88. if include_join_constraints:
  89. for additional_source in self.additional_sources:
  90. all_expressions.append(additional_source.join_constraint)
  91. return all_expressions
  92. def to_sql(
  93. self,
  94. strategy: EvaluationStrategy,
  95. output_format: QueryOutputFormat,
  96. query_column_selection: QueryColumnByIndexSelection,
  97. query_output_mode: QueryOutputMode,
  98. override_db_object_base_name: str | None = None,
  99. ) -> str:
  100. space_separator = self._get_space_separator(output_format)
  101. column_sql = self._create_column_sql(
  102. query_column_selection, space_separator, strategy.sql_adjuster
  103. )
  104. from_clause = self._create_from_clause(
  105. strategy,
  106. override_db_object_base_name,
  107. space_separator,
  108. )
  109. join_clauses = self._create_join_clauses(
  110. strategy,
  111. override_db_object_base_name,
  112. strategy.sql_adjuster,
  113. space_separator,
  114. )
  115. where_clause = self._create_where_clause(strategy.sql_adjuster)
  116. order_by_clause = self._create_order_by_clause(strategy.sql_adjuster)
  117. limit_clause = self._create_limit_clause()
  118. offset_clause = self._create_offset_clause()
  119. explain_mode = query_output_mode_to_sql(query_output_mode)
  120. sql = f"""
  121. {explain_mode} SELECT{space_separator}{column_sql}
  122. {from_clause}
  123. {join_clauses}
  124. {where_clause}
  125. {order_by_clause}
  126. {limit_clause}
  127. {offset_clause}
  128. """.strip()
  129. sql = f"{sql};"
  130. return self._post_format_sql(sql, output_format)
  131. def uses_join(self) -> bool:
  132. return self.count_joins() > 0
  133. def count_joins(self) -> int:
  134. return len(self.additional_sources)
  135. def has_where_condition(self) -> bool:
  136. return self.where_expression is not None
  137. def has_row_selection(self) -> bool:
  138. return self.row_selection.has_selection()
  139. def has_offset(self) -> bool:
  140. return self.limit is not None
  141. def has_limit(self) -> bool:
  142. return self.limit is not None
  143. def _get_space_separator(self, output_format: QueryOutputFormat) -> str:
  144. return "\n " if output_format == QueryOutputFormat.MULTI_LINE else " "
  145. def _create_column_sql(
  146. self,
  147. query_column_selection: QueryColumnByIndexSelection,
  148. space_separator: str,
  149. sql_adjuster: SqlDialectAdjuster,
  150. ) -> str:
  151. expressions_as_sql = []
  152. for index, expression in enumerate(self.select_expressions):
  153. if query_column_selection.is_included(index):
  154. expressions_as_sql.append(
  155. expression.to_sql(sql_adjuster, self.uses_join(), True)
  156. )
  157. return f",{space_separator}".join(expressions_as_sql)
  158. def _create_from_clause(
  159. self,
  160. strategy: EvaluationStrategy,
  161. override_db_object_base_name: str | None,
  162. space_separator: str,
  163. ) -> str:
  164. db_object_name = strategy.get_db_object_name(
  165. self.storage_layout,
  166. data_source=self.data_source,
  167. override_base_name=override_db_object_base_name,
  168. )
  169. alias = f" {self.data_source.alias()}" if self.uses_join() else ""
  170. return f"FROM{space_separator}{db_object_name}{alias}"
  171. def _create_join_clauses(
  172. self,
  173. strategy: EvaluationStrategy,
  174. override_db_object_base_name: str | None,
  175. sql_adjuster: SqlDialectAdjuster,
  176. space_separator: str,
  177. ) -> str:
  178. if len(self.additional_sources) == 0:
  179. # no JOIN necessary
  180. return ""
  181. join_clauses = ""
  182. for additional_source in self.additional_sources:
  183. join_clauses = (
  184. f"{join_clauses}"
  185. f"\n{self._create_join_clause(strategy, additional_source, override_db_object_base_name, sql_adjuster, space_separator)}"
  186. )
  187. return join_clauses
  188. def _create_join_clause(
  189. self,
  190. strategy: EvaluationStrategy,
  191. additional_source_to_join: AdditionalSource,
  192. override_db_object_base_name: str | None,
  193. sql_adjuster: SqlDialectAdjuster,
  194. space_separator: str,
  195. ) -> str:
  196. db_object_name_to_join = strategy.get_db_object_name(
  197. self.storage_layout,
  198. data_source=additional_source_to_join.data_source,
  199. override_base_name=override_db_object_base_name,
  200. )
  201. join_operator_sql = additional_source_to_join.join_operator.to_sql()
  202. return (
  203. f"{join_operator_sql} {db_object_name_to_join} {additional_source_to_join.data_source.alias()}"
  204. f"{space_separator}ON {additional_source_to_join.join_constraint.to_sql(sql_adjuster, True, True)}"
  205. )
  206. def _create_where_clause(self, sql_adjuster: SqlDialectAdjuster) -> str:
  207. where_conditions = []
  208. row_filter_clauses = self._create_row_filter_clauses()
  209. where_conditions.extend(row_filter_clauses)
  210. if self.where_expression:
  211. where_conditions.append(
  212. self.where_expression.to_sql(sql_adjuster, self.uses_join(), True)
  213. )
  214. if len(where_conditions) == 0:
  215. return ""
  216. # It is important that the condition parts are in parentheses so that they are connected with AND.
  217. # Otherwise, a generated condition containing OR at the top level may lift the row filter clause.
  218. all_conditions_sql = " AND ".join(
  219. [f"({condition})" for condition in where_conditions]
  220. )
  221. return f"WHERE {all_conditions_sql}"
  222. def _create_row_filter_clauses(self) -> list[str]:
  223. """Create a SQL clause to only include rows of certain indices"""
  224. row_filter_clauses = []
  225. for data_source in self.get_all_data_sources():
  226. if self.row_selection.includes_all_of_source(data_source):
  227. continue
  228. if len(self.row_selection.get_row_indices(data_source)) == 0:
  229. row_index_string = "-1"
  230. else:
  231. row_index_string = ", ".join(
  232. str(index)
  233. for index in sorted(self.row_selection.get_row_indices(data_source))
  234. )
  235. row_filter_clauses.append(
  236. f"{self._row_index_col_name(data_source)} IN ({row_index_string})"
  237. )
  238. return row_filter_clauses
  239. def _create_order_by_clause(self, sql_adjuster: SqlDialectAdjuster) -> str:
  240. if self.custom_order_expressions is not None:
  241. order_by_specs_str = ", ".join(
  242. [
  243. f"{expr.to_sql(sql_adjuster, self.uses_join(), True)} ASC"
  244. for expr in self.custom_order_expressions
  245. ]
  246. )
  247. return f"ORDER BY {order_by_specs_str}"
  248. if (
  249. self.storage_layout == ValueStorageLayout.VERTICAL
  250. and not self.contains_aggregations
  251. ):
  252. order_by_columns = []
  253. for data_source in self.get_all_data_sources():
  254. order_by_columns.append(f"{self._row_index_col_name(data_source)} ASC")
  255. order_by_columns_str = ", ".join(order_by_columns)
  256. return f"ORDER BY {order_by_columns_str}"
  257. return ""
  258. def _create_offset_clause(self) -> str:
  259. if self.offset is not None:
  260. return f"OFFSET {self.offset}"
  261. return ""
  262. def _create_limit_clause(self) -> str:
  263. if self.limit is not None:
  264. return f"LIMIT {self.limit}"
  265. return ""
  266. def _row_index_col_name(self, data_source: DataSource) -> str:
  267. if self.uses_join():
  268. return f"{data_source.alias()}.{ROW_INDEX_COL_NAME}"
  269. return ROW_INDEX_COL_NAME
  270. def _post_format_sql(self, sql: str, output_format: QueryOutputFormat) -> str:
  271. # apply this replacement twice
  272. sql = sql.replace("\n\n", "\n").replace("\n\n", "\n")
  273. sql = sql.replace("\n;", ";")
  274. if output_format == QueryOutputFormat.SINGLE_LINE:
  275. sql = sql.replace("\n", " ")
  276. return sql
  277. def collect_involved_vertical_table_indices(self) -> set[int] | None:
  278. if self.storage_layout == ValueStorageLayout.HORIZONTAL:
  279. return None
  280. assert self.storage_layout == ValueStorageLayout.VERTICAL
  281. table_indices = set()
  282. all_expressions = []
  283. all_expressions.extend(self.select_expressions)
  284. all_expressions.append(self.where_expression)
  285. all_expressions.extend(self.custom_order_expressions or [])
  286. for expression in all_expressions:
  287. table_indices.update(expression.collect_vertical_table_indices())
  288. return table_indices
  289. def column_count(self) -> int:
  290. return len(self.select_expressions)
  291. def __can_compare_error_messages(self) -> bool:
  292. if self.storage_layout == ValueStorageLayout.HORIZONTAL:
  293. return True
  294. for expression in self.select_expressions:
  295. if expression.contains_leaf_not_directly_consumed_by_aggregation():
  296. # The query operates on multiple rows and contains at least one non-aggregate function directly
  297. # operating on the value. Since the row processing order is not fixed, different evaluation
  298. # strategies may yield different error messages (depending on the first invalid value they
  299. # encounter). Therefore, error messages shall not be compared in case of a query failure.
  300. return False
  301. return True
  302. def matches_any_select_expression(
  303. self, predicate: Callable[[Expression], bool], check_recursively: bool
  304. ) -> bool:
  305. for expression in self.select_expressions:
  306. if expression.matches(predicate, check_recursively):
  307. return True
  308. return False
  309. def matches_any_expression(
  310. self, predicate: Callable[[Expression], bool], check_recursively: bool
  311. ) -> bool:
  312. return self.matches_any_select_expression(predicate, check_recursively) or (
  313. self.where_expression is not None
  314. and self.where_expression.matches(predicate, check_recursively)
  315. )
  316. def matches_specific_select_or_filter_expression(
  317. self,
  318. select_column_index: int,
  319. predicate: Callable[[Expression], bool],
  320. check_recursively: bool,
  321. ) -> bool:
  322. assert 0 <= select_column_index <= self.column_count()
  323. return self.select_expressions[select_column_index].matches(
  324. predicate, check_recursively
  325. ) or (
  326. self.where_expression is not None
  327. and self.where_expression.matches(predicate, check_recursively)
  328. )
  329. def get_involved_characteristics(
  330. self,
  331. query_column_selection: QueryColumnByIndexSelection,
  332. ) -> set[ExpressionCharacteristics]:
  333. all_involved_characteristics: set[ExpressionCharacteristics] = set()
  334. for index, expression in enumerate(self.select_expressions):
  335. if not query_column_selection.is_included(index):
  336. continue
  337. characteristics = expression.recursively_collect_involved_characteristics(
  338. self.row_selection
  339. )
  340. all_involved_characteristics.update(characteristics)
  341. for further_expression in self.get_all_expressions(
  342. include_select_expressions=False, include_join_constraints=True
  343. ):
  344. characteristics = (
  345. further_expression.recursively_collect_involved_characteristics(
  346. self.row_selection
  347. )
  348. )
  349. all_involved_characteristics.update(characteristics)
  350. return all_involved_characteristics
  351. def clone(
  352. self, expect_error: bool, select_expressions: list[Expression]
  353. ) -> QueryTemplate:
  354. return QueryTemplate(
  355. expect_error=expect_error,
  356. select_expressions=select_expressions,
  357. where_expression=self.where_expression,
  358. storage_layout=self.storage_layout,
  359. data_source=self.data_source,
  360. contains_aggregations=self.contains_aggregations,
  361. row_selection=self.row_selection,
  362. offset=self.offset,
  363. limit=self.limit,
  364. additional_sources=self.additional_sources,
  365. custom_order_expressions=self.custom_order_expressions,
  366. )