aggregation.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # Copyright Materialize, Inc. and contributors. All rights reserved.
  2. #
  3. # Use of this software is governed by the Business Source License
  4. # included in the LICENSE file at the root of this repository.
  5. #
  6. # As of the Change Date specified in that file, in accordance with
  7. # the Business Source License, use of this software will be governed
  8. # by the Apache License, Version 2.0.
  9. import statistics
  10. from collections.abc import Callable
  11. from typing import Any
  12. import numpy as np
  13. from materialize.feature_benchmark.measurement import (
  14. Measurement,
  15. MeasurementType,
  16. MeasurementUnit,
  17. )
  18. class Aggregation:
  19. def __init__(self) -> None:
  20. self.measurement_type: MeasurementType | None = None
  21. self._data: list[float] = []
  22. self._unit: MeasurementUnit = MeasurementUnit.UNKNOWN
  23. def append_measurement(self, measurement: Measurement) -> None:
  24. assert measurement.unit != MeasurementUnit.UNKNOWN, "Unknown unit"
  25. self.measurement_type = measurement.type
  26. self._unit = measurement.unit
  27. self._data.append(measurement.value)
  28. def aggregate(self) -> Any:
  29. if len(self._data) == 0:
  30. return None
  31. return self.func()([*self._data])
  32. def unit(self) -> MeasurementUnit:
  33. return self._unit
  34. def func(self) -> Callable:
  35. raise NotImplementedError
  36. def name(self) -> str:
  37. return self.__class__.__name__
  38. class MinAggregation(Aggregation):
  39. def func(self) -> Callable:
  40. return min
  41. class MeanAggregation(Aggregation):
  42. def func(self) -> Callable:
  43. return np.mean
  44. class StdDevAggregation(Aggregation):
  45. def __init__(self, num_stdevs: float) -> None:
  46. super().__init__()
  47. self._num_stdevs = num_stdevs
  48. def aggregate(self) -> float | None:
  49. if len(self._data) == 0:
  50. return None
  51. stdev: float = np.std(self._data, dtype=float)
  52. mean: float = np.mean(self._data, dtype=float)
  53. val = mean - (stdev * self._num_stdevs)
  54. return val
  55. class NormalDistributionAggregation(Aggregation):
  56. def aggregate(self) -> statistics.NormalDist | None:
  57. if len(self._data) == 0:
  58. return None
  59. return statistics.NormalDist(
  60. mu=np.mean(self._data, dtype=float), sigma=np.std(self._data, dtype=float)
  61. )
  62. class NoAggregation(Aggregation):
  63. def aggregate(self) -> Any:
  64. if len(self._data) == 0:
  65. return None
  66. return self._data[0]