tf_utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. # Copyright 2022 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List, Optional, Union
  15. import numpy as np
  16. import tensorflow as tf
  17. from .feature_extraction_utils import BatchFeature
  18. from .tokenization_utils_base import BatchEncoding
  19. from .utils import logging
  20. logger = logging.get_logger(__name__)
  21. def shape_list(tensor: Union[tf.Tensor, np.ndarray]) -> List[int]:
  22. """
  23. Deal with dynamic shape in tensorflow cleanly.
  24. Args:
  25. tensor (`tf.Tensor` or `np.ndarray`): The tensor we want the shape of.
  26. Returns:
  27. `List[int]`: The shape of the tensor as a list.
  28. """
  29. if isinstance(tensor, np.ndarray):
  30. return list(tensor.shape)
  31. dynamic = tf.shape(tensor)
  32. if tensor.shape == tf.TensorShape(None):
  33. return dynamic
  34. static = tensor.shape.as_list()
  35. return [dynamic[i] if s is None else s for i, s in enumerate(static)]
  36. def stable_softmax(logits: tf.Tensor, axis: Optional[int] = None, name: Optional[str] = None) -> tf.Tensor:
  37. """
  38. Stable wrapper that returns the same output as `tf.nn.softmax`, but that works reliably with XLA on CPU. It is
  39. meant as a workaround for the [following issue](https://github.com/tensorflow/tensorflow/issues/55682), and will be
  40. removed after it gets fixed. The arguments and outputs are the same as `tf.nn.softmax`, and relies on the fact that
  41. `softmax(x) = softmax(x + c)` (see https://ogunlao.github.io/2020/04/26/you_dont_really_know_softmax.html).
  42. Args:
  43. logits (`tf.Tensor`):
  44. Must be one of the following types: half, float32, float64.
  45. axis (`int`, *optional*):
  46. The dimension softmax would be performed on. The default is -1 which indicates the last dimension.
  47. name (`str`, *optional*):
  48. A name for the operation.
  49. Returns:
  50. `tf.Tensor`:
  51. A Tensor. Has the same type and shape as logits.
  52. """
  53. # TODO: When the issue linked above gets sorted, add a check on TF version here and use the original function if
  54. # it has the fix. After we drop the support for unfixed versions, remove this function.
  55. return tf.nn.softmax(logits=logits + 1e-9, axis=axis, name=name)
  56. def functional_layernorm(inputs, weight, bias, epsilon=1e-5, axis=-1):
  57. # This is a very simplified functional layernorm, designed to duplicate
  58. # the functionality of PyTorch nn.functional.layer_norm when this is needed to port
  59. # models in Transformers.
  60. if weight.shape.rank != 1 or bias.shape.rank != 1 or not isinstance(axis, int):
  61. raise NotImplementedError("Only 1D weight and bias tensors are supported for now, with only a single axis.")
  62. # Get mean and variance on the axis to be normalized
  63. mean, variance = tf.nn.moments(inputs, axes=[axis], keepdims=True)
  64. if axis != -1:
  65. # Reshape scale and weight to have the same rank as inputs, but with 1 dimensions
  66. # on every dimension except axis
  67. shape = [1] * inputs.shape.rank
  68. shape[axis] = shape_list(inputs)[axis]
  69. weight = tf.reshape(weight, shape)
  70. bias = tf.reshape(bias, shape)
  71. # Compute layer normalization using the batch_normalization
  72. # function.
  73. outputs = tf.nn.batch_normalization(
  74. inputs,
  75. mean,
  76. variance,
  77. offset=bias,
  78. scale=weight,
  79. variance_epsilon=epsilon,
  80. )
  81. return outputs
  82. def scaled_dot_product_attention(
  83. query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale: float = None
  84. ):
  85. """TF equivalent for torch's nn.functional.scaled_dot_product_attention"""
  86. if dropout_p != 0.0:
  87. raise ValueError(
  88. "Dropout is not supported in this implementation - file an issue "
  89. "with Transformers and ping @Rocketknight1 if you need it for a port!"
  90. )
  91. if is_causal and attn_mask is not None:
  92. raise ValueError("You cannot specify an attn_mask and is_causal at the same time!")
  93. if is_causal:
  94. attn_mask = tf.ones((tf.shape(query)[-2], tf.shape(key)[-2]), dtype=tf.int32)
  95. attn_mask = tf.experimental.numpy.tril(attn_mask, k=0)
  96. if attn_mask is not None and (attn_mask.dtype.is_integer or attn_mask.dtype.is_bool):
  97. # Convert boolean mask to a negative logit bias
  98. attn_mask = tf.where(attn_mask > 0, tf.cast(0.0, query.dtype), tf.cast(-1000.0, query.dtype))
  99. logits = tf.einsum("...qd, ...kd -> ...qk", query, key)
  100. if scale is None:
  101. scale = tf.cast(tf.shape(key)[-1], logits.dtype) ** -0.5
  102. logits *= scale # scale by 1/sqrt(key_dim)
  103. if attn_mask is not None:
  104. logits += attn_mask
  105. probs = tf.nn.softmax(logits)
  106. return probs @ value
  107. def flatten(input, start_dim=0, end_dim=-1):
  108. # Replicates the behavior of torch.flatten in TF
  109. # If end_dim or start_dim is negative, count them from the end
  110. if end_dim < 0:
  111. end_dim += input.shape.rank
  112. if start_dim < 0:
  113. start_dim += input.shape.rank
  114. if start_dim == end_dim:
  115. return input
  116. in_shape = tf.shape(input)
  117. flattened_dim = tf.math.reduce_prod(in_shape[start_dim : end_dim + 1])
  118. out_shape = tf.concat([in_shape[:start_dim], [flattened_dim], in_shape[end_dim + 1 :]], axis=0)
  119. return tf.reshape(input, out_shape)
  120. def invert_attention_mask(encoder_attention_mask: tf.Tensor) -> tf.Tensor:
  121. """
  122. Invert an attention mask (e.g., switches 0. and 1.).
  123. Args:
  124. encoder_attention_mask (`torch.Tensor`): An attention mask.
  125. Returns:
  126. `tf.Tensor`: The inverted attention mask.
  127. """
  128. if not isinstance(encoder_attention_mask, tf.Tensor):
  129. encoder_attention_mask = tf.convert_to_tensor(encoder_attention_mask) # Catches stray NumPy inputs
  130. if encoder_attention_mask.shape.rank == 3:
  131. encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
  132. if encoder_attention_mask.shape.rank == 2:
  133. encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
  134. # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
  135. # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
  136. # /transformer/transformer_layers.py#L270
  137. # encoder_extended_attention_mask = (encoder_extended_attention_mask ==
  138. # encoder_extended_attention_mask.transpose(-1, -2))
  139. encoder_extended_attention_mask = (
  140. tf.cast(1, encoder_attention_mask.dtype) - encoder_extended_attention_mask
  141. ) * encoder_extended_attention_mask.dtype.min
  142. return encoder_extended_attention_mask
  143. def check_embeddings_within_bounds(tensor: tf.Tensor, embed_dim: int, tensor_name: str = "input_ids") -> None:
  144. """
  145. `tf.gather`, on which TF embedding layers are based, won't check positive out of bound indices on GPU, returning
  146. zeros instead. This function adds a check against that dangerous silent behavior.
  147. Args:
  148. tensor (`tf.Tensor`): The tensor of indices to check.
  149. embed_dim (`int`): The embedding dimension.
  150. tensor_name (`str`, *optional*): The name of the tensor to use in the error message.
  151. """
  152. tf.debugging.assert_less(
  153. tensor,
  154. tf.cast(embed_dim, dtype=tensor.dtype),
  155. message=(
  156. f"The maximum value of {tensor_name} ({tf.math.reduce_max(tensor)}) must be smaller than the embedding "
  157. f"layer's input dimension ({embed_dim}). The likely cause is some problem at tokenization time."
  158. ),
  159. )
  160. def save_attributes_to_hdf5_group(group, name, data):
  161. """Saves attributes (data) of the specified name into the HDF5 group.
  162. This method deals with an inherent problem of HDF5 file which is not able to store data larger than
  163. HDF5_OBJECT_HEADER_LIMIT bytes.
  164. Args:
  165. group: A pointer to a HDF5 group.
  166. name: A name of the attributes to save.
  167. data: Attributes data to store.
  168. Raises:
  169. RuntimeError: If any single attribute is too large to be saved.
  170. Copied from Keras to Transformers to avoid versioning issues.
  171. """
  172. HDF5_OBJECT_HEADER_LIMIT = 64512
  173. # Check that no item in `data` is larger than `HDF5_OBJECT_HEADER_LIMIT`
  174. # because in that case even chunking the array would not make the saving
  175. # possible.
  176. bad_attributes = [x for x in data if len(x) > HDF5_OBJECT_HEADER_LIMIT]
  177. # Expecting this to never be true.
  178. if bad_attributes:
  179. raise RuntimeError(
  180. "The following attributes cannot be saved to HDF5 file because "
  181. f"they are larger than {HDF5_OBJECT_HEADER_LIMIT} "
  182. f"bytes: {bad_attributes}"
  183. )
  184. data_npy = np.asarray(data)
  185. num_chunks = 1
  186. chunked_data = np.array_split(data_npy, num_chunks)
  187. # This will never loop forever thanks to the test above.
  188. while any(x.nbytes > HDF5_OBJECT_HEADER_LIMIT for x in chunked_data):
  189. num_chunks += 1
  190. chunked_data = np.array_split(data_npy, num_chunks)
  191. if num_chunks > 1:
  192. for chunk_id, chunk_data in enumerate(chunked_data):
  193. group.attrs["%s%d" % (name, chunk_id)] = chunk_data
  194. else:
  195. group.attrs[name] = data
  196. def load_attributes_from_hdf5_group(group, name):
  197. """Loads attributes of the specified name from the HDF5 group.
  198. This method deals with an inherent problem of HDF5 file which is not able to store data larger than
  199. HDF5_OBJECT_HEADER_LIMIT bytes.
  200. Args:
  201. group: A pointer to a HDF5 group.
  202. name: A name of the attributes to load.
  203. Returns:
  204. data: Attributes data.
  205. Copied from Keras to Transformers to avoid versioning issues.
  206. """
  207. if name in group.attrs:
  208. data = [n.decode("utf8") if hasattr(n, "decode") else n for n in group.attrs[name]]
  209. else:
  210. data = []
  211. chunk_id = 0
  212. while "%s%d" % (name, chunk_id) in group.attrs:
  213. data.extend(
  214. [n.decode("utf8") if hasattr(n, "decode") else n for n in group.attrs["%s%d" % (name, chunk_id)]]
  215. )
  216. chunk_id += 1
  217. return data
  218. def expand_1d(data):
  219. """Expands 1-dimensional `Tensor`s into 2-dimensional `Tensor`s.
  220. Copied from Keras to here to avoid versioning issues."""
  221. def _expand_single_1d_tensor(t):
  222. if isinstance(t, tf.Tensor) and t.shape.rank == 1:
  223. return tf.expand_dims(t, axis=-1)
  224. return t
  225. return tf.nest.map_structure(_expand_single_1d_tensor, data)
  226. def convert_batch_encoding(*args, **kwargs):
  227. # Convert HF BatchEncoding/BatchFeature objects in the inputs to dicts that Keras understands
  228. if args and isinstance(args[0], (BatchEncoding, BatchFeature)):
  229. args = list(args)
  230. args[0] = dict(args[0])
  231. elif "x" in kwargs and isinstance(kwargs["x"], (BatchEncoding, BatchFeature)):
  232. kwargs["x"] = dict(kwargs["x"])
  233. return args, kwargs