data_column.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. # Copyright Materialize, Inc. and contributors. All rights reserved.
  2. #
  3. # Use of this software is governed by the Business Source License
  4. # included in the LICENSE file at the root of this repository.
  5. #
  6. # As of the Change Date specified in that file, in accordance with
  7. # the Business Source License, use of this software will be governed
  8. # by the Apache License, Version 2.0.
  9. from materialize.output_consistency.data_type.data_type import DataType
  10. from materialize.output_consistency.data_type.data_type_category import DataTypeCategory
  11. from materialize.output_consistency.data_value.data_value import DataValue
  12. from materialize.output_consistency.data_value.source_column_identifier import (
  13. SourceColumnIdentifier,
  14. )
  15. from materialize.output_consistency.execution.value_storage_layout import (
  16. ValueStorageLayout,
  17. )
  18. from materialize.output_consistency.expression.expression import LeafExpression
  19. from materialize.output_consistency.expression.expression_characteristics import (
  20. ExpressionCharacteristics,
  21. )
  22. from materialize.output_consistency.operation.return_type_spec import ReturnTypeSpec
  23. from materialize.output_consistency.query.data_source import DataSource
  24. from materialize.output_consistency.selection.row_selection import DataRowSelection
  25. class DataColumn(LeafExpression):
  26. """A column with a value per row (in contrast to an `ExpressionWithArgs`) for VERTICAL storage"""
  27. def __init__(self, data_type: DataType, row_values_of_column: list[DataValue]):
  28. column_name = f"{data_type.internal_identifier.lower()}_val"
  29. # data_source will be assigned later
  30. super().__init__(
  31. column_name,
  32. data_type,
  33. set(),
  34. ValueStorageLayout.VERTICAL,
  35. data_source=None,
  36. is_aggregate=False,
  37. is_expect_error=False,
  38. )
  39. self.values = row_values_of_column
  40. def assign_data_source(self, data_source: DataSource, force: bool) -> None:
  41. if self.data_source is not None:
  42. if self.is_shared:
  43. # the source has already been set
  44. return
  45. if not force:
  46. raise RuntimeError("Data source already assigned")
  47. self.data_source = data_source
  48. def resolve_return_type_spec(self) -> ReturnTypeSpec:
  49. # do not provide characteristics on purpose, the spec of this class is not value-specific
  50. return self.data_type.resolve_return_type_spec(set())
  51. def resolve_return_type_category(self) -> DataTypeCategory:
  52. return self.data_type.category
  53. def recursively_collect_involved_characteristics(
  54. self, row_selection: DataRowSelection
  55. ) -> set[ExpressionCharacteristics]:
  56. involved_characteristics: set[ExpressionCharacteristics] = set()
  57. selected_values = self.get_values_at_rows(
  58. row_selection,
  59. table_index=(
  60. self.data_source.table_index if self.data_source is not None else None
  61. ),
  62. )
  63. for value in selected_values:
  64. characteristics_of_value = (
  65. value.recursively_collect_involved_characteristics(row_selection)
  66. )
  67. involved_characteristics = involved_characteristics.union(
  68. characteristics_of_value
  69. )
  70. return involved_characteristics
  71. def collect_vertical_table_indices(self) -> set[int]:
  72. return set()
  73. def get_filtered_values(self, row_selection: DataRowSelection) -> list[DataValue]:
  74. assert self.data_source is not None
  75. if row_selection.includes_all_of_source(self.data_source):
  76. return self.values
  77. selected_rows = []
  78. for row_index, row_value in enumerate(self.values):
  79. if row_selection.is_included_in_source(self.data_source, row_index):
  80. selected_rows.append(row_value)
  81. return selected_rows
  82. def get_values_at_rows(
  83. self, row_selection: DataRowSelection, table_index: int | None
  84. ) -> list[DataValue]:
  85. if row_selection.includes_all_of_all_sources():
  86. return self.values
  87. if self.data_source is None:
  88. # still unknown, provide all values
  89. return self.values
  90. if row_selection.includes_all_of_source(self.data_source):
  91. return self.values
  92. values = []
  93. for row_index in row_selection.get_row_indices(self.data_source):
  94. values.append(self.get_value_at_row(row_index, table_index))
  95. return values
  96. def get_value_at_row(
  97. self,
  98. row_index: int,
  99. table_index: int | None,
  100. ) -> DataValue:
  101. """All types need to have the same number of rows, but not all have the same number of distinct values. After
  102. having iterated through of all values of the given type, begin repeating values but skip the NULL value, which
  103. is known to be the first value of all types.
  104. :param row_index: an arbitrary, positive number, may be out of the value range
  105. """
  106. values_of_table = self._get_values_of_table(table_index)
  107. assert len(values_of_table) > 0, f"No values for table index {table_index}"
  108. # if there is a NULL value, it will always be at position 0; we can only exclude it if we have other values
  109. has_null_value_to_exclude = (
  110. values_of_table[0].is_null_value and len(values_of_table) > 1
  111. )
  112. value_index = row_index
  113. if value_index >= len(values_of_table):
  114. null_value_offset = 1 if has_null_value_to_exclude else 0
  115. available_value_count = len(values_of_table) - (
  116. 1 if has_null_value_to_exclude else 0
  117. )
  118. value_index = null_value_offset + (
  119. (value_index - null_value_offset) % available_value_count
  120. )
  121. return values_of_table[value_index]
  122. def _get_values_of_table(self, table_index: int | None) -> list[DataValue]:
  123. return [
  124. value
  125. for value in self.values
  126. if table_index is None or table_index in value.vertical_table_indices
  127. ]
  128. def get_data_source(self) -> DataSource | None:
  129. assert self.data_source is not None, "Data source not assigned"
  130. return self.data_source
  131. def get_source_column_identifier(self) -> SourceColumnIdentifier:
  132. source_column_identifier = super().get_source_column_identifier()
  133. assert source_column_identifier is not None
  134. return source_column_identifier
  135. def __str__(self) -> str:
  136. return f"DataValue (column='{self.column_name}', type={self.data_type})"