123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442 |
- # Copyright Materialize, Inc. and contributors. All rights reserved.
- #
- # Use of this software is governed by the Business Source License
- # included in the LICENSE file at the root of this repository.
- #
- # As of the Change Date specified in that file, in accordance with
- # the Business Source License, use of this software will be governed
- # by the Apache License, Version 2.0.
- from __future__ import annotations
- from collections.abc import Callable
- from materialize.output_consistency.execution.evaluation_strategy import (
- EvaluationStrategy,
- )
- from materialize.output_consistency.execution.query_output_mode import (
- QueryOutputMode,
- query_output_mode_to_sql,
- )
- from materialize.output_consistency.execution.sql_dialect_adjuster import (
- SqlDialectAdjuster,
- )
- from materialize.output_consistency.execution.value_storage_layout import (
- ROW_INDEX_COL_NAME,
- ValueStorageLayout,
- )
- from materialize.output_consistency.expression.expression import Expression
- from materialize.output_consistency.expression.expression_characteristics import (
- ExpressionCharacteristics,
- )
- from materialize.output_consistency.query.additional_source import (
- AdditionalSource,
- as_data_sources,
- )
- from materialize.output_consistency.query.data_source import (
- DataSource,
- )
- from materialize.output_consistency.query.query_format import QueryOutputFormat
- from materialize.output_consistency.selection.column_selection import (
- QueryColumnByIndexSelection,
- )
- from materialize.output_consistency.selection.row_selection import (
- DataRowSelection,
- )
- class QueryTemplate:
- """Query template as base for creating SQL for different evaluation strategies"""
- def __init__(
- self,
- expect_error: bool,
- select_expressions: list[Expression],
- where_expression: Expression | None,
- storage_layout: ValueStorageLayout,
- data_source: DataSource,
- contains_aggregations: bool,
- row_selection: DataRowSelection,
- offset: int | None = None,
- limit: int | None = None,
- additional_sources: list[AdditionalSource] = [],
- custom_order_expressions: list[Expression] | None = None,
- ) -> None:
- assert storage_layout != ValueStorageLayout.ANY
- self.expect_error = expect_error
- self.select_expressions: list[Expression] = select_expressions
- self.where_expression = where_expression
- self.storage_layout = storage_layout
- self.data_source = data_source
- self.additional_sources = additional_sources
- self.contains_aggregations = contains_aggregations
- self.row_selection = row_selection
- self.offset = offset
- self.limit = limit
- self.custom_order_expressions = custom_order_expressions
- self.disable_error_message_validation = not self.__can_compare_error_messages()
- def get_all_data_sources(self) -> list[DataSource]:
- all_data_sources = [self.data_source]
- all_data_sources.extend(as_data_sources(self.additional_sources))
- return all_data_sources
- def get_all_expressions(
- self,
- include_select_expressions: bool = True,
- include_join_constraints: bool = False,
- ) -> list[Expression]:
- all_expressions = []
- if include_select_expressions:
- all_expressions.extend(self.select_expressions)
- if self.has_where_condition():
- all_expressions.append(self.where_expression)
- if self.custom_order_expressions is not None:
- all_expressions.extend(self.custom_order_expressions)
- if include_join_constraints:
- for additional_source in self.additional_sources:
- all_expressions.append(additional_source.join_constraint)
- return all_expressions
- def to_sql(
- self,
- strategy: EvaluationStrategy,
- output_format: QueryOutputFormat,
- query_column_selection: QueryColumnByIndexSelection,
- query_output_mode: QueryOutputMode,
- override_db_object_base_name: str | None = None,
- ) -> str:
- space_separator = self._get_space_separator(output_format)
- column_sql = self._create_column_sql(
- query_column_selection, space_separator, strategy.sql_adjuster
- )
- from_clause = self._create_from_clause(
- strategy,
- override_db_object_base_name,
- space_separator,
- )
- join_clauses = self._create_join_clauses(
- strategy,
- override_db_object_base_name,
- strategy.sql_adjuster,
- space_separator,
- )
- where_clause = self._create_where_clause(strategy.sql_adjuster)
- order_by_clause = self._create_order_by_clause(strategy.sql_adjuster)
- limit_clause = self._create_limit_clause()
- offset_clause = self._create_offset_clause()
- explain_mode = query_output_mode_to_sql(query_output_mode)
- sql = f"""
- {explain_mode} SELECT{space_separator}{column_sql}
- {from_clause}
- {join_clauses}
- {where_clause}
- {order_by_clause}
- {limit_clause}
- {offset_clause}
- """.strip()
- sql = f"{sql};"
- return self._post_format_sql(sql, output_format)
- def uses_join(self) -> bool:
- return self.count_joins() > 0
- def count_joins(self) -> int:
- return len(self.additional_sources)
- def has_where_condition(self) -> bool:
- return self.where_expression is not None
- def has_row_selection(self) -> bool:
- return self.row_selection.has_selection()
- def has_offset(self) -> bool:
- return self.limit is not None
- def has_limit(self) -> bool:
- return self.limit is not None
- def _get_space_separator(self, output_format: QueryOutputFormat) -> str:
- return "\n " if output_format == QueryOutputFormat.MULTI_LINE else " "
- def _create_column_sql(
- self,
- query_column_selection: QueryColumnByIndexSelection,
- space_separator: str,
- sql_adjuster: SqlDialectAdjuster,
- ) -> str:
- expressions_as_sql = []
- for index, expression in enumerate(self.select_expressions):
- if query_column_selection.is_included(index):
- expressions_as_sql.append(
- expression.to_sql(sql_adjuster, self.uses_join(), True)
- )
- return f",{space_separator}".join(expressions_as_sql)
- def _create_from_clause(
- self,
- strategy: EvaluationStrategy,
- override_db_object_base_name: str | None,
- space_separator: str,
- ) -> str:
- db_object_name = strategy.get_db_object_name(
- self.storage_layout,
- data_source=self.data_source,
- override_base_name=override_db_object_base_name,
- )
- alias = f" {self.data_source.alias()}" if self.uses_join() else ""
- return f"FROM{space_separator}{db_object_name}{alias}"
- def _create_join_clauses(
- self,
- strategy: EvaluationStrategy,
- override_db_object_base_name: str | None,
- sql_adjuster: SqlDialectAdjuster,
- space_separator: str,
- ) -> str:
- if len(self.additional_sources) == 0:
- # no JOIN necessary
- return ""
- join_clauses = ""
- for additional_source in self.additional_sources:
- join_clauses = (
- f"{join_clauses}"
- f"\n{self._create_join_clause(strategy, additional_source, override_db_object_base_name, sql_adjuster, space_separator)}"
- )
- return join_clauses
- def _create_join_clause(
- self,
- strategy: EvaluationStrategy,
- additional_source_to_join: AdditionalSource,
- override_db_object_base_name: str | None,
- sql_adjuster: SqlDialectAdjuster,
- space_separator: str,
- ) -> str:
- db_object_name_to_join = strategy.get_db_object_name(
- self.storage_layout,
- data_source=additional_source_to_join.data_source,
- override_base_name=override_db_object_base_name,
- )
- join_operator_sql = additional_source_to_join.join_operator.to_sql()
- return (
- f"{join_operator_sql} {db_object_name_to_join} {additional_source_to_join.data_source.alias()}"
- f"{space_separator}ON {additional_source_to_join.join_constraint.to_sql(sql_adjuster, True, True)}"
- )
- def _create_where_clause(self, sql_adjuster: SqlDialectAdjuster) -> str:
- where_conditions = []
- row_filter_clauses = self._create_row_filter_clauses()
- where_conditions.extend(row_filter_clauses)
- if self.where_expression:
- where_conditions.append(
- self.where_expression.to_sql(sql_adjuster, self.uses_join(), True)
- )
- if len(where_conditions) == 0:
- return ""
- # It is important that the condition parts are in parentheses so that they are connected with AND.
- # Otherwise, a generated condition containing OR at the top level may lift the row filter clause.
- all_conditions_sql = " AND ".join(
- [f"({condition})" for condition in where_conditions]
- )
- return f"WHERE {all_conditions_sql}"
- def _create_row_filter_clauses(self) -> list[str]:
- """Create a SQL clause to only include rows of certain indices"""
- row_filter_clauses = []
- for data_source in self.get_all_data_sources():
- if self.row_selection.includes_all_of_source(data_source):
- continue
- if len(self.row_selection.get_row_indices(data_source)) == 0:
- row_index_string = "-1"
- else:
- row_index_string = ", ".join(
- str(index)
- for index in sorted(self.row_selection.get_row_indices(data_source))
- )
- row_filter_clauses.append(
- f"{self._row_index_col_name(data_source)} IN ({row_index_string})"
- )
- return row_filter_clauses
- def _create_order_by_clause(self, sql_adjuster: SqlDialectAdjuster) -> str:
- if self.custom_order_expressions is not None:
- order_by_specs_str = ", ".join(
- [
- f"{expr.to_sql(sql_adjuster, self.uses_join(), True)} ASC"
- for expr in self.custom_order_expressions
- ]
- )
- return f"ORDER BY {order_by_specs_str}"
- if (
- self.storage_layout == ValueStorageLayout.VERTICAL
- and not self.contains_aggregations
- ):
- order_by_columns = []
- for data_source in self.get_all_data_sources():
- order_by_columns.append(f"{self._row_index_col_name(data_source)} ASC")
- order_by_columns_str = ", ".join(order_by_columns)
- return f"ORDER BY {order_by_columns_str}"
- return ""
- def _create_offset_clause(self) -> str:
- if self.offset is not None:
- return f"OFFSET {self.offset}"
- return ""
- def _create_limit_clause(self) -> str:
- if self.limit is not None:
- return f"LIMIT {self.limit}"
- return ""
- def _row_index_col_name(self, data_source: DataSource) -> str:
- if self.uses_join():
- return f"{data_source.alias()}.{ROW_INDEX_COL_NAME}"
- return ROW_INDEX_COL_NAME
- def _post_format_sql(self, sql: str, output_format: QueryOutputFormat) -> str:
- # apply this replacement twice
- sql = sql.replace("\n\n", "\n").replace("\n\n", "\n")
- sql = sql.replace("\n;", ";")
- if output_format == QueryOutputFormat.SINGLE_LINE:
- sql = sql.replace("\n", " ")
- return sql
- def collect_involved_vertical_table_indices(self) -> set[int] | None:
- if self.storage_layout == ValueStorageLayout.HORIZONTAL:
- return None
- assert self.storage_layout == ValueStorageLayout.VERTICAL
- table_indices = set()
- all_expressions = []
- all_expressions.extend(self.select_expressions)
- all_expressions.append(self.where_expression)
- all_expressions.extend(self.custom_order_expressions or [])
- for expression in all_expressions:
- table_indices.update(expression.collect_vertical_table_indices())
- return table_indices
- def column_count(self) -> int:
- return len(self.select_expressions)
- def __can_compare_error_messages(self) -> bool:
- if self.storage_layout == ValueStorageLayout.HORIZONTAL:
- return True
- for expression in self.select_expressions:
- if expression.contains_leaf_not_directly_consumed_by_aggregation():
- # The query operates on multiple rows and contains at least one non-aggregate function directly
- # operating on the value. Since the row processing order is not fixed, different evaluation
- # strategies may yield different error messages (depending on the first invalid value they
- # encounter). Therefore, error messages shall not be compared in case of a query failure.
- return False
- return True
- def matches_any_select_expression(
- self, predicate: Callable[[Expression], bool], check_recursively: bool
- ) -> bool:
- for expression in self.select_expressions:
- if expression.matches(predicate, check_recursively):
- return True
- return False
- def matches_any_expression(
- self, predicate: Callable[[Expression], bool], check_recursively: bool
- ) -> bool:
- return self.matches_any_select_expression(predicate, check_recursively) or (
- self.where_expression is not None
- and self.where_expression.matches(predicate, check_recursively)
- )
- def matches_specific_select_or_filter_expression(
- self,
- select_column_index: int,
- predicate: Callable[[Expression], bool],
- check_recursively: bool,
- ) -> bool:
- assert 0 <= select_column_index <= self.column_count()
- return self.select_expressions[select_column_index].matches(
- predicate, check_recursively
- ) or (
- self.where_expression is not None
- and self.where_expression.matches(predicate, check_recursively)
- )
- def get_involved_characteristics(
- self,
- query_column_selection: QueryColumnByIndexSelection,
- ) -> set[ExpressionCharacteristics]:
- all_involved_characteristics: set[ExpressionCharacteristics] = set()
- for index, expression in enumerate(self.select_expressions):
- if not query_column_selection.is_included(index):
- continue
- characteristics = expression.recursively_collect_involved_characteristics(
- self.row_selection
- )
- all_involved_characteristics.update(characteristics)
- for further_expression in self.get_all_expressions(
- include_select_expressions=False, include_join_constraints=True
- ):
- characteristics = (
- further_expression.recursively_collect_involved_characteristics(
- self.row_selection
- )
- )
- all_involved_characteristics.update(characteristics)
- return all_involved_characteristics
- def clone(
- self, expect_error: bool, select_expressions: list[Expression]
- ) -> QueryTemplate:
- return QueryTemplate(
- expect_error=expect_error,
- select_expressions=select_expressions,
- where_expression=self.where_expression,
- storage_layout=self.storage_layout,
- data_source=self.data_source,
- contains_aggregations=self.contains_aggregations,
- row_selection=self.row_selection,
- offset=self.offset,
- limit=self.limit,
- additional_sources=self.additional_sources,
- custom_order_expressions=self.custom_order_expressions,
- )
|