row_indices_expression.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # Copyright Materialize, Inc. and contributors. All rights reserved.
  2. #
  3. # Use of this software is governed by the Business Source License
  4. # included in the LICENSE file at the root of this repository.
  5. #
  6. # As of the Change Date specified in that file, in accordance with
  7. # the Business Source License, use of this software will be governed
  8. # by the Apache License, Version 2.0.
  9. from __future__ import annotations
  10. from materialize.output_consistency.data_type.data_type_category import DataTypeCategory
  11. from materialize.output_consistency.execution.sql_dialect_adjuster import (
  12. SqlDialectAdjuster,
  13. )
  14. from materialize.output_consistency.execution.value_storage_layout import (
  15. ROW_INDEX_COL_NAME,
  16. ValueStorageLayout,
  17. )
  18. from materialize.output_consistency.expression.expression import (
  19. Expression,
  20. LeafExpression,
  21. )
  22. from materialize.output_consistency.input_data.types.array_type_provider import (
  23. ArrayDataType,
  24. )
  25. from materialize.output_consistency.operation.return_type_spec import ReturnTypeSpec
  26. from materialize.output_consistency.query.data_source import DataSource
  27. _INT_ARRAY_TYPE = ArrayDataType(
  28. "INT_ARRAY",
  29. type_name="INT[]",
  30. array_entry_value_1="1",
  31. array_entry_value_2="2",
  32. value_type_category=DataTypeCategory.NUMERIC,
  33. )
  34. class RowIndicesExpression(LeafExpression):
  35. def __init__(self, expression_to_share_data_source: Expression):
  36. # data source will be derived dynamically
  37. super().__init__(
  38. column_name="<row_indices>",
  39. data_type=_INT_ARRAY_TYPE,
  40. characteristics=set(),
  41. storage_layout=ValueStorageLayout.ANY,
  42. data_source=None,
  43. )
  44. self.expression_to_share_data_source = expression_to_share_data_source
  45. def resolve_return_type_spec(self) -> ReturnTypeSpec:
  46. return self.data_type.resolve_return_type_spec(self.own_characteristics)
  47. def resolve_return_type_category(self) -> DataTypeCategory:
  48. return self.data_type.category
  49. def get_data_source(self) -> DataSource | None:
  50. data_sources = self.expression_to_share_data_source.collect_data_sources()
  51. if len(data_sources) == 0:
  52. # this happens when the expression is a constant
  53. return None
  54. # we can only return one data source here but that does not really matter because we only reuse already used
  55. # data sources
  56. return data_sources[0]
  57. def to_sql(
  58. self, sql_adjuster: SqlDialectAdjuster, include_alias: bool, is_root_level: bool
  59. ) -> str:
  60. data_sources = self.expression_to_share_data_source.collect_data_sources()
  61. if len(data_sources) == 0:
  62. # We won't use row_index in this case but a constant instead to avoid a potentially ambiguous column
  63. # reference
  64. return "0"
  65. expressions = []
  66. for data_source in data_sources:
  67. expressions.append(
  68. super().to_sql_as_column(
  69. sql_adjuster, include_alias, ROW_INDEX_COL_NAME, data_source
  70. )
  71. )
  72. array_elements = ",".join(expressions)
  73. return f"ARRAY[{array_elements}]::INT[]"
  74. def collect_vertical_table_indices(self) -> set[int]:
  75. # not relevant because this is already handled by the column sharing the data source
  76. return set()
  77. def __str__(self) -> str:
  78. return f"RowIndicesExpression (expression_to_share_data_source={self.expression_to_share_data_source})"