| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596 |
- # Copyright 2020 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import logging
- from dataclasses import dataclass, field
- from pathlib import Path
- from typing import Optional, Union
- from .generation.configuration_utils import GenerationConfig
- from .training_args import TrainingArguments
- from .utils import add_start_docstrings
- logger = logging.getLogger(__name__)
- @dataclass
- @add_start_docstrings(TrainingArguments.__doc__)
- class Seq2SeqTrainingArguments(TrainingArguments):
- """
- Args:
- sortish_sampler (`bool`, *optional*, defaults to `False`):
- Whether to use a *sortish sampler* or not. Only possible if the underlying datasets are *Seq2SeqDataset*
- for now but will become generally available in the near future.
- It sorts the inputs according to lengths in order to minimize the padding size, with a bit of randomness
- for the training set.
- predict_with_generate (`bool`, *optional*, defaults to `False`):
- Whether to use generate to calculate generative metrics (ROUGE, BLEU).
- generation_max_length (`int`, *optional*):
- The `max_length` to use on each evaluation loop when `predict_with_generate=True`. Will default to the
- `max_length` value of the model configuration.
- generation_num_beams (`int`, *optional*):
- The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. Will default to the
- `num_beams` value of the model configuration.
- generation_config (`str` or `Path` or [`~generation.GenerationConfig`], *optional*):
- Allows to load a [`~generation.GenerationConfig`] from the `from_pretrained` method. This can be either:
- - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
- huggingface.co.
- - a path to a *directory* containing a configuration file saved using the
- [`~GenerationConfig.save_pretrained`] method, e.g., `./my_model_directory/`.
- - a [`~generation.GenerationConfig`] object.
- """
- sortish_sampler: bool = field(default=False, metadata={"help": "Whether to use SortishSampler or not."})
- predict_with_generate: bool = field(
- default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
- )
- generation_max_length: Optional[int] = field(
- default=None,
- metadata={
- "help": (
- "The `max_length` to use on each evaluation loop when `predict_with_generate=True`. Will default "
- "to the `max_length` value of the model configuration."
- )
- },
- )
- generation_num_beams: Optional[int] = field(
- default=None,
- metadata={
- "help": (
- "The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. Will default "
- "to the `num_beams` value of the model configuration."
- )
- },
- )
- generation_config: Optional[Union[str, Path, GenerationConfig]] = field(
- default=None,
- metadata={
- "help": "Model id, file path or url pointing to a GenerationConfig json file, to use during prediction."
- },
- )
- def to_dict(self):
- """
- Serializes this instance while replace `Enum` by their values and `GenerationConfig` by dictionaries (for JSON
- serialization support). It obfuscates the token values by removing their value.
- """
- # filter out fields that are defined as field(init=False)
- d = super().to_dict()
- for k, v in d.items():
- if isinstance(v, GenerationConfig):
- d[k] = v.to_dict()
- return d
|