termination.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. # Copyright Materialize, Inc. and contributors. All rights reserved.
  2. #
  3. # Use of this software is governed by the Business Source License
  4. # included in the LICENSE file at the root of this repository.
  5. #
  6. # As of the Change Date specified in that file, in accordance with
  7. # the Business Source License, use of this software will be governed
  8. # by the Apache License, Version 2.0.
  9. import statistics
  10. import numpy as np
  11. from scipy import stats # type: ignore
  12. from materialize.feature_benchmark.measurement import Measurement
  13. class TerminationCondition:
  14. def __init__(self, threshold: float) -> None:
  15. self._threshold = threshold
  16. self._data: list[float] = []
  17. def terminate(self, measurement: Measurement) -> bool:
  18. raise NotImplementedError
  19. class NormalDistributionOverlap(TerminationCondition):
  20. """Signal termination if the overlap between the two distributions is above the threshold"""
  21. def __init__(self, threshold: float) -> None:
  22. self._last_fit: statistics.NormalDist | None = None
  23. super().__init__(threshold=threshold)
  24. def terminate(self, measurement: Measurement) -> bool:
  25. self._data.append(measurement.value)
  26. if len(self._data) > 10:
  27. (mu, sigma) = stats.norm.fit(self._data)
  28. current_fit = statistics.NormalDist(mu=mu, sigma=sigma)
  29. if self._last_fit:
  30. current_overlap = current_fit.overlap(other=self._last_fit)
  31. if current_overlap >= self._threshold:
  32. return True
  33. self._last_fit = current_fit
  34. return False
  35. class ProbForMin(TerminationCondition):
  36. """Signal termination if the probability that a new value will arrive that is smaller than all the previous values
  37. has dropped below the threshold
  38. """
  39. def terminate(self, measurement: Measurement) -> bool:
  40. self._data.append(measurement.value)
  41. if len(self._data) > 5:
  42. mean = np.mean(self._data)
  43. stdev = np.std(self._data)
  44. min_val = np.min(self._data)
  45. dist = stats.norm(loc=mean, scale=stdev)
  46. prob = dist.cdf(min_val)
  47. if prob < (1 - self._threshold):
  48. return True
  49. else:
  50. return False
  51. else:
  52. return False
  53. class RunAtMost(TerminationCondition):
  54. def terminate(self, measurement: Measurement) -> bool:
  55. self._data.append(measurement.value)
  56. return len(self._data) >= self._threshold