enum_operation_param.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. # Copyright Materialize, Inc. and contributors. All rights reserved.
  2. #
  3. # Use of this software is governed by the Business Source License
  4. # included in the LICENSE file at the root of this repository.
  5. #
  6. # As of the Change Date specified in that file, in accordance with
  7. # the Business Source License, use of this software will be governed
  8. # by the Apache License, Version 2.0.
  9. from materialize.output_consistency.data_type.data_type import DataType
  10. from materialize.output_consistency.data_type.data_type_category import DataTypeCategory
  11. from materialize.output_consistency.enum.enum_constant import EnumConstant
  12. from materialize.output_consistency.enum.enum_data_type import EnumDataType
  13. from materialize.output_consistency.expression.expression import Expression
  14. from materialize.output_consistency.expression.expression_characteristics import (
  15. ExpressionCharacteristics,
  16. )
  17. from materialize.output_consistency.generators.arg_context import ArgContext
  18. from materialize.output_consistency.operation.volatile_data_operation_param import (
  19. VolatileDataOperationParam,
  20. )
  21. from materialize.output_consistency.selection.randomized_picker import RandomizedPicker
  22. _INDEX_OF_NULL_VALUE = 0
  23. class EnumConstantOperationParam(VolatileDataOperationParam):
  24. def __init__(
  25. self,
  26. values: list[str],
  27. add_quotes: bool,
  28. add_invalid_value: bool = True,
  29. add_null_value: bool = True,
  30. optional: bool = False,
  31. invalid_value: str = "invalid_value_123",
  32. tags: set[str] | None = None,
  33. ):
  34. super().__init__(
  35. DataTypeCategory.ENUM,
  36. optional=optional,
  37. )
  38. assert len(values) == len(set(values)), f"Values contain duplicates {values}"
  39. self.values = values
  40. if add_invalid_value:
  41. self.invalid_value = invalid_value
  42. self.values.append(invalid_value)
  43. else:
  44. self.invalid_value = None
  45. self.add_null_value = add_null_value
  46. null_value = "NULL"
  47. if add_null_value:
  48. # NULL value must be at the beginning
  49. self.values.insert(
  50. _INDEX_OF_NULL_VALUE,
  51. null_value,
  52. )
  53. self.add_quotes = add_quotes
  54. self.characteristics_per_value: dict[str, set[ExpressionCharacteristics]] = (
  55. dict()
  56. )
  57. for value in self.values:
  58. self.characteristics_per_value[value] = set()
  59. if add_invalid_value:
  60. self.characteristics_per_value[invalid_value].add(
  61. ExpressionCharacteristics.ENUM_INVALID
  62. )
  63. if add_null_value:
  64. self.characteristics_per_value[null_value].add(
  65. ExpressionCharacteristics.NULL
  66. )
  67. self.tags = tags
  68. def supports_type(
  69. self, data_type: DataType, previous_args: list[Expression]
  70. ) -> bool:
  71. return isinstance(data_type, EnumDataType)
  72. def get_enum_constant(self, index: int) -> EnumConstant:
  73. assert (
  74. 0 <= index < len(self.values)
  75. ), f"Index {index} out of range in list with {len(self.values)} values: {self.values}"
  76. value = self.values[index]
  77. characteristics = self.characteristics_per_value[value]
  78. quote_value = self.add_quotes
  79. if self.add_null_value and index == 0:
  80. quote_value = False
  81. return EnumConstant(value, quote_value, characteristics, self.tags)
  82. def get_valid_values(self) -> list[str]:
  83. return [
  84. value
  85. for index, value in enumerate(self.values)
  86. if value != self.invalid_value
  87. and (index != _INDEX_OF_NULL_VALUE or not self.add_null_value)
  88. ]
  89. def generate_expression(
  90. self, arg_context: ArgContext, randomized_picker: RandomizedPicker
  91. ) -> Expression:
  92. enum_constant_index = randomized_picker.random_number(0, len(self.values) - 1)
  93. return self.get_enum_constant(enum_constant_index)