Dropout.py 915 B

1234567891011121314151617181920212223242526272829303132333435
  1. from __future__ import annotations
  2. import json
  3. import os
  4. from torch import Tensor, nn
  5. class Dropout(nn.Module):
  6. """Dropout layer.
  7. Args:
  8. dropout: Sets a dropout value for dense layer.
  9. """
  10. def __init__(self, dropout: float = 0.2):
  11. super().__init__()
  12. self.dropout = dropout
  13. self.dropout_layer = nn.Dropout(self.dropout)
  14. def forward(self, features: dict[str, Tensor]):
  15. features.update({"sentence_embedding": self.dropout_layer(features["sentence_embedding"])})
  16. return features
  17. def save(self, output_path):
  18. with open(os.path.join(output_path, "config.json"), "w") as fOut:
  19. json.dump({"dropout": self.dropout}, fOut)
  20. @staticmethod
  21. def load(input_path):
  22. with open(os.path.join(input_path, "config.json")) as fIn:
  23. config = json.load(fIn)
  24. model = Dropout(**config)
  25. return model