expression_with_args.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. # Copyright Materialize, Inc. and contributors. All rights reserved.
  2. #
  3. # Use of this software is governed by the Business Source License
  4. # included in the LICENSE file at the root of this repository.
  5. #
  6. # As of the Change Date specified in that file, in accordance with
  7. # the Business Source License, use of this software will be governed
  8. # by the Apache License, Version 2.0.
  9. from collections.abc import Callable
  10. from materialize.output_consistency.data_type.data_type import DataType
  11. from materialize.output_consistency.data_type.data_type_category import DataTypeCategory
  12. from materialize.output_consistency.execution.sql_dialect_adjuster import (
  13. SqlDialectAdjuster,
  14. )
  15. from materialize.output_consistency.execution.value_storage_layout import (
  16. ValueStorageLayout,
  17. )
  18. from materialize.output_consistency.expression.expression import (
  19. Expression,
  20. LeafExpression,
  21. )
  22. from materialize.output_consistency.expression.expression_characteristics import (
  23. ExpressionCharacteristics,
  24. )
  25. from materialize.output_consistency.operation.operation import (
  26. EXPRESSION_PLACEHOLDER,
  27. DbOperationOrFunction,
  28. )
  29. from materialize.output_consistency.operation.return_type_spec import (
  30. InputArgTypeHints,
  31. ReturnTypeSpec,
  32. )
  33. from materialize.output_consistency.selection.row_selection import DataRowSelection
  34. from materialize.util import stable_int_hash
  35. class ExpressionWithArgs(Expression):
  36. """An expression representing a usage of a database operation or function"""
  37. def __init__(
  38. self,
  39. operation: DbOperationOrFunction,
  40. args: list[Expression],
  41. is_aggregate: bool = False,
  42. is_expect_error: bool = False,
  43. ):
  44. super().__init__(
  45. operation.derive_characteristics(args),
  46. _determine_storage_layout(args),
  47. is_aggregate,
  48. is_expect_error,
  49. )
  50. self.operation = operation
  51. self.pattern = operation.to_pattern(len(args))
  52. self.return_type_spec = operation.return_type_spec
  53. self.args = args
  54. def hash(self) -> int:
  55. return stable_int_hash(
  56. self.operation.to_pattern(self.count_args()),
  57. *[str(arg.hash()) for arg in self.args],
  58. )
  59. def count_args(self) -> int:
  60. return len(self.args)
  61. def has_args(self) -> bool:
  62. return self.count_args() > 0
  63. def to_sql(
  64. self, sql_adjuster: SqlDialectAdjuster, include_alias: bool, is_root_level: bool
  65. ) -> str:
  66. sql: str = self.pattern
  67. for arg in self.args:
  68. sql = sql.replace(
  69. EXPRESSION_PLACEHOLDER,
  70. arg.to_sql(sql_adjuster, include_alias, False),
  71. 1,
  72. )
  73. if len(self.args) != self.pattern.count(EXPRESSION_PLACEHOLDER):
  74. raise RuntimeError(
  75. f"Not enough arguments to fill all placeholders in pattern {self.pattern}"
  76. )
  77. if (
  78. is_root_level
  79. and self.resolve_return_type_category() == DataTypeCategory.DATE_TIME
  80. ):
  81. # workaround because the max date type in python is smaller than values supported by mz
  82. sql = f"({sql})::TEXT"
  83. return sql
  84. def resolve_return_type_spec(self) -> ReturnTypeSpec:
  85. return self.return_type_spec
  86. def resolve_return_type_category(self) -> DataTypeCategory:
  87. input_type_hints = InputArgTypeHints()
  88. if self.return_type_spec.indices_of_required_input_type_hints is not None:
  89. # provide input types that are required as hints to determine the output type
  90. for arg_index in self.return_type_spec.indices_of_required_input_type_hints:
  91. assert (
  92. 0 <= arg_index <= len(self.args)
  93. ), f"Invalid requested index: {arg_index} as hint for {self.operation}"
  94. input_type_hints.type_category_of_requested_args[arg_index] = self.args[
  95. arg_index
  96. ].resolve_resulting_return_type_category()
  97. if self.return_type_spec.requires_return_type_spec_hints:
  98. input_type_hints.return_type_spec_of_requested_args[arg_index] = (
  99. self.args[arg_index].resolve_return_type_spec()
  100. )
  101. return self.return_type_spec.resolve_type_category(input_type_hints)
  102. def try_resolve_exact_data_type(self) -> DataType | None:
  103. return self.operation.try_resolve_exact_data_type(self.args)
  104. def __str__(self) -> str:
  105. args_desc = ", ".join(arg.__str__() for arg in self.args)
  106. return f"ExpressionWithArgs (pattern='{self.pattern}', args=[{args_desc}])"
  107. def recursively_collect_involved_characteristics(
  108. self, row_selection: DataRowSelection
  109. ) -> set[ExpressionCharacteristics]:
  110. involved_characteristics: set[ExpressionCharacteristics] = set()
  111. involved_characteristics = involved_characteristics.union(
  112. self.own_characteristics
  113. )
  114. for arg in self.args:
  115. involved_characteristics = involved_characteristics.union(
  116. arg.recursively_collect_involved_characteristics(row_selection)
  117. )
  118. return involved_characteristics
  119. def collect_leaves(self) -> list[LeafExpression]:
  120. leaves = []
  121. for arg in self.args:
  122. leaves.extend(arg.collect_leaves())
  123. return leaves
  124. def collect_vertical_table_indices(self) -> set[int]:
  125. vertical_table_indices = set()
  126. for arg in self.args:
  127. vertical_table_indices.update(arg.collect_vertical_table_indices())
  128. return vertical_table_indices
  129. def is_leaf(self) -> bool:
  130. return False
  131. def matches(
  132. self, predicate: Callable[[Expression], bool], apply_recursively: bool
  133. ) -> bool:
  134. if super().matches(predicate, apply_recursively):
  135. return True
  136. if apply_recursively:
  137. for arg in self.args:
  138. if arg.matches(predicate, apply_recursively):
  139. return True
  140. return False
  141. def contains_leaf_not_directly_consumed_by_aggregation(self) -> bool:
  142. for arg in self.args:
  143. if arg.is_leaf() and not self.is_aggregate:
  144. return True
  145. elif (
  146. not arg.is_leaf()
  147. and arg.contains_leaf_not_directly_consumed_by_aggregation()
  148. ):
  149. return True
  150. return False
  151. def operation_to_pattern(self) -> str:
  152. return self.operation.to_pattern(self.count_args())
  153. def recursively_mark_as_shared(self) -> None:
  154. super().recursively_mark_as_shared()
  155. for arg in self.args:
  156. arg.recursively_mark_as_shared()
  157. def _determine_storage_layout(args: list[Expression]) -> ValueStorageLayout:
  158. mutual_storage_layout: ValueStorageLayout | None = None
  159. for arg in args:
  160. if (
  161. mutual_storage_layout is None
  162. or mutual_storage_layout == ValueStorageLayout.ANY
  163. ):
  164. mutual_storage_layout = arg.storage_layout
  165. elif arg.storage_layout == ValueStorageLayout.ANY:
  166. continue
  167. elif mutual_storage_layout != arg.storage_layout:
  168. raise RuntimeError(
  169. f"It is not allowed to mix storage layouts in an expression (current={mutual_storage_layout}, got={arg.storage_layout})"
  170. )
  171. if mutual_storage_layout is None:
  172. # use this as default (in case there are no args)
  173. return ValueStorageLayout.ANY
  174. return mutual_storage_layout