| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290 |
- # coding=utf-8
- # Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import gc
- import json
- import os
- import re
- import warnings
- from functools import partial
- from pickle import UnpicklingError
- from typing import Any, Dict, Optional, Set, Tuple, Union
- import flax.linen as nn
- import jax
- import jax.numpy as jnp
- import msgpack.exceptions
- from flax.core.frozen_dict import FrozenDict, unfreeze
- from flax.serialization import from_bytes, to_bytes
- from flax.traverse_util import flatten_dict, unflatten_dict
- from jax.random import PRNGKey
- from .configuration_utils import PretrainedConfig
- from .dynamic_module_utils import custom_object_save
- from .generation import FlaxGenerationMixin, GenerationConfig
- from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
- from .utils import (
- FLAX_WEIGHTS_INDEX_NAME,
- FLAX_WEIGHTS_NAME,
- SAFE_WEIGHTS_INDEX_NAME,
- SAFE_WEIGHTS_NAME,
- WEIGHTS_INDEX_NAME,
- WEIGHTS_NAME,
- PushToHubMixin,
- add_code_sample_docstrings,
- add_start_docstrings_to_model_forward,
- cached_file,
- copy_func,
- download_url,
- has_file,
- is_offline_mode,
- is_remote_url,
- logging,
- replace_return_docstrings,
- )
- from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
- from .utils.import_utils import is_safetensors_available
- if is_safetensors_available():
- from safetensors import safe_open
- from safetensors.flax import load_file as safe_load_file
- from safetensors.flax import save_file as safe_save_file
- logger = logging.get_logger(__name__)
- def quick_gelu(x):
- return x * jax.nn.sigmoid(1.702 * x)
- ACT2FN = {
- "gelu": partial(nn.gelu, approximate=False),
- "relu": nn.relu,
- "silu": nn.swish,
- "swish": nn.swish,
- "gelu_new": partial(nn.gelu, approximate=True),
- "quick_gelu": quick_gelu,
- "gelu_pytorch_tanh": partial(nn.gelu, approximate=True),
- }
- def dtype_byte_size(dtype):
- """
- Returns the size (in bytes) occupied by one parameter of type `dtype`. Example:
- ```py
- >>> dtype_byte_size(np.float32)
- 4
- ```
- """
- if dtype is bool:
- return 1 / 8
- bit_search = re.search(r"[^\d](\d+)$", dtype.name)
- if bit_search is None:
- raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
- bit_size = int(bit_search.groups()[0])
- return bit_size // 8
- def flax_shard_checkpoint(params, max_shard_size="10GB"):
- """
- Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
- given size. The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so
- there is no optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For
- example, if the limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as
- [6GB], [6+2GB], [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB].
- <Tip warning={true}>
- If one of the model's weight is bigger that `max_shard_size`, it will end up in its own sub-checkpoint which will
- have a size greater than `max_shard_size`.
- </Tip>
- Args:
- params (`Union[Dict, FrozenDict]`): A `PyTree` of model parameters.
- max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
- The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
- (like `"5MB"`).
- """
- max_shard_size = convert_file_size_to_int(max_shard_size)
- sharded_state_dicts = []
- current_block = {}
- current_block_size = 0
- total_size = 0
- # flatten the weights to chunk
- weights = flatten_dict(params, sep="/")
- for item in weights:
- weight_size = weights[item].size * dtype_byte_size(weights[item].dtype)
- # If this weight is going to tip up over the maximal size, we split.
- if current_block_size + weight_size > max_shard_size:
- sharded_state_dicts.append(current_block)
- current_block = {}
- current_block_size = 0
- current_block[item] = weights[item]
- current_block_size += weight_size
- total_size += weight_size
- # Add the last block
- sharded_state_dicts.append(current_block)
- # If we only have one shard, we return it
- if len(sharded_state_dicts) == 1:
- return {FLAX_WEIGHTS_NAME: sharded_state_dicts[0]}, None
- # Otherwise, let's build the index
- weight_map = {}
- shards = {}
- for idx, shard in enumerate(sharded_state_dicts):
- shard_file = FLAX_WEIGHTS_NAME.replace(".msgpack", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.msgpack")
- shards[shard_file] = shard
- for weight_name in shard.keys():
- weight_map[weight_name] = shard_file
- # Add the metadata
- metadata = {"total_size": total_size}
- index = {"metadata": metadata, "weight_map": weight_map}
- return shards, index
- class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
- r"""
- Base class for all models.
- [`FlaxPreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading,
- downloading and saving models.
- Class attributes (overridden by derived classes):
- - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class
- for this model architecture.
- - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
- classes of the same architecture adding modules on top of the base model.
- - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
- models, `pixel_values` for vision models and `input_values` for speech models).
- """
- config_class = None
- base_model_prefix = ""
- main_input_name = "input_ids"
- _auto_class = None
- _missing_keys = set()
- def __init__(
- self,
- config: PretrainedConfig,
- module: nn.Module,
- input_shape: Tuple = (1, 1),
- seed: int = 0,
- dtype: jnp.dtype = jnp.float32,
- _do_init: bool = True,
- ):
- if config is None:
- raise ValueError("config cannot be None")
- if module is None:
- raise ValueError("module cannot be None")
- # Those are private to be exposed as typed property on derived classes.
- self._config = config
- self._module = module
- # Those are public as their type is generic to every derived classes.
- self.key = PRNGKey(seed)
- self.dtype = dtype
- self.input_shape = input_shape
- self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
- # To check if the model was initialized automatically.
- self._is_initialized = _do_init
- if _do_init:
- # randomly initialized parameters
- random_params = self.init_weights(self.key, input_shape)
- params_shape_tree = jax.eval_shape(lambda params: params, random_params)
- else:
- init_fn = partial(self.init_weights, input_shape=input_shape)
- params_shape_tree = jax.eval_shape(init_fn, self.key)
- logger.info(
- "Model weights are not initialized as `_do_init` is set to `False`. "
- f"Make sure to call `{self.__class__.__name__}.init_weights` manually to initialize the weights."
- )
- # get the shape of the parameters
- self._params_shape_tree = params_shape_tree
- # save required_params as set
- self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
- # initialize the parameters
- if _do_init:
- self.params = random_params
- def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> Dict:
- raise NotImplementedError(f"init method has to be implemented for {self}")
- def enable_gradient_checkpointing(self):
- raise NotImplementedError(f"gradient checkpointing method has to be implemented for {self}")
- @classmethod
- def _from_config(cls, config, **kwargs):
- """
- All context managers that the model should be initialized under go here.
- """
- return cls(config, **kwargs)
- @property
- def framework(self) -> str:
- """
- :str: Identifies that this is a Flax model.
- """
- return "flax"
- @property
- def config(self) -> PretrainedConfig:
- return self._config
- @property
- def module(self) -> nn.Module:
- return self._module
- @property
- def params(self) -> Union[Dict, FrozenDict]:
- if not self._is_initialized:
- raise ValueError(
- "`params` cannot be accessed from model when the model is created with `_do_init=False`. "
- "You must call `init_weights` manually and store the params outside of the model and "
- "pass it explicitly where needed."
- )
- return self._params
- @property
- def required_params(self) -> Set:
- return self._required_params
- @property
- def params_shape_tree(self) -> Dict:
- return self._params_shape_tree
- @params.setter
- def params(self, params: Union[Dict, FrozenDict]):
- # don't set params if the model is not initialized
- if not self._is_initialized:
- raise ValueError(
- "`params` cannot be set from model when the model is created with `_do_init=False`. "
- "You store the params outside of the model."
- )
- if isinstance(params, FrozenDict):
- params = unfreeze(params)
- param_keys = set(flatten_dict(params).keys())
- if len(self.required_params - param_keys) > 0:
- raise ValueError(
- "Some parameters are missing. Make sure that `params` include the following "
- f"parameters {self.required_params - param_keys}"
- )
- self._params = params
- def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
- """
- Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
- """
- # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
- def conditional_cast(param):
- if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
- param = param.astype(dtype)
- return param
- if mask is None:
- return jax.tree_util.tree_map(conditional_cast, params)
- flat_params = flatten_dict(params)
- flat_mask, _ = jax.tree_util.tree_flatten(mask)
- for masked, key in zip(flat_mask, sorted(flat_params.keys())):
- if masked:
- flat_params[key] = conditional_cast(flat_params[key])
- return unflatten_dict(flat_params)
- def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
- r"""
- Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
- the `params` in place.
- This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full
- half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
- Arguments:
- params (`Union[Dict, FrozenDict]`):
- A `PyTree` of model parameters.
- mask (`Union[Dict, FrozenDict]`):
- A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
- you want to cast, and should be `False` for those you want to skip.
- Examples:
- ```python
- >>> from transformers import FlaxBertModel
- >>> # load model
- >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
- >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
- >>> model.params = model.to_bf16(model.params)
- >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
- >>> # then pass the mask as follows
- >>> from flax import traverse_util
- >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
- >>> flat_params = traverse_util.flatten_dict(model.params)
- >>> mask = {
- ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
- ... for path in flat_params
- ... }
- >>> mask = traverse_util.unflatten_dict(mask)
- >>> model.params = model.to_bf16(model.params, mask)
- ```"""
- return self._cast_floating_to(params, jnp.bfloat16, mask)
- def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
- r"""
- Cast the floating-point `parmas` to `jax.numpy.float32`. This method can be used to explicitly convert the
- model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place.
- Arguments:
- params (`Union[Dict, FrozenDict]`):
- A `PyTree` of model parameters.
- mask (`Union[Dict, FrozenDict]`):
- A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
- you want to cast, and should be `False` for those you want to skip
- Examples:
- ```python
- >>> from transformers import FlaxBertModel
- >>> # Download model and configuration from huggingface.co
- >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
- >>> # By default, the model params will be in fp32, to illustrate the use of this method,
- >>> # we'll first cast to fp16 and back to fp32
- >>> model.params = model.to_f16(model.params)
- >>> # now cast back to fp32
- >>> model.params = model.to_fp32(model.params)
- ```"""
- return self._cast_floating_to(params, jnp.float32, mask)
- def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
- r"""
- Cast the floating-point `parmas` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the
- `params` in place.
- This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full
- half-precision training or to save weights in float16 for inference in order to save memory and improve speed.
- Arguments:
- params (`Union[Dict, FrozenDict]`):
- A `PyTree` of model parameters.
- mask (`Union[Dict, FrozenDict]`):
- A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
- you want to cast, and should be `False` for those you want to skip
- Examples:
- ```python
- >>> from transformers import FlaxBertModel
- >>> # load model
- >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
- >>> # By default, the model params will be in fp32, to cast these to float16
- >>> model.params = model.to_fp16(model.params)
- >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
- >>> # then pass the mask as follows
- >>> from flax import traverse_util
- >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
- >>> flat_params = traverse_util.flatten_dict(model.params)
- >>> mask = {
- ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
- ... for path in flat_params
- ... }
- >>> mask = traverse_util.unflatten_dict(mask)
- >>> model.params = model.to_fp16(model.params, mask)
- ```"""
- return self._cast_floating_to(params, jnp.float16, mask)
- @classmethod
- def load_flax_weights(cls, resolved_archive_file):
- try:
- if resolved_archive_file.endswith(".safetensors"):
- state = safe_load_file(resolved_archive_file)
- state = unflatten_dict(state, sep=".")
- else:
- with open(resolved_archive_file, "rb") as state_f:
- state = from_bytes(cls, state_f.read())
- except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
- try:
- with open(resolved_archive_file) as f:
- if f.read().startswith("version"):
- raise OSError(
- "You seem to have cloned a repository without having git-lfs installed. Please"
- " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
- " folder you cloned."
- )
- else:
- raise ValueError from e
- except (UnicodeDecodeError, ValueError):
- raise EnvironmentError(f"Unable to convert {resolved_archive_file} to Flax deserializable object. ")
- return state
- @classmethod
- def load_flax_sharded_weights(cls, shard_files):
- """
- This is the same as [`flax.serialization.from_bytes`]
- (https:lax.readthedocs.io/en/latest/_modules/flax/serialization.html#from_bytes) but for a sharded checkpoint.
- This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
- loaded in the model.
- Args:
- shard_files (`List[str]`:
- The list of shard files to load.
- Returns:
- `Dict`: A nested dictionary of the model parameters, in the expected format for flax models : `{'model':
- {'params': {'...'}}}`.
- """
- # Load the index
- state_sharded_dict = {}
- for shard_file in shard_files:
- # load using msgpack utils
- try:
- with open(shard_file, "rb") as state_f:
- state = from_bytes(cls, state_f.read())
- except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
- with open(shard_file) as f:
- if f.read().startswith("version"):
- raise OSError(
- "You seem to have cloned a repository without having git-lfs installed. Please"
- " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
- " folder you cloned."
- )
- else:
- raise ValueError from e
- except (UnicodeDecodeError, ValueError):
- raise EnvironmentError(f"Unable to convert {shard_file} to Flax deserializable object. ")
- state = flatten_dict(state, sep="/")
- state_sharded_dict.update(state)
- del state
- gc.collect()
- # the state dict is unflattened to the match the format of model.params
- return unflatten_dict(state_sharded_dict, sep="/")
- @classmethod
- def can_generate(cls) -> bool:
- """
- Returns whether this model can generate sequences with `.generate()`. Returns:
- `bool`: Whether this model can generate sequences with `.generate()`.
- """
- # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.
- # Alternativelly, the model can also have a custom `generate` function.
- if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate):
- return False
- return True
- @classmethod
- def from_pretrained(
- cls,
- pretrained_model_name_or_path: Union[str, os.PathLike],
- dtype: jnp.dtype = jnp.float32,
- *model_args,
- config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
- cache_dir: Optional[Union[str, os.PathLike]] = None,
- ignore_mismatched_sizes: bool = False,
- force_download: bool = False,
- local_files_only: bool = False,
- token: Optional[Union[str, bool]] = None,
- revision: str = "main",
- **kwargs,
- ):
- r"""
- Instantiate a pretrained flax model from a pre-trained model configuration.
- The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
- pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
- task.
- The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
- weights are discarded.
- Parameters:
- pretrained_model_name_or_path (`str` or `os.PathLike`):
- Can be either:
- - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
- - A path to a *directory* containing model weights saved using
- [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
- - A path or url to a *pt index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In this case,
- `from_pt` should be set to `True`.
- dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
- The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
- `jax.numpy.bfloat16` (on TPUs).
- This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
- specified all the computation will be performed with the given `dtype`.
- **Note that this only specifies the dtype of the computation and does not influence the dtype of model
- parameters.**
- If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
- [`~FlaxPreTrainedModel.to_bf16`].
- model_args (sequence of positional arguments, *optional*):
- All remaining positional arguments will be passed to the underlying model's `__init__` method.
- config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*):
- Can be either:
- - an instance of a class derived from [`PretrainedConfig`],
- - a string or path valid as input to [`~PretrainedConfig.from_pretrained`].
- Configuration for the model to use instead of an automatically loaded configuration. Configuration can
- be automatically loaded when:
- - The model is a model provided by the library (loaded with the *model id* string of a pretrained
- model).
- - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
- save directory.
- - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
- configuration JSON file named *config.json* is found in the directory.
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory in which a downloaded pretrained model configuration should be cached if the
- standard cache should not be used.
- from_pt (`bool`, *optional*, defaults to `False`):
- Load the model weights from a PyTorch checkpoint save file (see docstring of
- `pretrained_model_name_or_path` argument).
- ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
- Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
- as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
- checkpoint with 3 labels).
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
- resume_download:
- Deprecated and ignored. All downloads are now resumed by default when possible.
- Will be removed in v5 of Transformers.
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only(`bool`, *optional*, defaults to `False`):
- Whether or not to only look at local files (i.e., do not try to download the model).
- token (`str` or `bool`, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
- the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
- git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
- identifier allowed by git.
- <Tip>
- To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
- </Tip>
- subfolder (`str`, *optional*, defaults to `""`):
- In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
- specify the folder name here.
- kwargs (remaining dictionary of keyword arguments, *optional*):
- Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
- `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
- automatically loaded:
- - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
- underlying model's `__init__` method (we assume all relevant updates to the configuration have
- already been done)
- - If a configuration is not provided, `kwargs` will be first passed to the configuration class
- initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
- corresponds to a configuration attribute will be used to override said attribute with the
- supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
- will be passed to the underlying model's `__init__` function.
- Examples:
- ```python
- >>> from transformers import BertConfig, FlaxBertModel
- >>> # Download model and configuration from huggingface.co and cache.
- >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
- >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
- >>> model = FlaxBertModel.from_pretrained("./test/saved_model/")
- >>> # Loading from a PyTorch checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
- >>> config = BertConfig.from_json_file("./pt_model/config.json")
- >>> model = FlaxBertModel.from_pretrained("./pt_model/pytorch_model.bin", from_pt=True, config=config)
- ```"""
- from_pt = kwargs.pop("from_pt", False)
- resume_download = kwargs.pop("resume_download", None)
- proxies = kwargs.pop("proxies", None)
- use_auth_token = kwargs.pop("use_auth_token", None)
- trust_remote_code = kwargs.pop("trust_remote_code", None)
- from_pipeline = kwargs.pop("_from_pipeline", None)
- from_auto_class = kwargs.pop("_from_auto", False)
- _do_init = kwargs.pop("_do_init", True)
- subfolder = kwargs.pop("subfolder", "")
- commit_hash = kwargs.pop("_commit_hash", None)
- # Not relevant for Flax Models
- _ = kwargs.pop("adapter_kwargs", None)
- if use_auth_token is not None:
- warnings.warn(
- "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
- FutureWarning,
- )
- if token is not None:
- raise ValueError(
- "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
- )
- token = use_auth_token
- if trust_remote_code is True:
- logger.warning(
- "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
- " ignored."
- )
- user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class}
- if from_pipeline is not None:
- user_agent["using_pipeline"] = from_pipeline
- if is_offline_mode() and not local_files_only:
- logger.info("Offline mode: forcing local_files_only=True")
- local_files_only = True
- # Load config if we don't provide a configuration
- if not isinstance(config, PretrainedConfig):
- config_path = config if config is not None else pretrained_model_name_or_path
- config, model_kwargs = cls.config_class.from_pretrained(
- config_path,
- cache_dir=cache_dir,
- return_unused_kwargs=True,
- force_download=force_download,
- resume_download=resume_download,
- proxies=proxies,
- local_files_only=local_files_only,
- token=token,
- revision=revision,
- subfolder=subfolder,
- _from_auto=from_auto_class,
- _from_pipeline=from_pipeline,
- _commit_hash=commit_hash,
- **kwargs,
- )
- else:
- model_kwargs = kwargs.copy()
- if commit_hash is None:
- commit_hash = getattr(config, "_commit_hash", None)
- # Add the dtype to model_kwargs
- model_kwargs["dtype"] = dtype
- # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
- # index of the files.
- is_sharded = False
- # Load model
- if pretrained_model_name_or_path is not None:
- pretrained_model_name_or_path = str(pretrained_model_name_or_path)
- is_local = os.path.isdir(pretrained_model_name_or_path)
- if os.path.isdir(pretrained_model_name_or_path):
- if os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)):
- # Load from a Flax checkpoint
- archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
- elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)):
- # Load from a sharded Flax checkpoint
- archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)
- is_sharded = True
- elif is_safetensors_available() and os.path.isfile(
- os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
- ):
- # Load from a safetensors checkpoint
- archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
- elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
- # Load from a PyTorch checkpoint
- archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
- elif from_pt and os.path.isfile(
- os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)
- ):
- # Load from a sharded pytorch checkpoint
- archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)
- is_sharded = True
- # At this stage we don't have a weight file so we will raise an error.
- elif is_safetensors_available() and os.path.isfile(
- os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
- ):
- # Load from a sharded safetensors checkpoint
- archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
- is_sharded = True
- raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!")
- elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
- raise EnvironmentError(
- f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
- "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
- "weights."
- )
- else:
- raise EnvironmentError(
- f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
- f"{pretrained_model_name_or_path}."
- )
- elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
- archive_file = pretrained_model_name_or_path
- is_local = True
- elif is_remote_url(pretrained_model_name_or_path):
- filename = pretrained_model_name_or_path
- resolved_archive_file = download_url(pretrained_model_name_or_path)
- else:
- if from_pt:
- filename = WEIGHTS_NAME
- else:
- filename = FLAX_WEIGHTS_NAME
- try:
- # Load from URL or cache if already cached
- cached_file_kwargs = {
- "cache_dir": cache_dir,
- "force_download": force_download,
- "proxies": proxies,
- "resume_download": resume_download,
- "local_files_only": local_files_only,
- "token": token,
- "user_agent": user_agent,
- "revision": revision,
- "subfolder": subfolder,
- "_raise_exceptions_for_gated_repo": False,
- "_raise_exceptions_for_missing_entries": False,
- "_commit_hash": commit_hash,
- }
- resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
- # Maybe the checkpoint is sharded, we try to grab the index name in this case.
- if resolved_archive_file is None and filename == FLAX_WEIGHTS_NAME:
- resolved_archive_file = cached_file(
- pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME, **cached_file_kwargs
- )
- if resolved_archive_file is not None:
- is_sharded = True
- # Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case.
- if resolved_archive_file is None and from_pt:
- resolved_archive_file = cached_file(
- pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
- )
- if resolved_archive_file is not None:
- is_sharded = True
- # If we still haven't found anything, look for `safetensors`.
- if resolved_archive_file is None:
- # No support for sharded safetensors yet, so we'll raise an error if that's all we find.
- filename = SAFE_WEIGHTS_NAME
- resolved_archive_file = cached_file(
- pretrained_model_name_or_path, SAFE_WEIGHTS_NAME, **cached_file_kwargs
- )
- # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
- # result when internet is up, the repo and revision exist, but the file does not.
- if resolved_archive_file is None:
- # Otherwise, maybe there is a TF or Torch model file. We try those to give a helpful error
- # message.
- has_file_kwargs = {
- "revision": revision,
- "proxies": proxies,
- "token": token,
- "cache_dir": cache_dir,
- "local_files_only": local_files_only,
- }
- if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs):
- is_sharded = True
- raise NotImplementedError(
- "Support for sharded checkpoints using safetensors is coming soon!"
- )
- elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
- raise EnvironmentError(
- f"{pretrained_model_name_or_path} does not appear to have a file named"
- f" {FLAX_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to"
- " load this model from those weights."
- )
- elif has_file(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **has_file_kwargs):
- raise EnvironmentError(
- f"{pretrained_model_name_or_path} does not appear to have a file named"
- f" {FLAX_WEIGHTS_INDEX_NAME} but there is a sharded file for PyTorch weights. Use"
- " `from_pt=True` to load this model from those weights."
- )
- else:
- raise EnvironmentError(
- f"{pretrained_model_name_or_path} does not appear to have a file named"
- f" {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
- )
- except EnvironmentError:
- # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
- # to the original exception.
- raise
- except Exception:
- # For any other exception, we throw a generic error.
- raise EnvironmentError(
- f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
- " from 'https://huggingface.co/models', make sure you don't have a local directory with the"
- f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
- f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
- )
- if is_local:
- logger.info(f"loading weights file {archive_file}")
- resolved_archive_file = archive_file
- filename = resolved_archive_file.split(os.path.sep)[-1]
- else:
- logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
- else:
- resolved_archive_file = None
- # We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
- if is_sharded:
- # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
- resolved_archive_file, _ = get_checkpoint_shard_files(
- pretrained_model_name_or_path,
- resolved_archive_file,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- resume_download=resume_download,
- local_files_only=local_files_only,
- token=token,
- user_agent=user_agent,
- revision=revision,
- subfolder=subfolder,
- _commit_hash=commit_hash,
- )
- safetensors_from_pt = False
- if filename == SAFE_WEIGHTS_NAME:
- with safe_open(resolved_archive_file, framework="flax") as f:
- safetensors_metadata = f.metadata()
- if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax"]:
- raise OSError(
- f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata."
- " Make sure you save your model with the `save_pretrained` method."
- )
- safetensors_from_pt = safetensors_metadata.get("format") == "pt"
- # init random models
- model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)
- if from_pt or safetensors_from_pt:
- state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded)
- else:
- if is_sharded:
- state = cls.load_flax_sharded_weights(resolved_archive_file)
- else:
- state = cls.load_flax_weights(resolved_archive_file)
- # make sure all arrays are stored as jnp.arrays
- # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
- # https://github.com/google/flax/issues/1261
- if _do_init:
- state = jax.tree_util.tree_map(jnp.array, state)
- else:
- # keep the params on CPU if we don't want to initialize
- state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.local_devices(backend="cpu")[0]), state)
- if "batch_stats" in state: # if flax model contains batch norm layers
- # if model is base model only use model_prefix key
- if (
- cls.base_model_prefix not in dict(model.params_shape_tree["params"])
- and cls.base_model_prefix in state["params"]
- ):
- state["params"] = state["params"][cls.base_model_prefix]
- state["batch_stats"] = state["batch_stats"][cls.base_model_prefix]
- # if model is head model and we are loading weights from base model
- # we initialize new params dict with base_model_prefix
- if (
- cls.base_model_prefix in dict(model.params_shape_tree["params"])
- and cls.base_model_prefix not in state["params"]
- ):
- state = {
- "params": {cls.base_model_prefix: state["params"]},
- "batch_stats": {cls.base_model_prefix: state["batch_stats"]},
- }
- else:
- # if model is base model only use model_prefix key
- if cls.base_model_prefix not in dict(model.params_shape_tree) and cls.base_model_prefix in state:
- state = state[cls.base_model_prefix]
- # if model is head model and we are loading weights from base model
- # we initialize new params dict with base_model_prefix
- if cls.base_model_prefix in dict(model.params_shape_tree) and cls.base_model_prefix not in state:
- state = {cls.base_model_prefix: state}
- # flatten dicts
- state = flatten_dict(state)
- random_state = flatten_dict(unfreeze(model.params if _do_init else model.params_shape_tree))
- missing_keys = model.required_params - set(state.keys())
- unexpected_keys = set(state.keys()) - model.required_params
- # Disabling warning when porting pytorch weights to flax, flax does not uses num_batches_tracked
- for unexpected_key in unexpected_keys.copy():
- if "num_batches_tracked" in unexpected_key[-1]:
- unexpected_keys.remove(unexpected_key)
- if missing_keys and not _do_init:
- logger.warning(
- f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
- "Make sure to call model.init_weights to initialize the missing weights."
- )
- cls._missing_keys = missing_keys
- # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
- # matching the weights in the model.
- mismatched_keys = []
- for key in state.keys():
- if key in random_state and state[key].shape != random_state[key].shape:
- if ignore_mismatched_sizes:
- mismatched_keys.append((key, state[key].shape, random_state[key].shape))
- state[key] = random_state[key]
- else:
- raise ValueError(
- f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
- f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. "
- "Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this "
- "model."
- )
- # add missing keys as random parameters if we are initializing
- if missing_keys and _do_init:
- for missing_key in missing_keys:
- state[missing_key] = random_state[missing_key]
- # remove unexpected keys to not be saved again
- for unexpected_key in unexpected_keys:
- del state[unexpected_key]
- if len(unexpected_keys) > 0:
- logger.warning(
- f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
- f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
- f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
- " with another architecture (e.g. initializing a BertForSequenceClassification model from a"
- " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
- f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
- " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
- )
- else:
- logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
- if len(missing_keys) > 0:
- logger.warning(
- f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
- f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
- " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
- )
- elif len(mismatched_keys) == 0:
- logger.info(
- f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
- f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
- f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
- " training."
- )
- if len(mismatched_keys) > 0:
- mismatched_warning = "\n".join(
- [
- f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
- for key, shape1, shape2 in mismatched_keys
- ]
- )
- logger.warning(
- f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
- f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
- f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
- " to use it for predictions and inference."
- )
- # dictionary of key: dtypes for the model params
- param_dtypes = jax.tree_util.tree_map(lambda x: x.dtype, state)
- # extract keys of parameters not in jnp.float32
- fp16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.float16]
- bf16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.bfloat16]
- # raise a warning if any of the parameters are not in jnp.float32
- if len(fp16_params) > 0:
- logger.warning(
- f"Some of the weights of {model.__class__.__name__} were initialized in float16 precision from "
- f"the model checkpoint at {pretrained_model_name_or_path}:\n{fp16_params}\n"
- "You should probably UPCAST the model weights to float32 if this was not intended. "
- "See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
- )
- if len(bf16_params) > 0:
- logger.warning(
- f"Some of the weights of {model.__class__.__name__} were initialized in bfloat16 precision from "
- f"the model checkpoint at {pretrained_model_name_or_path}:\n{bf16_params}\n"
- "You should probably UPCAST the model weights to float32 if this was not intended. "
- "See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
- )
- # If it is a model with generation capabilities, attempt to load the generation config
- if model.can_generate():
- try:
- model.generation_config = GenerationConfig.from_pretrained(
- pretrained_model_name_or_path,
- cache_dir=cache_dir,
- force_download=force_download,
- resume_download=resume_download,
- proxies=proxies,
- local_files_only=local_files_only,
- token=token,
- revision=revision,
- subfolder=subfolder,
- _from_auto=from_auto_class,
- _from_pipeline=from_pipeline,
- **kwargs,
- )
- except OSError:
- logger.info(
- "Generation config file not found, using a generation config created from the model config."
- )
- pass
- if _do_init:
- # set correct parameters
- model.params = unflatten_dict(state)
- return model
- else:
- return model, unflatten_dict(state)
- def save_pretrained(
- self,
- save_directory: Union[str, os.PathLike],
- params=None,
- push_to_hub=False,
- max_shard_size="10GB",
- token: Optional[Union[str, bool]] = None,
- safe_serialization: bool = False,
- **kwargs,
- ):
- """
- Save a model and its configuration file to a directory, so that it can be re-loaded using the
- `[`~FlaxPreTrainedModel.from_pretrained`]` class method
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to which to save. Will be created if it doesn't exist.
- push_to_hub (`bool`, *optional*, defaults to `False`):
- Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
- repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
- namespace).
- max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
- The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
- lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
- <Tip warning={true}>
- If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
- which will be bigger than `max_shard_size`.
- </Tip>
- token (`str` or `bool`, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
- the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
- kwargs (`Dict[str, Any]`, *optional*):
- Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
- safe_serialization (`bool`, *optional*, defaults to `False`):
- Whether to save the model using `safetensors` or through msgpack.
- """
- use_auth_token = kwargs.pop("use_auth_token", None)
- if use_auth_token is not None:
- warnings.warn(
- "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
- FutureWarning,
- )
- if token is not None:
- raise ValueError(
- "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
- )
- token = use_auth_token
- if token is not None:
- kwargs["token"] = token
- if os.path.isfile(save_directory):
- logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
- return
- os.makedirs(save_directory, exist_ok=True)
- if push_to_hub:
- commit_message = kwargs.pop("commit_message", None)
- repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
- repo_id = self._create_repo(repo_id, **kwargs)
- files_timestamps = self._get_files_timestamps(save_directory)
- # get abs dir
- save_directory = os.path.abspath(save_directory)
- # save config as well
- self.config.architectures = [self.__class__.__name__[4:]]
- # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be
- # loaded from the Hub.
- if self._auto_class is not None:
- custom_object_save(self, save_directory, config=self.config)
- self.config.save_pretrained(save_directory)
- if self.can_generate():
- self.generation_config.save_pretrained(save_directory)
- # save model
- weights_name = SAFE_WEIGHTS_NAME if safe_serialization else FLAX_WEIGHTS_NAME
- output_model_file = os.path.join(save_directory, weights_name)
- shards, index = flax_shard_checkpoint(params if params is not None else self.params, max_shard_size)
- # Clean the folder from a previous save
- for filename in os.listdir(save_directory):
- full_filename = os.path.join(save_directory, filename)
- weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
- if (
- filename.startswith(weights_no_suffix)
- and os.path.isfile(full_filename)
- and filename not in shards.keys()
- ):
- os.remove(full_filename)
- if index is None:
- if safe_serialization:
- params = params if params is not None else self.params
- flat_dict = flatten_dict(params, sep=".")
- safe_save_file(flat_dict, output_model_file, metadata={"format": "flax"})
- else:
- with open(output_model_file, "wb") as f:
- params = params if params is not None else self.params
- model_bytes = to_bytes(params)
- f.write(model_bytes)
- else:
- save_index_file = os.path.join(save_directory, FLAX_WEIGHTS_INDEX_NAME)
- # Save the index as well
- with open(save_index_file, "w", encoding="utf-8") as f:
- content = json.dumps(index, indent=2, sort_keys=True) + "\n"
- f.write(content)
- logger.info(
- f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
- f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
- f"index located at {save_index_file}."
- )
- for shard_file, shard in shards.items():
- # the shard item are unflattened, to save them we need to flatten them again
- with open(os.path.join(save_directory, shard_file), mode="wb") as f:
- params = unflatten_dict(shard, sep="/")
- shard_bytes = to_bytes(params)
- f.write(shard_bytes)
- logger.info(f"Model weights saved in {output_model_file}")
- if push_to_hub:
- self._upload_modified_files(
- save_directory,
- repo_id,
- files_timestamps,
- commit_message=commit_message,
- token=token,
- )
- @classmethod
- def register_for_auto_class(cls, auto_class="FlaxAutoModel"):
- """
- Register this class with a given auto class. This should only be used for custom models as the ones in the
- library are already mapped with an auto class.
- <Tip warning={true}>
- This API is experimental and may have some slight breaking changes in the next releases.
- </Tip>
- Args:
- auto_class (`str` or `type`, *optional*, defaults to `"FlaxAutoModel"`):
- The auto class to register this new model with.
- """
- if not isinstance(auto_class, str):
- auto_class = auto_class.__name__
- import transformers.models.auto as auto_module
- if not hasattr(auto_module, auto_class):
- raise ValueError(f"{auto_class} is not a valid auto class.")
- cls._auto_class = auto_class
- # To update the docstring, we need to copy the method, otherwise we change the original docstring.
- FlaxPreTrainedModel.push_to_hub = copy_func(FlaxPreTrainedModel.push_to_hub)
- if FlaxPreTrainedModel.push_to_hub.__doc__ is not None:
- FlaxPreTrainedModel.push_to_hub.__doc__ = FlaxPreTrainedModel.push_to_hub.__doc__.format(
- object="model", object_class="FlaxAutoModel", object_files="model checkpoint"
- )
- def overwrite_call_docstring(model_class, docstring):
- # copy __call__ function to be sure docstring is changed only for this function
- model_class.__call__ = copy_func(model_class.__call__)
- # delete existing docstring
- model_class.__call__.__doc__ = None
- # set correct docstring
- model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__)
- def append_call_sample_docstring(
- model_class, checkpoint, output_type, config_class, mask=None, revision=None, real_checkpoint=None
- ):
- model_class.__call__ = copy_func(model_class.__call__)
- model_class.__call__ = add_code_sample_docstrings(
- checkpoint=checkpoint,
- output_type=output_type,
- config_class=config_class,
- model_cls=model_class.__name__,
- revision=revision,
- real_checkpoint=real_checkpoint,
- )(model_class.__call__)
- def append_replace_return_docstrings(model_class, output_type, config_class):
- model_class.__call__ = copy_func(model_class.__call__)
- model_class.__call__ = replace_return_docstrings(
- output_type=output_type,
- config_class=config_class,
- )(model_class.__call__)
|