evaluation_strategy.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. # Copyright Materialize, Inc. and contributors. All rights reserved.
  2. #
  3. # Use of this software is governed by the Business Source License
  4. # included in the LICENSE file at the root of this repository.
  5. #
  6. # As of the Change Date specified in that file, in accordance with
  7. # the Business Source License, use of this software will be governed
  8. # by the Apache License, Version 2.0.
  9. from enum import Enum
  10. from materialize.output_consistency.data_type.data_type_with_values import (
  11. DataTypeWithValues,
  12. )
  13. from materialize.output_consistency.execution.sql_dialect_adjuster import (
  14. MzSqlDialectAdjuster,
  15. SqlDialectAdjuster,
  16. )
  17. from materialize.output_consistency.execution.value_storage_layout import (
  18. ROW_INDEX_COL_NAME,
  19. ValueStorageLayout,
  20. )
  21. from materialize.output_consistency.input_data.test_input_types import (
  22. ConsistencyTestTypesInput,
  23. )
  24. from materialize.output_consistency.query.data_source import DataSource
  25. from materialize.output_consistency.selection.column_selection import (
  26. ALL_TABLE_COLUMNS_BY_NAME_SELECTION,
  27. TableColumnByNameSelection,
  28. )
  29. from materialize.output_consistency.selection.row_selection import (
  30. ALL_ROWS_SELECTION,
  31. DataRowSelection,
  32. )
  33. EVALUATION_STRATEGY_NAME_DFR = "dataflow_rendering"
  34. EVALUATION_STRATEGY_NAME_CTF = "constant_folding"
  35. INTERNAL_EVALUATION_STRATEGY_NAMES = [
  36. EVALUATION_STRATEGY_NAME_DFR,
  37. EVALUATION_STRATEGY_NAME_CTF,
  38. ]
  39. class EvaluationStrategyKey(Enum):
  40. DUMMY = 1
  41. MZ_DATAFLOW_RENDERING = 2
  42. MZ_CONSTANT_FOLDING = 3
  43. POSTGRES = 4
  44. MZ_DATAFLOW_RENDERING_OTHER_DB = 5
  45. MZ_CONSTANT_FOLDING_OTHER_DB = 6
  46. class EvaluationStrategy:
  47. """Strategy how to execute a `QueryTemplate`"""
  48. def __init__(
  49. self,
  50. identifier: EvaluationStrategyKey,
  51. name: str,
  52. object_name_base: str,
  53. simple_db_object_name: str,
  54. sql_adjuster: SqlDialectAdjuster = MzSqlDialectAdjuster(),
  55. ):
  56. """
  57. :param identifier: identifier of this strategy
  58. :param name: readable name
  59. :param object_name_base: the db object name will be derived from this
  60. :param simple_db_object_name: only used by the reproduction code printer
  61. """
  62. self.identifier = identifier
  63. self.name = name
  64. self.object_name_base = object_name_base
  65. self.simple_db_object_name = simple_db_object_name
  66. self.sql_adjuster = sql_adjuster
  67. self.additional_setup_info: str | None = None
  68. def generate_sources(
  69. self,
  70. types_input: ConsistencyTestTypesInput,
  71. vertical_join_tables: int,
  72. ) -> list[str]:
  73. statements = []
  74. statements.extend(
  75. self.generate_source_for_storage_layout(
  76. types_input,
  77. ValueStorageLayout.HORIZONTAL,
  78. ALL_ROWS_SELECTION,
  79. ALL_TABLE_COLUMNS_BY_NAME_SELECTION,
  80. data_source=DataSource(table_index=None),
  81. )
  82. )
  83. for table_index in range(0, vertical_join_tables):
  84. statements.extend(
  85. self.generate_source_for_storage_layout(
  86. types_input,
  87. ValueStorageLayout.VERTICAL,
  88. ALL_ROWS_SELECTION,
  89. ALL_TABLE_COLUMNS_BY_NAME_SELECTION,
  90. data_source=DataSource(table_index=table_index),
  91. )
  92. )
  93. return statements
  94. def generate_source_for_storage_layout(
  95. self,
  96. types_input: ConsistencyTestTypesInput,
  97. storage_layout: ValueStorageLayout,
  98. row_selection: DataRowSelection,
  99. table_column_selection: TableColumnByNameSelection,
  100. data_source: DataSource,
  101. override_base_name: str | None = None,
  102. ) -> list[str]:
  103. raise NotImplementedError
  104. def get_db_object_name(
  105. self,
  106. storage_layout: ValueStorageLayout,
  107. data_source: DataSource,
  108. override_base_name: str | None = None,
  109. ) -> str:
  110. if storage_layout == ValueStorageLayout.ANY:
  111. raise RuntimeError(f"{storage_layout} has not been resolved to a real one")
  112. if override_base_name is None:
  113. storage_suffix = (
  114. "horiz" if storage_layout == ValueStorageLayout.HORIZONTAL else "vert"
  115. )
  116. base_name = f"{self.object_name_base}_{storage_suffix}"
  117. else:
  118. base_name = override_base_name
  119. return data_source.get_db_object_name(base_name=base_name)
  120. def __str__(self) -> str:
  121. return self.name
  122. def _create_column_specs(
  123. self,
  124. types_input: ConsistencyTestTypesInput,
  125. storage_layout: ValueStorageLayout,
  126. table_index: int | None,
  127. include_type: bool,
  128. table_column_selection: TableColumnByNameSelection,
  129. ) -> list[str]:
  130. column_specs = []
  131. # row index as first column (also for horizontal layout helpful to simplify aggregate functions with order spec)
  132. int_type_name = self.sql_adjuster.adjust_type("INT")
  133. type_info = f" {int_type_name}" if include_type else ""
  134. column_specs.append(f"{ROW_INDEX_COL_NAME}{type_info}")
  135. for type_with_values in types_input.all_data_types_with_values:
  136. type_name = self.sql_adjuster.adjust_type(
  137. type_with_values.data_type.type_name
  138. )
  139. type_info = f" {type_name}" if include_type else ""
  140. if storage_layout == ValueStorageLayout.HORIZONTAL:
  141. for data_value in type_with_values.raw_values:
  142. if table_column_selection.is_included(
  143. data_value.get_source_column_identifier()
  144. ):
  145. column_specs.append(f"{data_value.column_name}{type_info}")
  146. elif storage_layout == ValueStorageLayout.VERTICAL:
  147. column = type_with_values.create_assigned_vertical_storage_column(
  148. DataSource(table_index)
  149. )
  150. if table_column_selection.is_included(
  151. column.get_source_column_identifier()
  152. ):
  153. column_specs.append(f"{column.column_name}{type_info}")
  154. else:
  155. raise RuntimeError(f"Unsupported storage layout: {storage_layout}")
  156. return column_specs
  157. def _adjust_type_name(self, type_name: str) -> str:
  158. return type_name
  159. def _create_value_rows(
  160. self,
  161. types_input: ConsistencyTestTypesInput,
  162. storage_layout: ValueStorageLayout,
  163. row_selection: DataRowSelection,
  164. table_column_selection: TableColumnByNameSelection,
  165. data_source: DataSource,
  166. ) -> list[str]:
  167. if storage_layout == ValueStorageLayout.HORIZONTAL:
  168. assert (
  169. data_source.table_index is None
  170. ), "Table index is not supported for horizontal storage"
  171. return [
  172. self.__create_horizontal_value_row(
  173. types_input.all_data_types_with_values, table_column_selection
  174. )
  175. ]
  176. elif storage_layout == ValueStorageLayout.VERTICAL:
  177. return self.__create_vertical_value_rows(
  178. types_input.all_data_types_with_values,
  179. types_input.get_max_value_count_of_all_types(
  180. table_index=data_source.table_index
  181. ),
  182. row_selection,
  183. table_column_selection,
  184. data_source,
  185. )
  186. else:
  187. raise RuntimeError(f"Unsupported storage layout: {storage_layout}")
  188. def __create_horizontal_value_row(
  189. self,
  190. data_type_with_values: list[DataTypeWithValues],
  191. table_column_selection: TableColumnByNameSelection,
  192. ) -> str:
  193. row_values = []
  194. # row index
  195. row_values.append("0")
  196. for type_with_values in data_type_with_values:
  197. for data_value in type_with_values.raw_values:
  198. if table_column_selection.is_included(
  199. data_value.get_source_column_identifier()
  200. ):
  201. row_values.append(data_value.to_sql_as_value(self.sql_adjuster))
  202. return f"{', '.join(row_values)}"
  203. def __create_vertical_value_rows(
  204. self,
  205. data_type_with_values: list[DataTypeWithValues],
  206. row_count: int,
  207. row_selection: DataRowSelection,
  208. table_column_selection: TableColumnByNameSelection,
  209. data_source: DataSource,
  210. ) -> list[str]:
  211. """Creates table rows with the values of each type in a column. For types with fewer values, values are repeated."""
  212. rows = []
  213. for row_index in range(0, row_count):
  214. # the first column holds the row index
  215. row_values = [str(row_index)]
  216. for type_with_values in data_type_with_values:
  217. data_column = type_with_values.create_assigned_vertical_storage_column(
  218. data_source
  219. )
  220. if not table_column_selection.is_included(
  221. data_column.get_source_column_identifier()
  222. ):
  223. continue
  224. data_value = data_column.get_value_at_row(
  225. row_index, data_source.table_index
  226. )
  227. row_values.append(data_value.to_sql_as_value(self.sql_adjuster))
  228. if row_selection.is_included_in_source(data_source, row_index):
  229. rows.append(f"{', '.join(row_values)}")
  230. return rows
  231. class DummyEvaluation(EvaluationStrategy):
  232. def __init__(self) -> None:
  233. super().__init__(EvaluationStrategyKey.DUMMY, "Dummy", "<source>", "dummy")
  234. def generate_sources(
  235. self,
  236. types_input: ConsistencyTestTypesInput,
  237. vertical_join_tables: int,
  238. ) -> list[str]:
  239. return []
  240. class DataFlowRenderingEvaluation(EvaluationStrategy):
  241. def __init__(self) -> None:
  242. super().__init__(
  243. EvaluationStrategyKey.MZ_DATAFLOW_RENDERING,
  244. "Dataflow rendering",
  245. "t_dfr",
  246. "dataflow_rendering",
  247. )
  248. def generate_source_for_storage_layout(
  249. self,
  250. types_input: ConsistencyTestTypesInput,
  251. storage_layout: ValueStorageLayout,
  252. row_selection: DataRowSelection,
  253. table_column_selection: TableColumnByNameSelection,
  254. data_source: DataSource,
  255. override_base_name: str | None = None,
  256. ) -> list[str]:
  257. db_object_name = self.get_db_object_name(
  258. storage_layout,
  259. data_source,
  260. override_base_name=override_base_name,
  261. )
  262. statements = []
  263. column_specs = self._create_column_specs(
  264. types_input,
  265. storage_layout,
  266. data_source.table_index,
  267. True,
  268. table_column_selection,
  269. )
  270. statements.append(f"DROP TABLE IF EXISTS {db_object_name};")
  271. statements.append(f"CREATE TABLE {db_object_name} ({', '.join(column_specs)});")
  272. value_rows = self._create_value_rows(
  273. types_input,
  274. storage_layout,
  275. row_selection,
  276. table_column_selection,
  277. data_source,
  278. )
  279. for value_row in value_rows:
  280. statements.append(f"INSERT INTO {db_object_name} VALUES ({value_row});")
  281. return statements
  282. class ConstantFoldingEvaluation(EvaluationStrategy):
  283. def __init__(self) -> None:
  284. super().__init__(
  285. EvaluationStrategyKey.MZ_CONSTANT_FOLDING,
  286. "Constant folding",
  287. "v_ctf",
  288. "constant_folding",
  289. )
  290. def generate_source_for_storage_layout(
  291. self,
  292. types_input: ConsistencyTestTypesInput,
  293. storage_layout: ValueStorageLayout,
  294. row_selection: DataRowSelection,
  295. table_column_selection: TableColumnByNameSelection,
  296. data_source: DataSource,
  297. override_base_name: str | None = None,
  298. ) -> list[str]:
  299. db_object_name = self.get_db_object_name(
  300. storage_layout,
  301. data_source,
  302. override_base_name=override_base_name,
  303. )
  304. column_specs = self._create_column_specs(
  305. types_input,
  306. storage_layout,
  307. data_source.table_index,
  308. False,
  309. table_column_selection,
  310. )
  311. value_rows = self._create_value_rows(
  312. types_input,
  313. storage_layout,
  314. row_selection,
  315. table_column_selection,
  316. data_source,
  317. )
  318. value_specification = "\n UNION SELECT ".join(value_rows)
  319. create_view_statement = (
  320. f"CREATE OR REPLACE VIEW {db_object_name} ({', '.join(column_specs)})\n"
  321. f" AS SELECT {value_specification};"
  322. )
  323. return [create_view_statement]
  324. def create_internal_evaluation_strategy_twice(
  325. evaluation_strategy_name: str,
  326. ) -> list[EvaluationStrategy]:
  327. strategies: list[EvaluationStrategy]
  328. if evaluation_strategy_name == EVALUATION_STRATEGY_NAME_DFR:
  329. strategies = [DataFlowRenderingEvaluation(), DataFlowRenderingEvaluation()]
  330. strategies[1].identifier = EvaluationStrategyKey.MZ_DATAFLOW_RENDERING_OTHER_DB
  331. return strategies
  332. if evaluation_strategy_name == EVALUATION_STRATEGY_NAME_CTF:
  333. strategies = [ConstantFoldingEvaluation(), ConstantFoldingEvaluation()]
  334. strategies[1].identifier = EvaluationStrategyKey.MZ_CONSTANT_FOLDING_OTHER_DB
  335. return strategies
  336. raise RuntimeError(f"Unexpected strategy name: { evaluation_strategy_name}")
  337. def is_other_db_evaluation_strategy(evaluation_key: EvaluationStrategyKey) -> bool:
  338. return evaluation_key in {
  339. EvaluationStrategyKey.MZ_DATAFLOW_RENDERING_OTHER_DB,
  340. EvaluationStrategyKey.MZ_CONSTANT_FOLDING_OTHER_DB,
  341. }
  342. def is_data_flow_rendering(evaluation_key: EvaluationStrategyKey) -> bool:
  343. return evaluation_key in {
  344. EvaluationStrategyKey.MZ_DATAFLOW_RENDERING,
  345. EvaluationStrategyKey.MZ_DATAFLOW_RENDERING_OTHER_DB,
  346. }
  347. def is_constant_folding(evaluation_key: EvaluationStrategyKey) -> bool:
  348. return evaluation_key in {
  349. EvaluationStrategyKey.MZ_CONSTANT_FOLDING,
  350. EvaluationStrategyKey.MZ_CONSTANT_FOLDING_OTHER_DB,
  351. }