expression_generator.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567
  1. # Copyright Materialize, Inc. and contributors. All rights reserved.
  2. #
  3. # Use of this software is governed by the Business Source License
  4. # included in the LICENSE file at the root of this repository.
  5. #
  6. # As of the Change Date specified in that file, in accordance with
  7. # the Business Source License, use of this software will be governed
  8. # by the Apache License, Version 2.0.
  9. from collections.abc import Callable
  10. from materialize.output_consistency.common import probability
  11. from materialize.output_consistency.common.configuration import (
  12. ConsistencyTestConfiguration,
  13. )
  14. from materialize.output_consistency.data_type.data_type_category import DataTypeCategory
  15. from materialize.output_consistency.data_type.data_type_with_values import (
  16. DataTypeWithValues,
  17. )
  18. from materialize.output_consistency.execution.value_storage_layout import (
  19. ValueStorageLayout,
  20. )
  21. from materialize.output_consistency.expression.expression import (
  22. Expression,
  23. LeafExpression,
  24. )
  25. from materialize.output_consistency.expression.expression_with_args import (
  26. ExpressionWithArgs,
  27. )
  28. from materialize.output_consistency.generators.arg_context import ArgContext
  29. from materialize.output_consistency.input_data.operations.equality_operations_provider import (
  30. EQUALS_OPERATION,
  31. )
  32. from materialize.output_consistency.input_data.params.one_of_operation_param import (
  33. OneOf,
  34. )
  35. from materialize.output_consistency.input_data.test_input_data import (
  36. ConsistencyTestInputData,
  37. )
  38. from materialize.output_consistency.operation.operation import (
  39. DbOperationOrFunction,
  40. )
  41. from materialize.output_consistency.operation.operation_param import OperationParam
  42. from materialize.output_consistency.operation.volatile_data_operation_param import (
  43. VolatileDataOperationParam,
  44. )
  45. from materialize.output_consistency.selection.randomized_picker import RandomizedPicker
  46. NESTING_LEVEL_ROOT = 0
  47. NESTING_LEVEL_OUTERMOST_ARG = 1
  48. FIRST_ARG_INDEX = 0
  49. class ExpressionGenerator:
  50. """Generates expressions based on a random selection of operations"""
  51. def __init__(
  52. self,
  53. config: ConsistencyTestConfiguration,
  54. randomized_picker: RandomizedPicker,
  55. input_data: ConsistencyTestInputData,
  56. ):
  57. self.config = config
  58. self.randomized_picker = randomized_picker
  59. self.input_data = input_data
  60. self.selectable_operations: list[DbOperationOrFunction] = []
  61. self.operation_weights: list[float] = []
  62. self.operation_weights_no_aggregates: list[float] = []
  63. self.operations_by_return_type_category: dict[
  64. DataTypeCategory, list[DbOperationOrFunction]
  65. ] = dict()
  66. self.types_with_values_by_category: dict[
  67. DataTypeCategory, list[DataTypeWithValues]
  68. ] = dict()
  69. self._initialize_operations()
  70. self._initialize_types()
  71. def _initialize_operations(self) -> None:
  72. self.operation_weights = self._get_operation_weights(
  73. self.input_data.operations_input.all_operation_types
  74. )
  75. for index, operation in enumerate(
  76. self.input_data.operations_input.all_operation_types
  77. ):
  78. self.selectable_operations.append(operation)
  79. self.operation_weights_no_aggregates.append(
  80. 0 if operation.is_aggregation else self.operation_weights[index]
  81. )
  82. category = operation.return_type_spec.type_category
  83. operations_with_return_category = (
  84. self.operations_by_return_type_category.get(category, [])
  85. )
  86. operations_with_return_category.append(operation)
  87. self.operations_by_return_type_category[category] = (
  88. operations_with_return_category
  89. )
  90. def _initialize_types(self) -> None:
  91. for (
  92. data_type_with_values
  93. ) in self.input_data.types_input.all_data_types_with_values:
  94. category = data_type_with_values.data_type.category
  95. types_with_values = self.types_with_values_by_category.get(category, [])
  96. types_with_values.append(data_type_with_values)
  97. self.types_with_values_by_category[category] = types_with_values
  98. def pick_random_operation(
  99. self,
  100. include_aggregates: bool,
  101. accept_op_filter: Callable[[DbOperationOrFunction], bool] | None = None,
  102. ) -> DbOperationOrFunction:
  103. all_weights = (
  104. self.operation_weights
  105. if include_aggregates
  106. else self.operation_weights_no_aggregates
  107. )
  108. if accept_op_filter is not None:
  109. selected_operations = []
  110. weights = []
  111. for index, operation in enumerate(self.selectable_operations):
  112. if accept_op_filter(operation):
  113. selected_operations.append(operation)
  114. weights.append(all_weights[index])
  115. else:
  116. selected_operations = self.selectable_operations
  117. weights = all_weights
  118. assert (
  119. len(selected_operations) > 0
  120. ), f"no operations available (include_aggregates={include_aggregates}, accept_op_filter used={accept_op_filter is not None})"
  121. assert len(selected_operations) == len(weights)
  122. return self.randomized_picker.random_operation(selected_operations, weights)
  123. def generate_boolean_expression(
  124. self,
  125. use_aggregation: bool,
  126. storage_layout: ValueStorageLayout | None,
  127. nesting_level: int = NESTING_LEVEL_ROOT,
  128. ) -> ExpressionWithArgs | None:
  129. return self.generate_expression_for_data_type_category(
  130. use_aggregation, storage_layout, DataTypeCategory.BOOLEAN, nesting_level
  131. )
  132. def generate_expression_for_data_type_category(
  133. self,
  134. use_aggregation: bool,
  135. storage_layout: ValueStorageLayout | None,
  136. data_type_category: DataTypeCategory,
  137. nesting_level: int = NESTING_LEVEL_ROOT,
  138. ) -> ExpressionWithArgs | None:
  139. def operation_filter(operation: DbOperationOrFunction) -> bool:
  140. if operation.is_aggregation != use_aggregation:
  141. return False
  142. # Simplification: This will only include operations defined to return a boolean value but not generic
  143. # operations that might return a boolean value depending on the input.
  144. return operation.return_type_spec.type_category == data_type_category
  145. return self.generate_expression_with_filter(
  146. use_aggregation, storage_layout, operation_filter, nesting_level
  147. )
  148. def generate_expression_with_filter(
  149. self,
  150. use_aggregation: bool,
  151. storage_layout: ValueStorageLayout | None,
  152. operation_filter: Callable[[DbOperationOrFunction], bool],
  153. nesting_level: int = NESTING_LEVEL_ROOT,
  154. ) -> ExpressionWithArgs | None:
  155. operation = self.pick_random_operation(use_aggregation, operation_filter)
  156. expression, _ = self.generate_expression_for_operation(
  157. operation, storage_layout, nesting_level
  158. )
  159. return expression
  160. def generate_expression_for_operation(
  161. self,
  162. operation: DbOperationOrFunction,
  163. storage_layout: ValueStorageLayout | None = None,
  164. nesting_level: int = NESTING_LEVEL_ROOT,
  165. ) -> tuple[ExpressionWithArgs | None, int]:
  166. """
  167. :return: the expression or None if it was not possible to create one, and the number of used operation params
  168. """
  169. if storage_layout is None:
  170. storage_layout = self._select_storage_layout(operation)
  171. number_of_args = self.randomized_picker.random_number(
  172. operation.min_param_count, operation.max_param_count
  173. )
  174. try:
  175. args = self._generate_args_for_operation(
  176. operation, number_of_args, storage_layout, nesting_level + 1
  177. )
  178. except NoSuitableExpressionFound as ex:
  179. if self.config.verbose_output:
  180. print(f"No suitable expression found: {ex.message}")
  181. return None, number_of_args
  182. is_aggregate = operation.is_aggregation or self._contains_aggregate_arg(args)
  183. is_expect_error = operation.is_expected_to_cause_db_error(args)
  184. expression = ExpressionWithArgs(operation, args, is_aggregate, is_expect_error)
  185. return expression, number_of_args
  186. def generate_equals_expression(
  187. self, arg1: Expression, arg2: Expression
  188. ) -> ExpressionWithArgs:
  189. operation = EQUALS_OPERATION
  190. args = [arg1, arg2]
  191. is_aggregate = self._contains_aggregate_arg(args)
  192. is_expect_error = operation.is_expected_to_cause_db_error(args)
  193. return ExpressionWithArgs(operation, args, is_aggregate, is_expect_error)
  194. def generate_leaf_expression(
  195. self,
  196. storage_layout: ValueStorageLayout,
  197. types_with_values: list[DataTypeWithValues],
  198. ) -> LeafExpression:
  199. assert len(types_with_values) > 0, "No suitable types with values"
  200. type_with_values = self.randomized_picker.random_type_with_values(
  201. types_with_values
  202. )
  203. if storage_layout == ValueStorageLayout.VERTICAL:
  204. return type_with_values.create_unassigned_vertical_storage_column()
  205. elif storage_layout == ValueStorageLayout.HORIZONTAL:
  206. if len(type_with_values.raw_values) == 0:
  207. raise NoSuitableExpressionFound("No value in type")
  208. return self.randomized_picker.random_value(type_with_values.raw_values)
  209. else:
  210. raise RuntimeError(f"Unsupported storage layout: {storage_layout}")
  211. def _select_storage_layout(
  212. self, operation: DbOperationOrFunction
  213. ) -> ValueStorageLayout:
  214. if not operation.is_aggregation:
  215. # Prefer the horizontal row format for non-aggregate expressions. (It makes it less likely that a query
  216. # results in (an unexpected) error. Furthermore, in case of an error, error messages of non-aggregate
  217. # expressions can only be compared in HORIZONTAL layout (because the row processing order of an
  218. # evaluation strategy is not defined).)
  219. if self.randomized_picker.random_boolean(
  220. probability.HORIZONTAL_LAYOUT_WHEN_NOT_AGGREGATED
  221. ):
  222. return ValueStorageLayout.HORIZONTAL
  223. else:
  224. return ValueStorageLayout.VERTICAL
  225. # strongly prefer vertical storage for aggregations but allow some variance
  226. if self.randomized_picker.random_boolean(
  227. probability.HORIZONTAL_LAYOUT_WHEN_AGGREGATED
  228. ):
  229. # Use horizontal layout in some cases
  230. return ValueStorageLayout.HORIZONTAL
  231. return ValueStorageLayout.VERTICAL
  232. def _contains_aggregate_arg(self, args: list[Expression]) -> bool:
  233. for arg in args:
  234. if arg.is_aggregate:
  235. return True
  236. return False
  237. def _generate_args_for_operation(
  238. self,
  239. operation: DbOperationOrFunction,
  240. number_of_args: int,
  241. storage_layout: ValueStorageLayout,
  242. nesting_level: int,
  243. try_number: int = 1,
  244. ) -> list[Expression]:
  245. if number_of_args == 0:
  246. return []
  247. arg_context = ArgContext()
  248. for arg_index in range(FIRST_ARG_INDEX, number_of_args):
  249. param = operation.params[arg_index]
  250. # nesting_level was already incremented before invoking this function
  251. arg = self._generate_arg_for_param(
  252. operation,
  253. param,
  254. storage_layout,
  255. arg_context,
  256. nesting_level,
  257. )
  258. arg_context.append(arg)
  259. if (
  260. self.config.avoid_expressions_expecting_db_error
  261. and try_number <= 50
  262. and operation.is_expected_to_cause_db_error(arg_context.args)
  263. ):
  264. # retry
  265. return self._generate_args_for_operation(
  266. operation,
  267. number_of_args,
  268. storage_layout,
  269. nesting_level=nesting_level,
  270. try_number=try_number + 1,
  271. )
  272. return arg_context.args
  273. def _generate_arg_for_param(
  274. self,
  275. operation: DbOperationOrFunction,
  276. param: OperationParam,
  277. storage_layout: ValueStorageLayout,
  278. arg_context: ArgContext,
  279. nesting_level: int,
  280. ) -> Expression:
  281. # this one must be at the top
  282. if isinstance(param, OneOf):
  283. param = param.pick(self.randomized_picker)
  284. if isinstance(param, VolatileDataOperationParam):
  285. return param.generate_expression(arg_context, self.randomized_picker)
  286. create_complex_arg = (
  287. arg_context.requires_aggregation()
  288. or self.randomized_picker.random_boolean(
  289. probability.CREATE_COMPLEX_EXPRESSION
  290. )
  291. )
  292. if create_complex_arg:
  293. return self._generate_complex_arg_for_param(
  294. param,
  295. storage_layout,
  296. arg_context,
  297. operation.is_aggregation,
  298. nesting_level,
  299. )
  300. else:
  301. return self._generate_simple_arg_for_param(
  302. param, arg_context, storage_layout
  303. )
  304. def _generate_simple_arg_for_param(
  305. self,
  306. param: OperationParam,
  307. arg_context: ArgContext,
  308. storage_layout: ValueStorageLayout,
  309. ) -> LeafExpression:
  310. # only consider the data type category, do not check incompatibilities and other validations at this point
  311. suitable_types_with_values = self._get_data_type_values_of_category(
  312. param, arg_context
  313. )
  314. if len(suitable_types_with_values) == 0:
  315. raise NoSuitableExpressionFound("No suitable type")
  316. return self.generate_leaf_expression(storage_layout, suitable_types_with_values)
  317. def _generate_complex_arg_for_param(
  318. self,
  319. param: OperationParam,
  320. storage_layout: ValueStorageLayout,
  321. arg_context: ArgContext,
  322. is_aggregation_operation: bool,
  323. nesting_level: int,
  324. try_number: int = 1,
  325. ) -> ExpressionWithArgs:
  326. must_use_aggregation = arg_context.requires_aggregation()
  327. # currently allow an aggregation function as argument if all applies:
  328. # * the operation is not an aggregation (nested aggregations are impossible)
  329. # * it is first param (all consecutive params with require aggregation)
  330. # * we are not already nested (to avoid nested aggregations spread across several levels)
  331. allow_aggregation = must_use_aggregation or (
  332. not is_aggregation_operation
  333. and arg_context.has_no_args()
  334. and nesting_level == NESTING_LEVEL_OUTERMOST_ARG
  335. )
  336. suitable_operations = self._get_operations_of_category(
  337. param, arg_context, must_use_aggregation, allow_aggregation
  338. )
  339. if len(suitable_operations) == 0:
  340. raise NoSuitableExpressionFound(
  341. f"No suitable operation for {param}"
  342. f" (layout={storage_layout},"
  343. f" allow_aggregation={allow_aggregation},"
  344. f" must_use_aggregation={must_use_aggregation})"
  345. )
  346. weights = self._get_operation_weights(suitable_operations)
  347. operation = self.randomized_picker.random_operation(
  348. suitable_operations, weights
  349. )
  350. nested_expression, _ = self.generate_expression_for_operation(
  351. operation, storage_layout, nesting_level
  352. )
  353. if nested_expression is None:
  354. raise NoSuitableExpressionFound(
  355. f"No nested expression for {param} in {storage_layout}"
  356. )
  357. data_type = nested_expression.try_resolve_exact_data_type()
  358. is_unsupported = data_type is not None and not param.supports_type(
  359. data_type, arg_context.args
  360. )
  361. is_unsupported = (
  362. is_unsupported
  363. or not param.might_support_type_as_input_assuming_category_matches(
  364. nested_expression.operation.return_type_spec
  365. )
  366. )
  367. if is_unsupported:
  368. if try_number < 5:
  369. return self._generate_complex_arg_for_param(
  370. param,
  371. storage_layout,
  372. arg_context,
  373. is_aggregation_operation,
  374. nesting_level,
  375. try_number=try_number + 1,
  376. )
  377. else:
  378. raise NoSuitableExpressionFound("No supported data type")
  379. return nested_expression
  380. def _get_data_type_values_of_category(
  381. self, param: OperationParam, arg_context: ArgContext
  382. ) -> list[DataTypeWithValues]:
  383. category = param.resolve_type_category(arg_context.args)
  384. if category == DataTypeCategory.ANY:
  385. return self.input_data.types_input.all_data_types_with_values
  386. self._assert_valid_type_category_for_param(param, category)
  387. preselected_types_with_values = self.types_with_values_by_category.get(
  388. category, []
  389. )
  390. suitable_types_with_values = []
  391. for type_with_values in preselected_types_with_values:
  392. if param.supports_type(type_with_values.data_type, arg_context.args):
  393. suitable_types_with_values.append(type_with_values)
  394. return suitable_types_with_values
  395. def _assert_valid_type_category_for_param(
  396. self, param: OperationParam, category: DataTypeCategory
  397. ) -> None:
  398. assert category not in {
  399. DataTypeCategory.DYNAMIC,
  400. }, f"Type category {category} not allowed for parameters (param={param})"
  401. def _get_operations_of_category(
  402. self,
  403. param: OperationParam,
  404. arg_context: ArgContext,
  405. must_use_aggregation: bool,
  406. allow_aggregation: bool,
  407. ) -> list[DbOperationOrFunction]:
  408. category = param.resolve_type_category(arg_context.args)
  409. suitable_operations = self._get_all_operations_of_category(param, category)
  410. if must_use_aggregation:
  411. return self._get_only_aggregate_operations(suitable_operations)
  412. elif not allow_aggregation:
  413. return self._get_without_aggregate_operations(suitable_operations)
  414. else:
  415. return suitable_operations
  416. def _get_all_operations_of_category(
  417. self, param: OperationParam, category: DataTypeCategory
  418. ) -> list[DbOperationOrFunction]:
  419. if category == DataTypeCategory.ANY:
  420. return self.input_data.operations_input.all_operation_types
  421. self._assert_valid_type_category_for_param(param, category)
  422. return self.operations_by_return_type_category.get(category, [])
  423. def _get_without_aggregate_operations(
  424. self, operations: list[DbOperationOrFunction]
  425. ) -> list[DbOperationOrFunction]:
  426. return self._get_operations_with_filter(
  427. operations, lambda op: not op.is_aggregation
  428. )
  429. def _get_only_aggregate_operations(
  430. self, operations: list[DbOperationOrFunction]
  431. ) -> list[DbOperationOrFunction]:
  432. return self._get_operations_with_filter(
  433. operations, lambda op: op.is_aggregation
  434. )
  435. def _get_operations_with_filter(
  436. self,
  437. operations: list[DbOperationOrFunction],
  438. op_filter: Callable[[DbOperationOrFunction], bool],
  439. ) -> list[DbOperationOrFunction]:
  440. matching_operations = []
  441. for operation in operations:
  442. if op_filter(operation):
  443. matching_operations.append(operation)
  444. return matching_operations
  445. def _get_operation_weights(
  446. self, operations: list[DbOperationOrFunction]
  447. ) -> list[float]:
  448. weights = []
  449. for operation in operations:
  450. weight = self.randomized_picker.convert_operation_relevance_to_number(
  451. operation.relevance
  452. )
  453. weights.append(weight)
  454. return weights
  455. def find_operations_by_predicate(
  456. self, match_op: Callable[[DbOperationOrFunction], bool]
  457. ) -> list[DbOperationOrFunction]:
  458. matched_ops = list()
  459. for op in self.selectable_operations:
  460. if match_op(op):
  461. matched_ops.append(op)
  462. return matched_ops
  463. def find_exactly_one_operation_by_predicate(
  464. self, match_op: Callable[[DbOperationOrFunction], bool]
  465. ) -> DbOperationOrFunction:
  466. operations = self.find_operations_by_predicate(match_op)
  467. if len(operations) == 0:
  468. raise RuntimeError("No operation matches!")
  469. if len(operations) > 1:
  470. raise RuntimeError(f"More than one operation matches: {operations}")
  471. return operations[0]
  472. def find_data_type_with_values_by_type_identifier(
  473. self, type_identifier: str
  474. ) -> DataTypeWithValues:
  475. for (
  476. data_type_with_values
  477. ) in self.input_data.types_input.all_data_types_with_values:
  478. if data_type_with_values.data_type.internal_identifier == type_identifier:
  479. return data_type_with_values
  480. raise RuntimeError(f"No data type found with identifier {type_identifier}")
  481. class NoSuitableExpressionFound(Exception):
  482. def __init__(self, message: str):
  483. super().__init__()
  484. self.message = message