randomized_picker.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. # Copyright Materialize, Inc. and contributors. All rights reserved.
  2. #
  3. # Use of this software is governed by the Business Source License
  4. # included in the LICENSE file at the root of this repository.
  5. #
  6. # As of the Change Date specified in that file, in accordance with
  7. # the Business Source License, use of this software will be governed
  8. # by the Apache License, Version 2.0.
  9. import random
  10. from materialize.output_consistency.common import probability
  11. from materialize.output_consistency.common.configuration import (
  12. ConsistencyTestConfiguration,
  13. )
  14. from materialize.output_consistency.data_type.data_type_with_values import (
  15. DataTypeWithValues,
  16. )
  17. from materialize.output_consistency.data_value.data_value import DataValue
  18. from materialize.output_consistency.operation.operation import (
  19. DbOperationOrFunction,
  20. OperationRelevance,
  21. )
  22. from materialize.output_consistency.query.data_source import (
  23. DataSource,
  24. )
  25. from materialize.output_consistency.query.join import (
  26. JOIN_TARGET_WEIGHTS,
  27. JoinOperator,
  28. JoinTarget,
  29. )
  30. class RandomizedPicker:
  31. def __init__(self, config: ConsistencyTestConfiguration):
  32. self.config = config
  33. random.seed(self.config.random_seed)
  34. def random_boolean(self, probability_for_true: float = 0.5) -> bool:
  35. assert (
  36. 0 <= probability_for_true <= 1
  37. ), f"Invalid probability: {probability_for_true}"
  38. weights = [probability_for_true, 1 - probability_for_true]
  39. return random.choices([True, False], k=1, weights=weights)[0]
  40. def random_number(self, min_value_incl: int, max_value_incl: int) -> int:
  41. return random.randint(min_value_incl, max_value_incl)
  42. def random_operation(
  43. self, operations: list[DbOperationOrFunction], weights: list[float]
  44. ) -> DbOperationOrFunction:
  45. return random.choices(operations, k=1, weights=weights)[0]
  46. def random_type_with_values(
  47. self, types_with_values: list[DataTypeWithValues]
  48. ) -> DataTypeWithValues:
  49. return random.choice(types_with_values)
  50. def random_row_indices(
  51. self, vertical_storage_row_count: int, max_number_of_rows_to_select: int
  52. ) -> set[int]:
  53. selected_rows = random.choices(
  54. range(0, vertical_storage_row_count), k=max_number_of_rows_to_select
  55. )
  56. return set(selected_rows)
  57. def random_value(self, values: list[DataValue]) -> DataValue:
  58. return random.choice(values)
  59. def random_data_source(self, sources: list[DataSource]) -> DataSource:
  60. assert len(sources) > 0, "No data sources available"
  61. if self.random_boolean(
  62. probability.COLUMN_SELECTION_ADDITIONAL_CHANCE_FOR_FIRST_TABLE
  63. ):
  64. # give the first data source a higher chance so that not all queries need a join
  65. return sources[0]
  66. return random.choice(sources)
  67. def convert_operation_relevance_to_number(
  68. self, relevance: OperationRelevance
  69. ) -> float:
  70. if relevance == OperationRelevance.EXTREME_HIGH:
  71. return 100
  72. if relevance == OperationRelevance.HIGH:
  73. return 0.8
  74. elif relevance == OperationRelevance.DEFAULT:
  75. return 0.5
  76. elif relevance == OperationRelevance.LOW:
  77. return 0.2
  78. else:
  79. raise RuntimeError(f"Unexpected value: {relevance}")
  80. def _random_bool(self, probability: float) -> bool:
  81. return random.random() < probability
  82. def random_join_operator(self) -> JoinOperator:
  83. return random.choice(list(JoinOperator))
  84. def random_join_target(self) -> JoinTarget:
  85. return random.choices(list(JoinTarget), k=1, weights=JOIN_TARGET_WEIGHTS)[0]