column_selection.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. # Copyright Materialize, Inc. and contributors. All rights reserved.
  2. #
  3. # Use of this software is governed by the Business Source License
  4. # included in the LICENSE file at the root of this repository.
  5. #
  6. # As of the Change Date specified in that file, in accordance with
  7. # the Business Source License, use of this software will be governed
  8. # by the Apache License, Version 2.0.
  9. from typing import Generic, TypeVar
  10. from materialize.output_consistency.data_value.source_column_identifier import (
  11. SourceColumnIdentifier,
  12. )
  13. from materialize.output_consistency.query.data_source import DataSource
  14. T = TypeVar("T")
  15. class SelectionByKey(Generic[T]):
  16. def __init__(self, keys: set[T] | None = None):
  17. self.keys = keys
  18. def includes_all(self) -> bool:
  19. return self.keys is None
  20. def is_included(self, key: T) -> bool:
  21. if self.keys is None:
  22. return True
  23. return key in self.keys
  24. def __str__(self) -> str:
  25. filter_string = ""
  26. if self.keys is not None:
  27. filter_string = ", ".join(str(key) for key in self.keys)
  28. return f"{type(self).__name__}({filter_string})"
  29. class QueryColumnByIndexSelection(SelectionByKey[int]):
  30. def __init__(self, column_indices: set[int] | None = None):
  31. """
  32. :param column_indices: name of selected columns; all columns if not specified
  33. """
  34. super().__init__(column_indices)
  35. class TableColumnByNameSelection(SelectionByKey[SourceColumnIdentifier]):
  36. def __init__(self, column_identifiers: set[SourceColumnIdentifier] | None = None):
  37. """
  38. :param column_identifiers: identifiers of selected columns; all columns if not specified
  39. """
  40. super().__init__(column_identifiers)
  41. def requires_data_source(self, data_source: DataSource) -> bool:
  42. if self.includes_all():
  43. return True
  44. assert self.keys is not None
  45. for column_identifier in self.keys:
  46. if data_source.alias() == column_identifier.data_source_alias:
  47. return True
  48. return False
  49. ALL_QUERY_COLUMNS_BY_INDEX_SELECTION = QueryColumnByIndexSelection()
  50. ALL_TABLE_COLUMNS_BY_NAME_SELECTION = TableColumnByNameSelection()