LayerNorm.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. from __future__ import annotations
  2. import json
  3. import os
  4. import torch
  5. from safetensors.torch import load_model as load_safetensors_model
  6. from safetensors.torch import save_model as save_safetensors_model
  7. from torch import Tensor, nn
  8. class LayerNorm(nn.Module):
  9. def __init__(self, dimension: int):
  10. super().__init__()
  11. self.dimension = dimension
  12. self.norm = nn.LayerNorm(dimension)
  13. def forward(self, features: dict[str, Tensor]):
  14. features["sentence_embedding"] = self.norm(features["sentence_embedding"])
  15. return features
  16. def get_sentence_embedding_dimension(self):
  17. return self.dimension
  18. def save(self, output_path, safe_serialization: bool = True) -> None:
  19. with open(os.path.join(output_path, "config.json"), "w") as fOut:
  20. json.dump({"dimension": self.dimension}, fOut, indent=2)
  21. if safe_serialization:
  22. save_safetensors_model(self, os.path.join(output_path, "model.safetensors"))
  23. else:
  24. torch.save(self.state_dict(), os.path.join(output_path, "pytorch_model.bin"))
  25. @staticmethod
  26. def load(input_path):
  27. with open(os.path.join(input_path, "config.json")) as fIn:
  28. config = json.load(fIn)
  29. model = LayerNorm(**config)
  30. if os.path.exists(os.path.join(input_path, "model.safetensors")):
  31. load_safetensors_model(model, os.path.join(input_path, "model.safetensors"))
  32. else:
  33. model.load_state_dict(
  34. torch.load(
  35. os.path.join(input_path, "pytorch_model.bin"), map_location=torch.device("cpu"), weights_only=True
  36. )
  37. )
  38. return model