Normalize.py 588 B

12345678910111213141516171819202122
  1. from __future__ import annotations
  2. import torch.nn.functional as F
  3. from torch import Tensor, nn
  4. class Normalize(nn.Module):
  5. """This layer normalizes embeddings to unit length"""
  6. def __init__(self) -> None:
  7. super().__init__()
  8. def forward(self, features: dict[str, Tensor]) -> dict[str, Tensor]:
  9. features.update({"sentence_embedding": F.normalize(features["sentence_embedding"], p=2, dim=1)})
  10. return features
  11. def save(self, output_path) -> None:
  12. pass
  13. @staticmethod
  14. def load(input_path) -> Normalize:
  15. return Normalize()