| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748 |
- from __future__ import annotations
- import json
- import os
- import torch
- from safetensors.torch import load_model as load_safetensors_model
- from safetensors.torch import save_model as save_safetensors_model
- from torch import Tensor, nn
- class LayerNorm(nn.Module):
- def __init__(self, dimension: int):
- super().__init__()
- self.dimension = dimension
- self.norm = nn.LayerNorm(dimension)
- def forward(self, features: dict[str, Tensor]):
- features["sentence_embedding"] = self.norm(features["sentence_embedding"])
- return features
- def get_sentence_embedding_dimension(self):
- return self.dimension
- def save(self, output_path, safe_serialization: bool = True) -> None:
- with open(os.path.join(output_path, "config.json"), "w") as fOut:
- json.dump({"dimension": self.dimension}, fOut, indent=2)
- if safe_serialization:
- save_safetensors_model(self, os.path.join(output_path, "model.safetensors"))
- else:
- torch.save(self.state_dict(), os.path.join(output_path, "pytorch_model.bin"))
- @staticmethod
- def load(input_path):
- with open(os.path.join(input_path, "config.json")) as fIn:
- config = json.load(fIn)
- model = LayerNorm(**config)
- if os.path.exists(os.path.join(input_path, "model.safetensors")):
- load_safetensors_model(model, os.path.join(input_path, "model.safetensors"))
- else:
- model.load_state_dict(
- torch.load(
- os.path.join(input_path, "pytorch_model.bin"), map_location=torch.device("cpu"), weights_only=True
- )
- )
- return model
|