scalability_operation.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # Copyright Materialize, Inc. and contributors. All rights reserved.
  2. #
  3. # Use of this software is governed by the Business Source License
  4. # included in the LICENSE file at the root of this repository.
  5. #
  6. # As of the Change Date specified in that file, in accordance with
  7. # the Business Source License, use of this software will be governed
  8. # by the Apache License, Version 2.0.
  9. from psycopg import Cursor, ProgrammingError
  10. from materialize.scalability.operation.operation_data import OperationData
  11. class Operation:
  12. def required_keys(self) -> set[str]:
  13. """
  14. Keys in the data dictionary that are required.
  15. """
  16. return set()
  17. def produced_keys(self) -> set[str]:
  18. """
  19. Keys in the data dictionary that will be added or updated.
  20. """
  21. return set()
  22. def execute(self, data: OperationData) -> OperationData:
  23. data.validate_requirements(self.required_keys(), self.__class__, "requires")
  24. data = self._execute(data)
  25. data.validate_requirements(self.produced_keys(), self.__class__, "produces")
  26. return data
  27. def _execute(self, data: OperationData) -> OperationData:
  28. raise NotImplementedError
  29. def __str__(self) -> str:
  30. return self.__class__.__name__
  31. class SqlOperationWithInput(Operation):
  32. def required_keys(self) -> set[str]:
  33. return {"cursor"}.union(self.required_input_keys())
  34. def required_input_keys(self) -> set[str]:
  35. raise NotImplementedError
  36. def _execute(self, data: OperationData) -> OperationData:
  37. try:
  38. cursor: Cursor = data.cursor()
  39. cursor.execute(self.sql_statement_based_on_input(data).encode("utf8"))
  40. cursor.fetchall()
  41. except ProgrammingError as e:
  42. assert "the last operation didn't produce a result" in str(e)
  43. return data
  44. def sql_statement_based_on_input(self, input: OperationData) -> str:
  45. raise NotImplementedError
  46. class SqlOperationWithSeed(SqlOperationWithInput):
  47. def __init__(self, seed_key: str):
  48. self.seed_key = seed_key
  49. def required_input_keys(self) -> set[str]:
  50. return {self.seed_key}
  51. def sql_statement_based_on_input(self, input: OperationData) -> str:
  52. return self.sql_statement(str(input.get(self.seed_key)))
  53. def sql_statement(self, seed: str) -> str:
  54. raise NotImplementedError
  55. class SqlOperationWithTwoSeeds(SqlOperationWithInput):
  56. def __init__(self, seed_key1: str, seed_key2: str):
  57. self.seed_key1 = seed_key1
  58. self.seed_key2 = seed_key2
  59. def required_input_keys(self) -> set[str]:
  60. return {self.seed_key1, self.seed_key2}
  61. def sql_statement_based_on_input(self, input: OperationData) -> str:
  62. return self.sql_statement(
  63. str(input.get(self.seed_key1)), str(input.get(self.seed_key2))
  64. )
  65. def sql_statement(self, seed1: str, seed2: str) -> str:
  66. raise NotImplementedError
  67. class SimpleSqlOperation(SqlOperationWithInput):
  68. def required_input_keys(self) -> set[str]:
  69. return set()
  70. def sql_statement_based_on_input(self, _input: OperationData) -> str:
  71. return self.sql_statement()
  72. def sql_statement(self) -> str:
  73. raise NotImplementedError
  74. class OperationChainWithDataExchange(Operation):
  75. def __init__(self, operations: list[Operation]):
  76. assert len(operations) > 0, "requires at least one operation"
  77. self.ops = operations
  78. def required_keys(self) -> set[str]:
  79. return self.operations()[0].required_keys()
  80. def operations(self) -> list[Operation]:
  81. return self.ops
  82. def _execute(self, data: OperationData) -> OperationData:
  83. for operation in self.operations():
  84. data = operation.execute(data)
  85. return data
  86. def __str__(self) -> str:
  87. return f"{self.__class__.__name__} with {', '.join(str(op) for op in self.operations())}"