feats.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. # Copyright 2021 AlQuraishi Laboratory
  2. # Copyright 2021 DeepMind Technologies Limited
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from typing import Dict, Tuple, overload
  16. import torch
  17. import torch.types
  18. from torch import nn
  19. from . import residue_constants as rc
  20. from .rigid_utils import Rigid, Rotation
  21. from .tensor_utils import batched_gather
  22. @overload
  23. def pseudo_beta_fn(aatype: torch.Tensor, all_atom_positions: torch.Tensor, all_atom_masks: None) -> torch.Tensor: ...
  24. @overload
  25. def pseudo_beta_fn(
  26. aatype: torch.Tensor, all_atom_positions: torch.Tensor, all_atom_masks: torch.Tensor
  27. ) -> Tuple[torch.Tensor, torch.Tensor]: ...
  28. def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
  29. is_gly = aatype == rc.restype_order["G"]
  30. ca_idx = rc.atom_order["CA"]
  31. cb_idx = rc.atom_order["CB"]
  32. pseudo_beta = torch.where(
  33. is_gly[..., None].expand(*((-1,) * len(is_gly.shape)), 3),
  34. all_atom_positions[..., ca_idx, :],
  35. all_atom_positions[..., cb_idx, :],
  36. )
  37. if all_atom_masks is not None:
  38. pseudo_beta_mask = torch.where(
  39. is_gly,
  40. all_atom_masks[..., ca_idx],
  41. all_atom_masks[..., cb_idx],
  42. )
  43. return pseudo_beta, pseudo_beta_mask
  44. else:
  45. return pseudo_beta
  46. def atom14_to_atom37(atom14: torch.Tensor, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
  47. atom37_data = batched_gather(
  48. atom14,
  49. batch["residx_atom37_to_atom14"],
  50. dim=-2,
  51. no_batch_dims=len(atom14.shape[:-2]),
  52. )
  53. atom37_data = atom37_data * batch["atom37_atom_exists"][..., None]
  54. return atom37_data
  55. def build_template_angle_feat(template_feats: Dict[str, torch.Tensor]) -> torch.Tensor:
  56. template_aatype = template_feats["template_aatype"]
  57. torsion_angles_sin_cos = template_feats["template_torsion_angles_sin_cos"]
  58. alt_torsion_angles_sin_cos = template_feats["template_alt_torsion_angles_sin_cos"]
  59. torsion_angles_mask = template_feats["template_torsion_angles_mask"]
  60. template_angle_feat = torch.cat(
  61. [
  62. nn.functional.one_hot(template_aatype, 22),
  63. torsion_angles_sin_cos.reshape(*torsion_angles_sin_cos.shape[:-2], 14),
  64. alt_torsion_angles_sin_cos.reshape(*alt_torsion_angles_sin_cos.shape[:-2], 14),
  65. torsion_angles_mask,
  66. ],
  67. dim=-1,
  68. )
  69. return template_angle_feat
  70. def build_template_pair_feat(
  71. batch: Dict[str, torch.Tensor],
  72. min_bin: torch.types.Number,
  73. max_bin: torch.types.Number,
  74. no_bins: int,
  75. use_unit_vector: bool = False,
  76. eps: float = 1e-20,
  77. inf: float = 1e8,
  78. ) -> torch.Tensor:
  79. template_mask = batch["template_pseudo_beta_mask"]
  80. template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
  81. # Compute distogram (this seems to differ slightly from Alg. 5)
  82. tpb = batch["template_pseudo_beta"]
  83. dgram = torch.sum((tpb[..., None, :] - tpb[..., None, :, :]) ** 2, dim=-1, keepdim=True)
  84. lower = torch.linspace(min_bin, max_bin, no_bins, device=tpb.device) ** 2
  85. upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
  86. dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
  87. to_concat = [dgram, template_mask_2d[..., None]]
  88. aatype_one_hot: torch.LongTensor = nn.functional.one_hot(
  89. batch["template_aatype"],
  90. rc.restype_num + 2,
  91. )
  92. n_res = batch["template_aatype"].shape[-1]
  93. to_concat.append(aatype_one_hot[..., None, :, :].expand(*aatype_one_hot.shape[:-2], n_res, -1, -1))
  94. to_concat.append(aatype_one_hot[..., None, :].expand(*aatype_one_hot.shape[:-2], -1, n_res, -1))
  95. n, ca, c = [rc.atom_order[a] for a in ["N", "CA", "C"]]
  96. rigids = Rigid.make_transform_from_reference(
  97. n_xyz=batch["template_all_atom_positions"][..., n, :],
  98. ca_xyz=batch["template_all_atom_positions"][..., ca, :],
  99. c_xyz=batch["template_all_atom_positions"][..., c, :],
  100. eps=eps,
  101. )
  102. points = rigids.get_trans()[..., None, :, :]
  103. rigid_vec = rigids[..., None].invert_apply(points)
  104. inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec**2, dim=-1))
  105. t_aa_masks = batch["template_all_atom_mask"]
  106. template_mask = t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c]
  107. template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
  108. inv_distance_scalar = inv_distance_scalar * template_mask_2d
  109. unit_vector = rigid_vec * inv_distance_scalar[..., None]
  110. if not use_unit_vector:
  111. unit_vector = unit_vector * 0.0
  112. to_concat.extend(torch.unbind(unit_vector[..., None, :], dim=-1))
  113. to_concat.append(template_mask_2d[..., None])
  114. act = torch.cat(to_concat, dim=-1)
  115. act = act * template_mask_2d[..., None]
  116. return act
  117. def build_extra_msa_feat(batch: Dict[str, torch.Tensor]) -> torch.Tensor:
  118. msa_1hot: torch.LongTensor = nn.functional.one_hot(batch["extra_msa"], 23)
  119. msa_feat = [
  120. msa_1hot,
  121. batch["extra_has_deletion"].unsqueeze(-1),
  122. batch["extra_deletion_value"].unsqueeze(-1),
  123. ]
  124. return torch.cat(msa_feat, dim=-1)
  125. def torsion_angles_to_frames(
  126. r: Rigid,
  127. alpha: torch.Tensor,
  128. aatype: torch.Tensor,
  129. rrgdf: torch.Tensor,
  130. ) -> Rigid:
  131. # [*, N, 8, 4, 4]
  132. default_4x4 = rrgdf[aatype, ...]
  133. # [*, N, 8] transformations, i.e.
  134. # One [*, N, 8, 3, 3] rotation matrix and
  135. # One [*, N, 8, 3] translation matrix
  136. default_r = r.from_tensor_4x4(default_4x4)
  137. bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2))
  138. bb_rot[..., 1] = 1
  139. # [*, N, 8, 2]
  140. alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2)
  141. # [*, N, 8, 3, 3]
  142. # Produces rotation matrices of the form:
  143. # [
  144. # [1, 0 , 0 ],
  145. # [0, a_2,-a_1],
  146. # [0, a_1, a_2]
  147. # ]
  148. # This follows the original code rather than the supplement, which uses
  149. # different indices.
  150. all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape)
  151. all_rots[..., 0, 0] = 1
  152. all_rots[..., 1, 1] = alpha[..., 1]
  153. all_rots[..., 1, 2] = -alpha[..., 0]
  154. all_rots[..., 2, 1:] = alpha
  155. all_frames = default_r.compose(Rigid(Rotation(rot_mats=all_rots), None))
  156. chi2_frame_to_frame = all_frames[..., 5]
  157. chi3_frame_to_frame = all_frames[..., 6]
  158. chi4_frame_to_frame = all_frames[..., 7]
  159. chi1_frame_to_bb = all_frames[..., 4]
  160. chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame)
  161. chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
  162. chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
  163. all_frames_to_bb = Rigid.cat(
  164. [
  165. all_frames[..., :5],
  166. chi2_frame_to_bb.unsqueeze(-1),
  167. chi3_frame_to_bb.unsqueeze(-1),
  168. chi4_frame_to_bb.unsqueeze(-1),
  169. ],
  170. dim=-1,
  171. )
  172. all_frames_to_global = r[..., None].compose(all_frames_to_bb)
  173. return all_frames_to_global
  174. def frames_and_literature_positions_to_atom14_pos(
  175. r: Rigid,
  176. aatype: torch.Tensor,
  177. default_frames: torch.Tensor,
  178. group_idx: torch.Tensor,
  179. atom_mask: torch.Tensor,
  180. lit_positions: torch.Tensor,
  181. ) -> torch.Tensor:
  182. # [*, N, 14]
  183. group_mask = group_idx[aatype, ...]
  184. # [*, N, 14, 8]
  185. group_mask_one_hot: torch.LongTensor = nn.functional.one_hot(
  186. group_mask,
  187. num_classes=default_frames.shape[-3],
  188. )
  189. # [*, N, 14, 8]
  190. t_atoms_to_global = r[..., None, :] * group_mask_one_hot
  191. # [*, N, 14]
  192. t_atoms_to_global = t_atoms_to_global.map_tensor_fn(lambda x: torch.sum(x, dim=-1))
  193. # [*, N, 14, 1]
  194. atom_mask = atom_mask[aatype, ...].unsqueeze(-1)
  195. # [*, N, 14, 3]
  196. lit_positions = lit_positions[aatype, ...]
  197. pred_positions = t_atoms_to_global.apply(lit_positions)
  198. pred_positions = pred_positions * atom_mask
  199. return pred_positions