operation_args_validator.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. # Copyright Materialize, Inc. and contributors. All rights reserved.
  2. #
  3. # Use of this software is governed by the Business Source License
  4. # included in the LICENSE file at the root of this repository.
  5. #
  6. # As of the Change Date specified in that file, in accordance with
  7. # the Business Source License, use of this software will be governed
  8. # by the Apache License, Version 2.0.
  9. from collections.abc import Callable
  10. from materialize.output_consistency.expression.expression import Expression
  11. from materialize.output_consistency.expression.expression_characteristics import (
  12. ExpressionCharacteristics,
  13. )
  14. class OperationArgsValidator:
  15. """Validator that performs heuristic checks to determine if a database error is to be expected"""
  16. def is_expected_to_cause_error(self, args: list[Expression]) -> bool:
  17. raise NotImplementedError
  18. def index_of(
  19. self,
  20. args: list[Expression],
  21. match_argument_fn: Callable[
  22. [Expression, set[ExpressionCharacteristics], int], bool
  23. ],
  24. skip_argument_indices: set[int] | None = None,
  25. ) -> int:
  26. if skip_argument_indices is None:
  27. skip_argument_indices = set()
  28. for index, arg in enumerate(args):
  29. if index in skip_argument_indices:
  30. continue
  31. if match_argument_fn(arg, arg.own_characteristics, index):
  32. return index
  33. return -1
  34. def index_of_characteristic_combination(
  35. self,
  36. args: list[Expression],
  37. characteristic_combination: set[ExpressionCharacteristics],
  38. skip_argument_indices: set[int] | None = None,
  39. ) -> int:
  40. def match_fn(
  41. _arg: Expression,
  42. arg_characteristics: set[ExpressionCharacteristics],
  43. _index: int,
  44. ) -> bool:
  45. return len(characteristic_combination & arg_characteristics) == len(
  46. characteristic_combination
  47. )
  48. return self.index_of(args, match_fn, skip_argument_indices)