CLIPModel.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. from __future__ import annotations
  2. import torch
  3. import transformers
  4. from PIL import Image
  5. from torch import nn
  6. class CLIPModel(nn.Module):
  7. save_in_root: bool = True
  8. def __init__(self, model_name: str = "openai/clip-vit-base-patch32", processor_name=None) -> None:
  9. super().__init__()
  10. if processor_name is None:
  11. processor_name = model_name
  12. self.model = transformers.CLIPModel.from_pretrained(model_name)
  13. self.processor = transformers.CLIPProcessor.from_pretrained(processor_name)
  14. def __repr__(self) -> str:
  15. return "CLIPModel()"
  16. def forward(self, features: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
  17. image_embeds = []
  18. text_embeds = []
  19. if "pixel_values" in features:
  20. vision_outputs = self.model.vision_model(pixel_values=features["pixel_values"])
  21. image_embeds = self.model.visual_projection(vision_outputs[1])
  22. if "input_ids" in features:
  23. text_outputs = self.model.text_model(
  24. input_ids=features.get("input_ids"),
  25. attention_mask=features.get("attention_mask", None),
  26. position_ids=features.get("position_ids", None),
  27. output_attentions=features.get("output_attentions", None),
  28. output_hidden_states=features.get("output_hidden_states", None),
  29. )
  30. text_embeds = self.model.text_projection(text_outputs[1])
  31. sentence_embedding = []
  32. image_features = iter(image_embeds)
  33. text_features = iter(text_embeds)
  34. for idx, input_type in enumerate(features["image_text_info"]):
  35. if input_type == 0:
  36. sentence_embedding.append(next(image_features))
  37. else:
  38. sentence_embedding.append(next(text_features))
  39. features["sentence_embedding"] = torch.stack(sentence_embedding).float()
  40. return features
  41. def tokenize(self, texts, padding: str | bool = True) -> dict[str, torch.Tensor]:
  42. images = []
  43. texts_values = []
  44. image_text_info = []
  45. for idx, data in enumerate(texts):
  46. if isinstance(data, Image.Image): # An Image
  47. images.append(data)
  48. image_text_info.append(0)
  49. else: # A text
  50. texts_values.append(data)
  51. image_text_info.append(1)
  52. encoding = {}
  53. if len(texts_values):
  54. encoding = self.processor.tokenizer(texts_values, return_tensors="pt", padding=padding)
  55. if len(images):
  56. image_features = self.processor.image_processor(images, return_tensors="pt")
  57. encoding["pixel_values"] = image_features.pixel_values
  58. encoding["image_text_info"] = image_text_info
  59. return dict(encoding)
  60. @property
  61. def tokenizer(self) -> transformers.CLIPProcessor:
  62. return self.processor
  63. def save(self, output_path: str) -> None:
  64. self.model.save_pretrained(output_path)
  65. self.processor.save_pretrained(output_path)
  66. @staticmethod
  67. def load(input_path: str) -> CLIPModel:
  68. return CLIPModel(model_name=input_path)