123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242 |
- # Copyright Materialize, Inc. and contributors. All rights reserved.
- #
- # Use of this software is governed by the Business Source License
- # included in the LICENSE file at the root of this repository.
- #
- # As of the Change Date specified in that file, in accordance with
- # the Business Source License, use of this software will be governed
- # by the Apache License, Version 2.0.
- from __future__ import annotations
- from collections.abc import Callable
- from materialize.output_consistency.data_type.data_type import DataType
- from materialize.output_consistency.data_type.data_type_category import DataTypeCategory
- from materialize.output_consistency.data_value.source_column_identifier import (
- SourceColumnIdentifier,
- )
- from materialize.output_consistency.execution.sql_dialect_adjuster import (
- SqlDialectAdjuster,
- )
- from materialize.output_consistency.execution.value_storage_layout import (
- ValueStorageLayout,
- )
- from materialize.output_consistency.expression.expression_characteristics import (
- ExpressionCharacteristics,
- )
- from materialize.output_consistency.operation.return_type_spec import ReturnTypeSpec
- from materialize.output_consistency.query.data_source import DataSource
- from materialize.output_consistency.selection.row_selection import (
- ALL_ROWS_SELECTION,
- DataRowSelection,
- )
- from materialize.util import stable_int_hash
- class Expression:
- """An expression is either a `LeafExpression` or a `ExpressionWithArgs`"""
- def __init__(
- self,
- characteristics: set[ExpressionCharacteristics],
- storage_layout: ValueStorageLayout,
- is_aggregate: bool,
- is_expect_error: bool,
- ):
- # own characteristics without the ones of children
- self.own_characteristics = characteristics
- self.storage_layout = storage_layout
- self.is_aggregate = is_aggregate
- self.is_expect_error = is_expect_error
- self.is_shared = False
- def to_sql(
- self, sql_adjuster: SqlDialectAdjuster, include_alias: bool, is_root_level: bool
- ) -> str:
- raise NotImplementedError
- def hash(self) -> int:
- """
- The primary purpose of this method is to allow conditional breakpoints when debugging.
- """
- raise NotImplementedError
- def resolve_return_type_spec(self) -> ReturnTypeSpec:
- """
- Ignore filters should favor #resolve_return_type_category over this method whenever possible because that method
- resolves dynamic types.
- :return: the return type spec
- """
- raise NotImplementedError
- def resolve_return_type_category(self) -> DataTypeCategory:
- """
- :return: the data type category of this value
- """
- raise NotImplementedError
- def resolve_resulting_return_type_category(self) -> DataTypeCategory:
- """
- :return: the data type category that the use of this value will lead to
- """
- return self.resolve_return_type_category()
- def try_resolve_exact_data_type(self) -> DataType | None:
- raise NotImplementedError
- def recursively_collect_involved_characteristics(
- self, row_selection: DataRowSelection
- ) -> set[ExpressionCharacteristics]:
- """Get all involved characteristics through recursion"""
- raise NotImplementedError
- def collect_leaves(self) -> list[LeafExpression]:
- raise NotImplementedError
- def collect_data_sources(self) -> list[DataSource]:
- data_sources = []
- for leaf in self.collect_leaves():
- data_source = leaf.get_data_source()
- if data_source is not None:
- data_sources.append(data_source)
- return data_sources
- def collect_vertical_table_indices(self) -> set[int]:
- raise NotImplementedError
- def __str__(self) -> str:
- raise NotImplementedError
- def has_all_characteristics(
- self,
- characteristics: set[ExpressionCharacteristics],
- recursive: bool = True,
- row_selection: DataRowSelection = ALL_ROWS_SELECTION,
- ) -> bool:
- """True if this expression itself exhibits all characteristics."""
- present_characteristics = (
- self.own_characteristics
- if not recursive
- else self.recursively_collect_involved_characteristics(row_selection)
- )
- overlap = present_characteristics & characteristics
- return len(overlap) == len(characteristics)
- def has_any_characteristic(
- self,
- characteristics: set[ExpressionCharacteristics],
- recursive: bool = True,
- row_selection: DataRowSelection = ALL_ROWS_SELECTION,
- ) -> bool:
- """True if this expression itself exhibits any of the characteristics."""
- present_characteristics = (
- self.own_characteristics
- if not recursive
- else self.recursively_collect_involved_characteristics(row_selection)
- )
- overlap = present_characteristics & characteristics
- return len(overlap) > 0
- def is_leaf(self) -> bool:
- raise NotImplementedError
- def contains_leaf_not_directly_consumed_by_aggregation(self) -> bool:
- """
- True if any leaf is not directly consumed by an aggregation,
- hence false if all leaves of this expression are directly consumed by an aggregation.
- This is relevant because when using non-aggregate functions on multiple rows, different evaluation strategies may yield different error messages due to a different row processing order.
- """
- raise NotImplementedError
- def matches(
- self, predicate: Callable[[Expression], bool], apply_recursively: bool
- ) -> bool:
- # recursion is implemented in ExpressionWithArgs
- return predicate(self)
- def contains(
- self, predicate: Callable[[Expression], bool], check_recursively: bool
- ) -> bool:
- return self.matches(predicate, check_recursively)
- def recursively_mark_as_shared(self) -> None:
- """
- Mark that this expression is used multiple times within a query.
- All instances will use the same data source.
- """
- self.is_shared = True
- class LeafExpression(Expression):
- def __init__(
- self,
- column_name: str,
- data_type: DataType,
- characteristics: set[ExpressionCharacteristics],
- storage_layout: ValueStorageLayout,
- data_source: DataSource | None,
- is_aggregate: bool = False,
- is_expect_error: bool = False,
- ):
- super().__init__(characteristics, storage_layout, is_aggregate, is_expect_error)
- self.column_name = column_name
- self.data_type = data_type
- self.data_source = data_source
- def hash(self) -> int:
- return stable_int_hash(self.column_name)
- def resolve_data_type_category(self) -> DataTypeCategory:
- return self.data_type.category
- def try_resolve_exact_data_type(self) -> DataType | None:
- return self.data_type
- def to_sql(
- self, sql_adjuster: SqlDialectAdjuster, include_alias: bool, is_root_level: bool
- ) -> str:
- return self.to_sql_as_column(
- sql_adjuster, include_alias, self.column_name, self.get_data_source()
- )
- def to_sql_as_column(
- self,
- sql_adjuster: SqlDialectAdjuster,
- include_alias: bool,
- column_name: str,
- data_source: DataSource | None,
- ) -> str:
- if include_alias:
- assert data_source is not None, "data source is None"
- return f"{data_source.alias()}.{column_name}"
- return column_name
- def collect_leaves(self) -> list[LeafExpression]:
- return [self]
- def is_leaf(self) -> bool:
- return True
- def contains_leaf_not_directly_consumed_by_aggregation(self) -> bool:
- # This is not decided at leaf level.
- return False
- def recursively_collect_involved_characteristics(
- self, row_selection: DataRowSelection
- ) -> set[ExpressionCharacteristics]:
- return self.own_characteristics
- def get_data_source(self) -> DataSource | None:
- return self.data_source
- def get_source_column_identifier(self) -> SourceColumnIdentifier | None:
- data_source = self.get_data_source()
- if data_source is None:
- return None
- return SourceColumnIdentifier(
- data_source_alias=data_source.alias(),
- column_name=self.column_name,
- )
|