123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311 |
- # Copyright Materialize, Inc. and contributors. All rights reserved.
- #
- # Use of this software is governed by the Business Source License
- # included in the LICENSE file at the root of this repository.
- #
- # As of the Change Date specified in that file, in accordance with
- # the Business Source License, use of this software will be governed
- # by the Apache License, Version 2.0.
- from enum import Enum
- from materialize.mz_version import MzVersion
- from materialize.output_consistency.data_type.data_type import DataType
- from materialize.output_consistency.expression.expression import Expression
- from materialize.output_consistency.expression.expression_characteristics import (
- ExpressionCharacteristics,
- )
- from materialize.output_consistency.operation.operation_args_validator import (
- OperationArgsValidator,
- )
- from materialize.output_consistency.operation.operation_param import OperationParam
- from materialize.output_consistency.operation.return_type_spec import ReturnTypeSpec
- EXPRESSION_PLACEHOLDER = "$"
- class OperationRelevance(Enum):
- # for testing
- EXTREME_HIGH = 1
- HIGH = 2
- DEFAULT = 3
- LOW = 4
- class DbOperationOrFunction:
- """Base class of `DbOperation` and `DbFunction`"""
- def __init__(
- self,
- params: list[OperationParam],
- min_param_count: int,
- max_param_count: int,
- return_type_spec: ReturnTypeSpec,
- args_validators: set[OperationArgsValidator] | None = None,
- is_aggregation: bool = False,
- is_table_function: bool = False,
- relevance: OperationRelevance = OperationRelevance.DEFAULT,
- comment: str | None = None,
- is_enabled: bool = True,
- is_pg_compatible: bool = True,
- tags: set[str] | None = None,
- since_mz_version: MzVersion | None = None,
- ):
- """
- :param is_enabled: an operation should only be disabled if its execution causes problems;
- if it just fails, it should be ignored
- """
- if args_validators is None:
- args_validators = set()
- self.params = params
- self.min_param_count = min_param_count
- self.max_param_count = max_param_count
- self.return_type_spec = return_type_spec
- self.args_validators: set[OperationArgsValidator] = args_validators
- self.is_aggregation = is_aggregation
- self.is_table_function = is_table_function
- self.relevance = relevance
- self.comment = comment
- self.is_enabled = is_enabled
- self.is_pg_compatible = is_pg_compatible
- self.tags = tags
- self.since_mz_version = since_mz_version
- self.added_characteristics: set[ExpressionCharacteristics] = set()
- def to_pattern(self, args_count: int) -> str:
- raise NotImplementedError
- def validate_args_count_in_range(self, args_count: int) -> None:
- if args_count < self.min_param_count:
- raise RuntimeError(
- f"To few arguments (got {args_count}, expected at least {self.min_param_count})"
- )
- if args_count > self.max_param_count:
- raise RuntimeError(
- f"To many arguments (got {args_count}, expected at most {self.max_param_count})"
- )
- def derive_characteristics(
- self, args: list[Expression]
- ) -> set[ExpressionCharacteristics]:
- return self.added_characteristics
- def __str__(self) -> str:
- raise NotImplementedError
- def operation_type_name(self) -> str:
- raise NotImplementedError
- def to_description(self, param_count: int) -> str:
- assert self.min_param_count <= param_count <= self.max_param_count
- return f"{self.operation_type_name()} '{self._to_pattern_with_named_params(param_count)}'"
- def _to_pattern_with_named_params(self, param_count: int) -> str:
- pattern = self.to_pattern(param_count)
- for i in range(param_count):
- pattern = pattern.replace("$", self.params[i].__class__.__name__, 1)
- return pattern
- def is_tagged(self, tag: str) -> bool:
- if self.tags is None:
- return False
- return tag in self.tags
- def try_resolve_exact_data_type(self, args: list[Expression]) -> DataType | None:
- return None
- def is_expected_to_cause_db_error(self, args: list[Expression]) -> bool:
- """checks incompatibilities (e.g., division by zero) and potential error scenarios (e.g., addition of two max
- data_type)
- """
- self.validate_args_count_in_range(len(args))
- for validator in self.args_validators:
- if validator.is_expected_to_cause_error(args):
- return True
- for param, arg in zip(self.params, args):
- if not param.supports_expression(arg):
- return True
- return False
- def count_variants(self) -> int:
- return self.max_param_count - self.min_param_count + 1
- class DbOperation(DbOperationOrFunction):
- """A database operation (e.g., `a + b`)"""
- def __init__(
- self,
- pattern: str,
- params: list[OperationParam],
- return_type_spec: ReturnTypeSpec,
- args_validators: set[OperationArgsValidator] | None = None,
- relevance: OperationRelevance = OperationRelevance.DEFAULT,
- comment: str | None = None,
- is_enabled: bool = True,
- is_pg_compatible: bool = True,
- tags: set[str] | None = None,
- since_mz_version: MzVersion | None = None,
- ):
- param_count = len(params)
- super().__init__(
- params,
- min_param_count=param_count,
- max_param_count=param_count,
- return_type_spec=return_type_spec,
- args_validators=args_validators,
- is_aggregation=False,
- relevance=relevance,
- comment=comment,
- is_enabled=is_enabled,
- is_pg_compatible=is_pg_compatible,
- tags=tags,
- since_mz_version=since_mz_version,
- )
- self.pattern = pattern
- if param_count != self.pattern.count(EXPRESSION_PLACEHOLDER):
- raise RuntimeError(
- f"Operation has pattern {self.pattern} but has only {param_count} parameters"
- )
- def to_pattern(self, args_count: int) -> str:
- self.validate_args_count_in_range(args_count)
- # wrap in parentheses
- return f"({self.pattern})"
- def __str__(self) -> str:
- comment = f" (comment: {self.comment})" if self.comment is not None else ""
- return f"DbOperation: {self.pattern}{comment}"
- def operation_type_name(self) -> str:
- return "operation"
- class DbFunction(DbOperationOrFunction):
- """A database function (e.g., `SUM(x)`)"""
- def __init__(
- self,
- function_name: str,
- params: list[OperationParam],
- return_type_spec: ReturnTypeSpec,
- args_validators: set[OperationArgsValidator] | None = None,
- is_aggregation: bool = False,
- is_table_function: bool = False,
- relevance: OperationRelevance = OperationRelevance.DEFAULT,
- comment: str | None = None,
- is_enabled: bool = True,
- is_pg_compatible: bool = True,
- tags: set[str] | None = None,
- since_mz_version: MzVersion | None = None,
- ):
- self.validate_params(params)
- super().__init__(
- params,
- min_param_count=self.get_min_param_count(params),
- max_param_count=len(params),
- return_type_spec=return_type_spec,
- args_validators=args_validators,
- is_aggregation=is_aggregation,
- is_table_function=is_table_function,
- relevance=relevance,
- comment=comment,
- is_enabled=is_enabled,
- is_pg_compatible=is_pg_compatible,
- tags=tags,
- since_mz_version=since_mz_version,
- )
- self.function_name_in_lower_case = function_name.lower()
- def validate_params(self, params: list[OperationParam]) -> None:
- optional_param_seen = False
- for param in params:
- if optional_param_seen and not param.optional:
- raise RuntimeError("Optional parameters must be at the end")
- if param.optional:
- optional_param_seen = True
- def get_min_param_count(self, params: list[OperationParam]) -> int:
- for index, param in enumerate(params):
- if param.optional:
- return index
- return len(params)
- def to_pattern(self, args_count: int) -> str:
- self.validate_args_count_in_range(args_count)
- args_pattern = ", ".join(["$"] * args_count)
- return f"{self.function_name_in_lower_case}({args_pattern})"
- def __str__(self) -> str:
- comment = f" (comment: {self.comment})" if self.comment is not None else ""
- return f"DbFunction: {self.function_name_in_lower_case}{comment}"
- def operation_type_name(self) -> str:
- return "function"
- class DbFunctionWithCustomPattern(DbFunction):
- def __init__(
- self,
- function_name: str,
- pattern_per_param_count: dict[int, str],
- params: list[OperationParam],
- return_type_spec: ReturnTypeSpec,
- args_validators: set[OperationArgsValidator] | None = None,
- is_aggregation: bool = False,
- is_table_function: bool = False,
- relevance: OperationRelevance = OperationRelevance.DEFAULT,
- comment: str | None = None,
- is_enabled: bool = True,
- tags: set[str] | None = None,
- ):
- super().__init__(
- function_name,
- params,
- return_type_spec,
- args_validators=args_validators,
- is_aggregation=is_aggregation,
- is_table_function=is_table_function,
- relevance=relevance,
- comment=comment,
- is_enabled=is_enabled,
- tags=tags,
- )
- self.pattern_per_param_count = pattern_per_param_count
- self.min_param_count = min(pattern_per_param_count.keys())
- self.max_param_count = max(pattern_per_param_count.keys())
- def to_pattern(self, args_count: int) -> str:
- self.validate_args_count_in_range(args_count)
- if args_count not in self.pattern_per_param_count:
- raise RuntimeError(
- f"No pattern specified for {self.function_name_in_lower_case} with {args_count} params"
- )
- return self.pattern_per_param_count[args_count]
- def count_variants(self) -> int:
- return len(self.pattern_per_param_count)
- def match_function_by_name(
- op: DbOperationOrFunction, function_name_in_lower_case: str
- ) -> bool:
- return (
- isinstance(op, DbFunction)
- and op.function_name_in_lower_case == function_name_in_lower_case
- )
|