CNN.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. from __future__ import annotations
  2. import json
  3. import os
  4. import torch
  5. from safetensors.torch import load_model as load_safetensors_model
  6. from safetensors.torch import save_model as save_safetensors_model
  7. from torch import nn
  8. class CNN(nn.Module):
  9. """CNN-layer with multiple kernel-sizes over the word embeddings"""
  10. def __init__(
  11. self,
  12. in_word_embedding_dimension: int,
  13. out_channels: int = 256,
  14. kernel_sizes: list[int] = [1, 3, 5],
  15. stride_sizes: list[int] = None,
  16. ):
  17. nn.Module.__init__(self)
  18. self.config_keys = ["in_word_embedding_dimension", "out_channels", "kernel_sizes"]
  19. self.in_word_embedding_dimension = in_word_embedding_dimension
  20. self.out_channels = out_channels
  21. self.kernel_sizes = kernel_sizes
  22. self.embeddings_dimension = out_channels * len(kernel_sizes)
  23. self.convs = nn.ModuleList()
  24. in_channels = in_word_embedding_dimension
  25. if stride_sizes is None:
  26. stride_sizes = [1] * len(kernel_sizes)
  27. for kernel_size, stride in zip(kernel_sizes, stride_sizes):
  28. padding_size = int((kernel_size - 1) / 2)
  29. conv = nn.Conv1d(
  30. in_channels=in_channels,
  31. out_channels=out_channels,
  32. kernel_size=kernel_size,
  33. stride=stride,
  34. padding=padding_size,
  35. )
  36. self.convs.append(conv)
  37. def forward(self, features):
  38. token_embeddings = features["token_embeddings"]
  39. token_embeddings = token_embeddings.transpose(1, -1)
  40. vectors = [conv(token_embeddings) for conv in self.convs]
  41. out = torch.cat(vectors, 1).transpose(1, -1)
  42. features.update({"token_embeddings": out})
  43. return features
  44. def get_word_embedding_dimension(self) -> int:
  45. return self.embeddings_dimension
  46. def tokenize(self, text: str, **kwargs) -> list[int]:
  47. raise NotImplementedError()
  48. def save(self, output_path: str, safe_serialization: bool = True):
  49. with open(os.path.join(output_path, "cnn_config.json"), "w") as fOut:
  50. json.dump(self.get_config_dict(), fOut, indent=2)
  51. if safe_serialization:
  52. save_safetensors_model(self, os.path.join(output_path, "model.safetensors"))
  53. else:
  54. torch.save(self.state_dict(), os.path.join(output_path, "pytorch_model.bin"))
  55. def get_config_dict(self):
  56. return {key: self.__dict__[key] for key in self.config_keys}
  57. @staticmethod
  58. def load(input_path: str):
  59. with open(os.path.join(input_path, "cnn_config.json")) as fIn:
  60. config = json.load(fIn)
  61. model = CNN(**config)
  62. if os.path.exists(os.path.join(input_path, "model.safetensors")):
  63. load_safetensors_model(model, os.path.join(input_path, "model.safetensors"))
  64. else:
  65. model.load_state_dict(
  66. torch.load(
  67. os.path.join(input_path, "pytorch_model.bin"), map_location=torch.device("cpu"), weights_only=True
  68. )
  69. )
  70. return model