| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293 |
- from __future__ import annotations
- import logging
- from dataclasses import dataclass, field
- from typing import Any, Callable
- import torch
- logger = logging.getLogger(__name__)
- @dataclass
- class SentenceTransformerDataCollator:
- """Collator for a SentenceTransformers model.
- This encodes the text columns to {column}_input_ids and {column}_attention_mask columns.
- This works with the two text dataset that is used as the example in the training overview:
- https://www.sbert.net/docs/sentence_transformer/training_overview.html
- It is important that the columns are in the expected order. For example, if your dataset has columns
- "answer", "question" in that order, then the MultipleNegativesRankingLoss will consider
- "answer" as the anchor and "question" as the positive, and it will (unexpectedly) optimize for
- "given the answer, what is the question?".
- """
- tokenize_fn: Callable
- valid_label_columns: list[str] = field(default_factory=lambda: ["label", "score"])
- _warned_columns: set[tuple[str]] = field(default_factory=set, init=False, repr=False)
- def __call__(self, features: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
- column_names = list(features[0].keys())
- # We should always be able to return a loss, label or not:
- batch = {}
- if "dataset_name" in column_names:
- column_names.remove("dataset_name")
- batch["dataset_name"] = features[0]["dataset_name"]
- if tuple(column_names) not in self._warned_columns:
- self.maybe_warn_about_column_order(column_names)
- # Extract the label column if it exists
- for label_column in self.valid_label_columns:
- if label_column in column_names:
- batch["label"] = torch.tensor([row[label_column] for row in features])
- column_names.remove(label_column)
- break
- # Extract the feature columns
- for column_name in column_names:
- tokenized = self.tokenize_fn([row[column_name] for row in features])
- for key, value in tokenized.items():
- batch[f"{column_name}_{key}"] = value
- return batch
- def maybe_warn_about_column_order(self, column_names: list[str]) -> None:
- """Warn the user if the columns are likely not in the expected order."""
- # A mapping from common column names to the expected index in the dataset
- column_name_to_expected_idx = {
- "anchor": 0,
- "positive": 1,
- "negative": 2,
- "question": 0,
- "answer": 1,
- "query": 0,
- "response": 1,
- "hypothesis": 0,
- "entailment": 1,
- "contradiction": 2,
- }
- for column_name, expected_idx in column_name_to_expected_idx.items():
- if column_name in column_names and column_names.index(column_name) != expected_idx:
- if column_name in ("anchor", "positive", "negative"):
- proposed_fix_columns = ["anchor", "positive", "negative"]
- elif column_name in ("question", "answer"):
- proposed_fix_columns = ["question", "answer"]
- elif column_name in ("query", "response"):
- proposed_fix_columns = ["query", "response"]
- elif column_name in ("hypothesis", "entailment", "contradiction"):
- proposed_fix_columns = ["hypothesis", "entailment", "contradiction"]
- logger.warning(
- f"Column {column_name!r} is at index {column_names.index(column_name)}, whereas "
- f"a column with this name is usually expected at index {expected_idx}. Note that the column "
- "order can be important for some losses, e.g. MultipleNegativesRankingLoss will always "
- "consider the first column as the anchor and the second as the positive, regardless of "
- "the dataset column names. Consider renaming the columns to match the expected order, e.g.:\n"
- f"dataset = dataset.select_columns({proposed_fix_columns})"
- )
- # We only need to warn once per list of column names to prevent spamming the user
- break
- self._warned_columns.add(tuple(column_names))
|