123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122 |
- # Copyright Materialize, Inc. and contributors. All rights reserved.
- #
- # Use of this software is governed by the Business Source License
- # included in the LICENSE file at the root of this repository.
- #
- # As of the Change Date specified in that file, in accordance with
- # the Business Source License, use of this software will be governed
- # by the Apache License, Version 2.0.
- from psycopg import Cursor, ProgrammingError
- from materialize.scalability.operation.operation_data import OperationData
- class Operation:
- def required_keys(self) -> set[str]:
- """
- Keys in the data dictionary that are required.
- """
- return set()
- def produced_keys(self) -> set[str]:
- """
- Keys in the data dictionary that will be added or updated.
- """
- return set()
- def execute(self, data: OperationData) -> OperationData:
- data.validate_requirements(self.required_keys(), self.__class__, "requires")
- data = self._execute(data)
- data.validate_requirements(self.produced_keys(), self.__class__, "produces")
- return data
- def _execute(self, data: OperationData) -> OperationData:
- raise NotImplementedError
- def __str__(self) -> str:
- return self.__class__.__name__
- class SqlOperationWithInput(Operation):
- def required_keys(self) -> set[str]:
- return {"cursor"}.union(self.required_input_keys())
- def required_input_keys(self) -> set[str]:
- raise NotImplementedError
- def _execute(self, data: OperationData) -> OperationData:
- try:
- cursor: Cursor = data.cursor()
- cursor.execute(self.sql_statement_based_on_input(data).encode("utf8"))
- cursor.fetchall()
- except ProgrammingError as e:
- assert "the last operation didn't produce a result" in str(e)
- return data
- def sql_statement_based_on_input(self, input: OperationData) -> str:
- raise NotImplementedError
- class SqlOperationWithSeed(SqlOperationWithInput):
- def __init__(self, seed_key: str):
- self.seed_key = seed_key
- def required_input_keys(self) -> set[str]:
- return {self.seed_key}
- def sql_statement_based_on_input(self, input: OperationData) -> str:
- return self.sql_statement(str(input.get(self.seed_key)))
- def sql_statement(self, seed: str) -> str:
- raise NotImplementedError
- class SqlOperationWithTwoSeeds(SqlOperationWithInput):
- def __init__(self, seed_key1: str, seed_key2: str):
- self.seed_key1 = seed_key1
- self.seed_key2 = seed_key2
- def required_input_keys(self) -> set[str]:
- return {self.seed_key1, self.seed_key2}
- def sql_statement_based_on_input(self, input: OperationData) -> str:
- return self.sql_statement(
- str(input.get(self.seed_key1)), str(input.get(self.seed_key2))
- )
- def sql_statement(self, seed1: str, seed2: str) -> str:
- raise NotImplementedError
- class SimpleSqlOperation(SqlOperationWithInput):
- def required_input_keys(self) -> set[str]:
- return set()
- def sql_statement_based_on_input(self, _input: OperationData) -> str:
- return self.sql_statement()
- def sql_statement(self) -> str:
- raise NotImplementedError
- class OperationChainWithDataExchange(Operation):
- def __init__(self, operations: list[Operation]):
- assert len(operations) > 0, "requires at least one operation"
- self.ops = operations
- def required_keys(self) -> set[str]:
- return self.operations()[0].required_keys()
- def operations(self) -> list[Operation]:
- return self.ops
- def _execute(self, data: OperationData) -> OperationData:
- for operation in self.operations():
- data = operation.execute(data)
- return data
- def __str__(self) -> str:
- return f"{self.__class__.__name__} with {', '.join(str(op) for op in self.operations())}"
|