expression.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. # Copyright Materialize, Inc. and contributors. All rights reserved.
  2. #
  3. # Use of this software is governed by the Business Source License
  4. # included in the LICENSE file at the root of this repository.
  5. #
  6. # As of the Change Date specified in that file, in accordance with
  7. # the Business Source License, use of this software will be governed
  8. # by the Apache License, Version 2.0.
  9. from __future__ import annotations
  10. from collections.abc import Callable
  11. from materialize.output_consistency.data_type.data_type import DataType
  12. from materialize.output_consistency.data_type.data_type_category import DataTypeCategory
  13. from materialize.output_consistency.data_value.source_column_identifier import (
  14. SourceColumnIdentifier,
  15. )
  16. from materialize.output_consistency.execution.sql_dialect_adjuster import (
  17. SqlDialectAdjuster,
  18. )
  19. from materialize.output_consistency.execution.value_storage_layout import (
  20. ValueStorageLayout,
  21. )
  22. from materialize.output_consistency.expression.expression_characteristics import (
  23. ExpressionCharacteristics,
  24. )
  25. from materialize.output_consistency.operation.return_type_spec import ReturnTypeSpec
  26. from materialize.output_consistency.query.data_source import DataSource
  27. from materialize.output_consistency.selection.row_selection import (
  28. ALL_ROWS_SELECTION,
  29. DataRowSelection,
  30. )
  31. from materialize.util import stable_int_hash
  32. class Expression:
  33. """An expression is either a `LeafExpression` or a `ExpressionWithArgs`"""
  34. def __init__(
  35. self,
  36. characteristics: set[ExpressionCharacteristics],
  37. storage_layout: ValueStorageLayout,
  38. is_aggregate: bool,
  39. is_expect_error: bool,
  40. ):
  41. # own characteristics without the ones of children
  42. self.own_characteristics = characteristics
  43. self.storage_layout = storage_layout
  44. self.is_aggregate = is_aggregate
  45. self.is_expect_error = is_expect_error
  46. self.is_shared = False
  47. def to_sql(
  48. self, sql_adjuster: SqlDialectAdjuster, include_alias: bool, is_root_level: bool
  49. ) -> str:
  50. raise NotImplementedError
  51. def hash(self) -> int:
  52. """
  53. The primary purpose of this method is to allow conditional breakpoints when debugging.
  54. """
  55. raise NotImplementedError
  56. def resolve_return_type_spec(self) -> ReturnTypeSpec:
  57. """
  58. Ignore filters should favor #resolve_return_type_category over this method whenever possible because that method
  59. resolves dynamic types.
  60. :return: the return type spec
  61. """
  62. raise NotImplementedError
  63. def resolve_return_type_category(self) -> DataTypeCategory:
  64. """
  65. :return: the data type category of this value
  66. """
  67. raise NotImplementedError
  68. def resolve_resulting_return_type_category(self) -> DataTypeCategory:
  69. """
  70. :return: the data type category that the use of this value will lead to
  71. """
  72. return self.resolve_return_type_category()
  73. def try_resolve_exact_data_type(self) -> DataType | None:
  74. raise NotImplementedError
  75. def recursively_collect_involved_characteristics(
  76. self, row_selection: DataRowSelection
  77. ) -> set[ExpressionCharacteristics]:
  78. """Get all involved characteristics through recursion"""
  79. raise NotImplementedError
  80. def collect_leaves(self) -> list[LeafExpression]:
  81. raise NotImplementedError
  82. def collect_data_sources(self) -> list[DataSource]:
  83. data_sources = []
  84. for leaf in self.collect_leaves():
  85. data_source = leaf.get_data_source()
  86. if data_source is not None:
  87. data_sources.append(data_source)
  88. return data_sources
  89. def collect_vertical_table_indices(self) -> set[int]:
  90. raise NotImplementedError
  91. def __str__(self) -> str:
  92. raise NotImplementedError
  93. def has_all_characteristics(
  94. self,
  95. characteristics: set[ExpressionCharacteristics],
  96. recursive: bool = True,
  97. row_selection: DataRowSelection = ALL_ROWS_SELECTION,
  98. ) -> bool:
  99. """True if this expression itself exhibits all characteristics."""
  100. present_characteristics = (
  101. self.own_characteristics
  102. if not recursive
  103. else self.recursively_collect_involved_characteristics(row_selection)
  104. )
  105. overlap = present_characteristics & characteristics
  106. return len(overlap) == len(characteristics)
  107. def has_any_characteristic(
  108. self,
  109. characteristics: set[ExpressionCharacteristics],
  110. recursive: bool = True,
  111. row_selection: DataRowSelection = ALL_ROWS_SELECTION,
  112. ) -> bool:
  113. """True if this expression itself exhibits any of the characteristics."""
  114. present_characteristics = (
  115. self.own_characteristics
  116. if not recursive
  117. else self.recursively_collect_involved_characteristics(row_selection)
  118. )
  119. overlap = present_characteristics & characteristics
  120. return len(overlap) > 0
  121. def is_leaf(self) -> bool:
  122. raise NotImplementedError
  123. def contains_leaf_not_directly_consumed_by_aggregation(self) -> bool:
  124. """
  125. True if any leaf is not directly consumed by an aggregation,
  126. hence false if all leaves of this expression are directly consumed by an aggregation.
  127. This is relevant because when using non-aggregate functions on multiple rows, different evaluation strategies may yield different error messages due to a different row processing order.
  128. """
  129. raise NotImplementedError
  130. def matches(
  131. self, predicate: Callable[[Expression], bool], apply_recursively: bool
  132. ) -> bool:
  133. # recursion is implemented in ExpressionWithArgs
  134. return predicate(self)
  135. def contains(
  136. self, predicate: Callable[[Expression], bool], check_recursively: bool
  137. ) -> bool:
  138. return self.matches(predicate, check_recursively)
  139. def recursively_mark_as_shared(self) -> None:
  140. """
  141. Mark that this expression is used multiple times within a query.
  142. All instances will use the same data source.
  143. """
  144. self.is_shared = True
  145. class LeafExpression(Expression):
  146. def __init__(
  147. self,
  148. column_name: str,
  149. data_type: DataType,
  150. characteristics: set[ExpressionCharacteristics],
  151. storage_layout: ValueStorageLayout,
  152. data_source: DataSource | None,
  153. is_aggregate: bool = False,
  154. is_expect_error: bool = False,
  155. ):
  156. super().__init__(characteristics, storage_layout, is_aggregate, is_expect_error)
  157. self.column_name = column_name
  158. self.data_type = data_type
  159. self.data_source = data_source
  160. def hash(self) -> int:
  161. return stable_int_hash(self.column_name)
  162. def resolve_data_type_category(self) -> DataTypeCategory:
  163. return self.data_type.category
  164. def try_resolve_exact_data_type(self) -> DataType | None:
  165. return self.data_type
  166. def to_sql(
  167. self, sql_adjuster: SqlDialectAdjuster, include_alias: bool, is_root_level: bool
  168. ) -> str:
  169. return self.to_sql_as_column(
  170. sql_adjuster, include_alias, self.column_name, self.get_data_source()
  171. )
  172. def to_sql_as_column(
  173. self,
  174. sql_adjuster: SqlDialectAdjuster,
  175. include_alias: bool,
  176. column_name: str,
  177. data_source: DataSource | None,
  178. ) -> str:
  179. if include_alias:
  180. assert data_source is not None, "data source is None"
  181. return f"{data_source.alias()}.{column_name}"
  182. return column_name
  183. def collect_leaves(self) -> list[LeafExpression]:
  184. return [self]
  185. def is_leaf(self) -> bool:
  186. return True
  187. def contains_leaf_not_directly_consumed_by_aggregation(self) -> bool:
  188. # This is not decided at leaf level.
  189. return False
  190. def recursively_collect_involved_characteristics(
  191. self, row_selection: DataRowSelection
  192. ) -> set[ExpressionCharacteristics]:
  193. return self.own_characteristics
  194. def get_data_source(self) -> DataSource | None:
  195. return self.data_source
  196. def get_source_column_identifier(self) -> SourceColumnIdentifier | None:
  197. data_source = self.get_data_source()
  198. if data_source is None:
  199. return None
  200. return SourceColumnIdentifier(
  201. data_source_alias=data_source.alias(),
  202. column_name=self.column_name,
  203. )