SequentialEvaluator.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. from __future__ import annotations
  2. from typing import TYPE_CHECKING, Iterable
  3. from sentence_transformers.evaluation.SentenceEvaluator import SentenceEvaluator
  4. if TYPE_CHECKING:
  5. from sentence_transformers.SentenceTransformer import SentenceTransformer
  6. class SequentialEvaluator(SentenceEvaluator):
  7. """
  8. This evaluator allows that multiple sub-evaluators are passed. When the model is evaluated,
  9. the data is passed sequentially to all sub-evaluators.
  10. All scores are passed to 'main_score_function', which derives one final score value
  11. """
  12. def __init__(self, evaluators: Iterable[SentenceEvaluator], main_score_function=lambda scores: scores[-1]):
  13. """
  14. Initializes a SequentialEvaluator object.
  15. Args:
  16. evaluators (Iterable[SentenceEvaluator]): A collection of SentenceEvaluator objects.
  17. main_score_function (function, optional): A function that takes a list of scores and returns the main score.
  18. Defaults to selecting the last score in the list.
  19. Example:
  20. ::
  21. evaluator1 = BinaryClassificationEvaluator(...)
  22. evaluator2 = InformationRetrievalEvaluator(...)
  23. evaluator3 = MSEEvaluator(...)
  24. seq_evaluator = SequentialEvaluator([evaluator1, evaluator2, evaluator3])
  25. """
  26. super().__init__()
  27. self.evaluators = evaluators
  28. self.main_score_function = main_score_function
  29. def __call__(
  30. self, model: SentenceTransformer, output_path: str = None, epoch: int = -1, steps: int = -1
  31. ) -> dict[str, float]:
  32. evaluations = []
  33. scores = []
  34. for evaluator_idx, evaluator in enumerate(self.evaluators):
  35. evaluation = evaluator(model, output_path, epoch, steps)
  36. if not isinstance(evaluation, dict):
  37. scores.append(evaluation)
  38. evaluation = {f"evaluator_{evaluator_idx}": evaluation}
  39. else:
  40. if hasattr(evaluator, "primary_metric"):
  41. scores.append(evaluation[evaluator.primary_metric])
  42. else:
  43. scores.append(evaluation[list(evaluation.keys())[0]])
  44. evaluations.append(evaluation)
  45. self.primary_metric = "sequential_score"
  46. main_score = self.main_score_function(scores)
  47. results = {key: value for evaluation in evaluations for key, value in evaluation.items()}
  48. results["sequential_score"] = main_score
  49. return results