123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567 |
- # Copyright Materialize, Inc. and contributors. All rights reserved.
- #
- # Use of this software is governed by the Business Source License
- # included in the LICENSE file at the root of this repository.
- #
- # As of the Change Date specified in that file, in accordance with
- # the Business Source License, use of this software will be governed
- # by the Apache License, Version 2.0.
- from collections.abc import Callable
- from materialize.output_consistency.common import probability
- from materialize.output_consistency.common.configuration import (
- ConsistencyTestConfiguration,
- )
- from materialize.output_consistency.data_type.data_type_category import DataTypeCategory
- from materialize.output_consistency.data_type.data_type_with_values import (
- DataTypeWithValues,
- )
- from materialize.output_consistency.execution.value_storage_layout import (
- ValueStorageLayout,
- )
- from materialize.output_consistency.expression.expression import (
- Expression,
- LeafExpression,
- )
- from materialize.output_consistency.expression.expression_with_args import (
- ExpressionWithArgs,
- )
- from materialize.output_consistency.generators.arg_context import ArgContext
- from materialize.output_consistency.input_data.operations.equality_operations_provider import (
- EQUALS_OPERATION,
- )
- from materialize.output_consistency.input_data.params.one_of_operation_param import (
- OneOf,
- )
- from materialize.output_consistency.input_data.test_input_data import (
- ConsistencyTestInputData,
- )
- from materialize.output_consistency.operation.operation import (
- DbOperationOrFunction,
- )
- from materialize.output_consistency.operation.operation_param import OperationParam
- from materialize.output_consistency.operation.volatile_data_operation_param import (
- VolatileDataOperationParam,
- )
- from materialize.output_consistency.selection.randomized_picker import RandomizedPicker
- NESTING_LEVEL_ROOT = 0
- NESTING_LEVEL_OUTERMOST_ARG = 1
- FIRST_ARG_INDEX = 0
- class ExpressionGenerator:
- """Generates expressions based on a random selection of operations"""
- def __init__(
- self,
- config: ConsistencyTestConfiguration,
- randomized_picker: RandomizedPicker,
- input_data: ConsistencyTestInputData,
- ):
- self.config = config
- self.randomized_picker = randomized_picker
- self.input_data = input_data
- self.selectable_operations: list[DbOperationOrFunction] = []
- self.operation_weights: list[float] = []
- self.operation_weights_no_aggregates: list[float] = []
- self.operations_by_return_type_category: dict[
- DataTypeCategory, list[DbOperationOrFunction]
- ] = dict()
- self.types_with_values_by_category: dict[
- DataTypeCategory, list[DataTypeWithValues]
- ] = dict()
- self._initialize_operations()
- self._initialize_types()
- def _initialize_operations(self) -> None:
- self.operation_weights = self._get_operation_weights(
- self.input_data.operations_input.all_operation_types
- )
- for index, operation in enumerate(
- self.input_data.operations_input.all_operation_types
- ):
- self.selectable_operations.append(operation)
- self.operation_weights_no_aggregates.append(
- 0 if operation.is_aggregation else self.operation_weights[index]
- )
- category = operation.return_type_spec.type_category
- operations_with_return_category = (
- self.operations_by_return_type_category.get(category, [])
- )
- operations_with_return_category.append(operation)
- self.operations_by_return_type_category[category] = (
- operations_with_return_category
- )
- def _initialize_types(self) -> None:
- for (
- data_type_with_values
- ) in self.input_data.types_input.all_data_types_with_values:
- category = data_type_with_values.data_type.category
- types_with_values = self.types_with_values_by_category.get(category, [])
- types_with_values.append(data_type_with_values)
- self.types_with_values_by_category[category] = types_with_values
- def pick_random_operation(
- self,
- include_aggregates: bool,
- accept_op_filter: Callable[[DbOperationOrFunction], bool] | None = None,
- ) -> DbOperationOrFunction:
- all_weights = (
- self.operation_weights
- if include_aggregates
- else self.operation_weights_no_aggregates
- )
- if accept_op_filter is not None:
- selected_operations = []
- weights = []
- for index, operation in enumerate(self.selectable_operations):
- if accept_op_filter(operation):
- selected_operations.append(operation)
- weights.append(all_weights[index])
- else:
- selected_operations = self.selectable_operations
- weights = all_weights
- assert (
- len(selected_operations) > 0
- ), f"no operations available (include_aggregates={include_aggregates}, accept_op_filter used={accept_op_filter is not None})"
- assert len(selected_operations) == len(weights)
- return self.randomized_picker.random_operation(selected_operations, weights)
- def generate_boolean_expression(
- self,
- use_aggregation: bool,
- storage_layout: ValueStorageLayout | None,
- nesting_level: int = NESTING_LEVEL_ROOT,
- ) -> ExpressionWithArgs | None:
- return self.generate_expression_for_data_type_category(
- use_aggregation, storage_layout, DataTypeCategory.BOOLEAN, nesting_level
- )
- def generate_expression_for_data_type_category(
- self,
- use_aggregation: bool,
- storage_layout: ValueStorageLayout | None,
- data_type_category: DataTypeCategory,
- nesting_level: int = NESTING_LEVEL_ROOT,
- ) -> ExpressionWithArgs | None:
- def operation_filter(operation: DbOperationOrFunction) -> bool:
- if operation.is_aggregation != use_aggregation:
- return False
- # Simplification: This will only include operations defined to return a boolean value but not generic
- # operations that might return a boolean value depending on the input.
- return operation.return_type_spec.type_category == data_type_category
- return self.generate_expression_with_filter(
- use_aggregation, storage_layout, operation_filter, nesting_level
- )
- def generate_expression_with_filter(
- self,
- use_aggregation: bool,
- storage_layout: ValueStorageLayout | None,
- operation_filter: Callable[[DbOperationOrFunction], bool],
- nesting_level: int = NESTING_LEVEL_ROOT,
- ) -> ExpressionWithArgs | None:
- operation = self.pick_random_operation(use_aggregation, operation_filter)
- expression, _ = self.generate_expression_for_operation(
- operation, storage_layout, nesting_level
- )
- return expression
- def generate_expression_for_operation(
- self,
- operation: DbOperationOrFunction,
- storage_layout: ValueStorageLayout | None = None,
- nesting_level: int = NESTING_LEVEL_ROOT,
- ) -> tuple[ExpressionWithArgs | None, int]:
- """
- :return: the expression or None if it was not possible to create one, and the number of used operation params
- """
- if storage_layout is None:
- storage_layout = self._select_storage_layout(operation)
- number_of_args = self.randomized_picker.random_number(
- operation.min_param_count, operation.max_param_count
- )
- try:
- args = self._generate_args_for_operation(
- operation, number_of_args, storage_layout, nesting_level + 1
- )
- except NoSuitableExpressionFound as ex:
- if self.config.verbose_output:
- print(f"No suitable expression found: {ex.message}")
- return None, number_of_args
- is_aggregate = operation.is_aggregation or self._contains_aggregate_arg(args)
- is_expect_error = operation.is_expected_to_cause_db_error(args)
- expression = ExpressionWithArgs(operation, args, is_aggregate, is_expect_error)
- return expression, number_of_args
- def generate_equals_expression(
- self, arg1: Expression, arg2: Expression
- ) -> ExpressionWithArgs:
- operation = EQUALS_OPERATION
- args = [arg1, arg2]
- is_aggregate = self._contains_aggregate_arg(args)
- is_expect_error = operation.is_expected_to_cause_db_error(args)
- return ExpressionWithArgs(operation, args, is_aggregate, is_expect_error)
- def generate_leaf_expression(
- self,
- storage_layout: ValueStorageLayout,
- types_with_values: list[DataTypeWithValues],
- ) -> LeafExpression:
- assert len(types_with_values) > 0, "No suitable types with values"
- type_with_values = self.randomized_picker.random_type_with_values(
- types_with_values
- )
- if storage_layout == ValueStorageLayout.VERTICAL:
- return type_with_values.create_unassigned_vertical_storage_column()
- elif storage_layout == ValueStorageLayout.HORIZONTAL:
- if len(type_with_values.raw_values) == 0:
- raise NoSuitableExpressionFound("No value in type")
- return self.randomized_picker.random_value(type_with_values.raw_values)
- else:
- raise RuntimeError(f"Unsupported storage layout: {storage_layout}")
- def _select_storage_layout(
- self, operation: DbOperationOrFunction
- ) -> ValueStorageLayout:
- if not operation.is_aggregation:
- # Prefer the horizontal row format for non-aggregate expressions. (It makes it less likely that a query
- # results in (an unexpected) error. Furthermore, in case of an error, error messages of non-aggregate
- # expressions can only be compared in HORIZONTAL layout (because the row processing order of an
- # evaluation strategy is not defined).)
- if self.randomized_picker.random_boolean(
- probability.HORIZONTAL_LAYOUT_WHEN_NOT_AGGREGATED
- ):
- return ValueStorageLayout.HORIZONTAL
- else:
- return ValueStorageLayout.VERTICAL
- # strongly prefer vertical storage for aggregations but allow some variance
- if self.randomized_picker.random_boolean(
- probability.HORIZONTAL_LAYOUT_WHEN_AGGREGATED
- ):
- # Use horizontal layout in some cases
- return ValueStorageLayout.HORIZONTAL
- return ValueStorageLayout.VERTICAL
- def _contains_aggregate_arg(self, args: list[Expression]) -> bool:
- for arg in args:
- if arg.is_aggregate:
- return True
- return False
- def _generate_args_for_operation(
- self,
- operation: DbOperationOrFunction,
- number_of_args: int,
- storage_layout: ValueStorageLayout,
- nesting_level: int,
- try_number: int = 1,
- ) -> list[Expression]:
- if number_of_args == 0:
- return []
- arg_context = ArgContext()
- for arg_index in range(FIRST_ARG_INDEX, number_of_args):
- param = operation.params[arg_index]
- # nesting_level was already incremented before invoking this function
- arg = self._generate_arg_for_param(
- operation,
- param,
- storage_layout,
- arg_context,
- nesting_level,
- )
- arg_context.append(arg)
- if (
- self.config.avoid_expressions_expecting_db_error
- and try_number <= 50
- and operation.is_expected_to_cause_db_error(arg_context.args)
- ):
- # retry
- return self._generate_args_for_operation(
- operation,
- number_of_args,
- storage_layout,
- nesting_level=nesting_level,
- try_number=try_number + 1,
- )
- return arg_context.args
- def _generate_arg_for_param(
- self,
- operation: DbOperationOrFunction,
- param: OperationParam,
- storage_layout: ValueStorageLayout,
- arg_context: ArgContext,
- nesting_level: int,
- ) -> Expression:
- # this one must be at the top
- if isinstance(param, OneOf):
- param = param.pick(self.randomized_picker)
- if isinstance(param, VolatileDataOperationParam):
- return param.generate_expression(arg_context, self.randomized_picker)
- create_complex_arg = (
- arg_context.requires_aggregation()
- or self.randomized_picker.random_boolean(
- probability.CREATE_COMPLEX_EXPRESSION
- )
- )
- if create_complex_arg:
- return self._generate_complex_arg_for_param(
- param,
- storage_layout,
- arg_context,
- operation.is_aggregation,
- nesting_level,
- )
- else:
- return self._generate_simple_arg_for_param(
- param, arg_context, storage_layout
- )
- def _generate_simple_arg_for_param(
- self,
- param: OperationParam,
- arg_context: ArgContext,
- storage_layout: ValueStorageLayout,
- ) -> LeafExpression:
- # only consider the data type category, do not check incompatibilities and other validations at this point
- suitable_types_with_values = self._get_data_type_values_of_category(
- param, arg_context
- )
- if len(suitable_types_with_values) == 0:
- raise NoSuitableExpressionFound("No suitable type")
- return self.generate_leaf_expression(storage_layout, suitable_types_with_values)
- def _generate_complex_arg_for_param(
- self,
- param: OperationParam,
- storage_layout: ValueStorageLayout,
- arg_context: ArgContext,
- is_aggregation_operation: bool,
- nesting_level: int,
- try_number: int = 1,
- ) -> ExpressionWithArgs:
- must_use_aggregation = arg_context.requires_aggregation()
- # currently allow an aggregation function as argument if all applies:
- # * the operation is not an aggregation (nested aggregations are impossible)
- # * it is first param (all consecutive params with require aggregation)
- # * we are not already nested (to avoid nested aggregations spread across several levels)
- allow_aggregation = must_use_aggregation or (
- not is_aggregation_operation
- and arg_context.has_no_args()
- and nesting_level == NESTING_LEVEL_OUTERMOST_ARG
- )
- suitable_operations = self._get_operations_of_category(
- param, arg_context, must_use_aggregation, allow_aggregation
- )
- if len(suitable_operations) == 0:
- raise NoSuitableExpressionFound(
- f"No suitable operation for {param}"
- f" (layout={storage_layout},"
- f" allow_aggregation={allow_aggregation},"
- f" must_use_aggregation={must_use_aggregation})"
- )
- weights = self._get_operation_weights(suitable_operations)
- operation = self.randomized_picker.random_operation(
- suitable_operations, weights
- )
- nested_expression, _ = self.generate_expression_for_operation(
- operation, storage_layout, nesting_level
- )
- if nested_expression is None:
- raise NoSuitableExpressionFound(
- f"No nested expression for {param} in {storage_layout}"
- )
- data_type = nested_expression.try_resolve_exact_data_type()
- is_unsupported = data_type is not None and not param.supports_type(
- data_type, arg_context.args
- )
- is_unsupported = (
- is_unsupported
- or not param.might_support_type_as_input_assuming_category_matches(
- nested_expression.operation.return_type_spec
- )
- )
- if is_unsupported:
- if try_number < 5:
- return self._generate_complex_arg_for_param(
- param,
- storage_layout,
- arg_context,
- is_aggregation_operation,
- nesting_level,
- try_number=try_number + 1,
- )
- else:
- raise NoSuitableExpressionFound("No supported data type")
- return nested_expression
- def _get_data_type_values_of_category(
- self, param: OperationParam, arg_context: ArgContext
- ) -> list[DataTypeWithValues]:
- category = param.resolve_type_category(arg_context.args)
- if category == DataTypeCategory.ANY:
- return self.input_data.types_input.all_data_types_with_values
- self._assert_valid_type_category_for_param(param, category)
- preselected_types_with_values = self.types_with_values_by_category.get(
- category, []
- )
- suitable_types_with_values = []
- for type_with_values in preselected_types_with_values:
- if param.supports_type(type_with_values.data_type, arg_context.args):
- suitable_types_with_values.append(type_with_values)
- return suitable_types_with_values
- def _assert_valid_type_category_for_param(
- self, param: OperationParam, category: DataTypeCategory
- ) -> None:
- assert category not in {
- DataTypeCategory.DYNAMIC,
- }, f"Type category {category} not allowed for parameters (param={param})"
- def _get_operations_of_category(
- self,
- param: OperationParam,
- arg_context: ArgContext,
- must_use_aggregation: bool,
- allow_aggregation: bool,
- ) -> list[DbOperationOrFunction]:
- category = param.resolve_type_category(arg_context.args)
- suitable_operations = self._get_all_operations_of_category(param, category)
- if must_use_aggregation:
- return self._get_only_aggregate_operations(suitable_operations)
- elif not allow_aggregation:
- return self._get_without_aggregate_operations(suitable_operations)
- else:
- return suitable_operations
- def _get_all_operations_of_category(
- self, param: OperationParam, category: DataTypeCategory
- ) -> list[DbOperationOrFunction]:
- if category == DataTypeCategory.ANY:
- return self.input_data.operations_input.all_operation_types
- self._assert_valid_type_category_for_param(param, category)
- return self.operations_by_return_type_category.get(category, [])
- def _get_without_aggregate_operations(
- self, operations: list[DbOperationOrFunction]
- ) -> list[DbOperationOrFunction]:
- return self._get_operations_with_filter(
- operations, lambda op: not op.is_aggregation
- )
- def _get_only_aggregate_operations(
- self, operations: list[DbOperationOrFunction]
- ) -> list[DbOperationOrFunction]:
- return self._get_operations_with_filter(
- operations, lambda op: op.is_aggregation
- )
- def _get_operations_with_filter(
- self,
- operations: list[DbOperationOrFunction],
- op_filter: Callable[[DbOperationOrFunction], bool],
- ) -> list[DbOperationOrFunction]:
- matching_operations = []
- for operation in operations:
- if op_filter(operation):
- matching_operations.append(operation)
- return matching_operations
- def _get_operation_weights(
- self, operations: list[DbOperationOrFunction]
- ) -> list[float]:
- weights = []
- for operation in operations:
- weight = self.randomized_picker.convert_operation_relevance_to_number(
- operation.relevance
- )
- weights.append(weight)
- return weights
- def find_operations_by_predicate(
- self, match_op: Callable[[DbOperationOrFunction], bool]
- ) -> list[DbOperationOrFunction]:
- matched_ops = list()
- for op in self.selectable_operations:
- if match_op(op):
- matched_ops.append(op)
- return matched_ops
- def find_exactly_one_operation_by_predicate(
- self, match_op: Callable[[DbOperationOrFunction], bool]
- ) -> DbOperationOrFunction:
- operations = self.find_operations_by_predicate(match_op)
- if len(operations) == 0:
- raise RuntimeError("No operation matches!")
- if len(operations) > 1:
- raise RuntimeError(f"More than one operation matches: {operations}")
- return operations[0]
- def find_data_type_with_values_by_type_identifier(
- self, type_identifier: str
- ) -> DataTypeWithValues:
- for (
- data_type_with_values
- ) in self.input_data.types_input.all_data_types_with_values:
- if data_type_with_values.data_type.internal_identifier == type_identifier:
- return data_type_with_values
- raise RuntimeError(f"No data type found with identifier {type_identifier}")
- class NoSuitableExpressionFound(Exception):
- def __init__(self, message: str):
- super().__init__()
- self.message = message
|