MSEEvaluatorFromDataFrame.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. from __future__ import annotations
  2. import csv
  3. import logging
  4. import os
  5. from contextlib import nullcontext
  6. from typing import TYPE_CHECKING
  7. import numpy as np
  8. from sentence_transformers.evaluation.SentenceEvaluator import SentenceEvaluator
  9. if TYPE_CHECKING:
  10. from sentence_transformers.SentenceTransformer import SentenceTransformer
  11. logger = logging.getLogger(__name__)
  12. class MSEEvaluatorFromDataFrame(SentenceEvaluator):
  13. """
  14. Computes the mean squared error (x100) between the computed sentence embedding and some target sentence embedding.
  15. Args:
  16. dataframe (List[Dict[str, str]]): It must have the following format. Rows contains different, parallel sentences.
  17. Columns are the respective language codes::
  18. [{'en': 'My sentence in English', 'es': 'Oración en español', 'fr': 'Phrase en français'...},
  19. {'en': 'My second sentence', ...}]
  20. teacher_model (SentenceTransformer): The teacher model used to compute the sentence embeddings.
  21. combinations (List[Tuple[str, str]]): Must be of the format ``[('en', 'es'), ('en', 'fr'), ...]``.
  22. First entry in a tuple is the source language. The sentence in the respective language will be fetched from
  23. the dataframe and passed to the teacher model. Second entry in a tuple the the target language. Sentence
  24. will be fetched from the dataframe and passed to the student model
  25. batch_size (int, optional): The batch size to compute sentence embeddings. Defaults to 8.
  26. name (str, optional): The name of the evaluator. Defaults to "".
  27. write_csv (bool, optional): Whether to write the results to a CSV file. Defaults to True.
  28. truncate_dim (Optional[int], optional): The dimension to truncate sentence embeddings to. If None, uses the model's
  29. current truncation dimension. Defaults to None.
  30. """
  31. def __init__(
  32. self,
  33. dataframe: list[dict[str, str]],
  34. teacher_model: SentenceTransformer,
  35. combinations: list[tuple[str, str]],
  36. batch_size: int = 8,
  37. name: str = "",
  38. write_csv: bool = True,
  39. truncate_dim: int | None = None,
  40. ):
  41. super().__init__()
  42. self.combinations = combinations
  43. self.name = name
  44. self.batch_size = batch_size
  45. if name:
  46. name = "_" + name
  47. self.csv_file = "mse_evaluation" + name + "_results.csv"
  48. self.csv_headers = ["epoch", "steps"]
  49. self.primary_metric = "negative_mse"
  50. self.write_csv = write_csv
  51. self.truncate_dim = truncate_dim
  52. self.data = {}
  53. logger.info("Compute teacher embeddings")
  54. all_source_sentences = set()
  55. for src_lang, trg_lang in self.combinations:
  56. src_sentences = []
  57. trg_sentences = []
  58. for row in dataframe:
  59. if row[src_lang].strip() != "" and row[trg_lang].strip() != "":
  60. all_source_sentences.add(row[src_lang])
  61. src_sentences.append(row[src_lang])
  62. trg_sentences.append(row[trg_lang])
  63. self.data[(src_lang, trg_lang)] = (src_sentences, trg_sentences)
  64. self.csv_headers.append(f"{src_lang}-{trg_lang}")
  65. all_source_sentences = list(all_source_sentences)
  66. with nullcontext() if self.truncate_dim is None else teacher_model.truncate_sentence_embeddings(
  67. self.truncate_dim
  68. ):
  69. all_src_embeddings = teacher_model.encode(all_source_sentences, batch_size=self.batch_size)
  70. self.teacher_embeddings = {sent: emb for sent, emb in zip(all_source_sentences, all_src_embeddings)}
  71. def __call__(
  72. self, model: SentenceTransformer, output_path: str = None, epoch: int = -1, steps: int = -1
  73. ) -> dict[str, float]:
  74. model.eval()
  75. mse_scores = []
  76. for src_lang, trg_lang in self.combinations:
  77. src_sentences, trg_sentences = self.data[(src_lang, trg_lang)]
  78. src_embeddings = np.asarray([self.teacher_embeddings[sent] for sent in src_sentences])
  79. with nullcontext() if self.truncate_dim is None else model.truncate_sentence_embeddings(self.truncate_dim):
  80. trg_embeddings = np.asarray(model.encode(trg_sentences, batch_size=self.batch_size))
  81. mse = ((src_embeddings - trg_embeddings) ** 2).mean()
  82. mse *= 100
  83. mse_scores.append(mse)
  84. logger.info(f"MSE evaluation on {self.name} dataset - {src_lang}-{trg_lang}:")
  85. logger.info(f"MSE (*100):\t{mse:4f}")
  86. if output_path is not None and self.write_csv:
  87. csv_path = os.path.join(output_path, self.csv_file)
  88. output_file_exists = os.path.isfile(csv_path)
  89. with open(csv_path, newline="", mode="a" if output_file_exists else "w", encoding="utf-8") as f:
  90. writer = csv.writer(f)
  91. if not output_file_exists:
  92. writer.writerow(self.csv_headers)
  93. writer.writerow([epoch, steps] + mse_scores)
  94. # Return negative score as SentenceTransformers maximizes the performance
  95. metrics = {"negative_mse": -np.mean(mse_scores).item()}
  96. metrics = self.prefix_name_to_metrics(metrics, self.name)
  97. self.store_metrics_in_model_card_data(model, metrics)
  98. return metrics
  99. @property
  100. def description(self) -> str:
  101. return "Knowledge Distillation"