Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
789f2f4
TF generate start refactor
patrickvonplaten Feb 8, 2022
6702ed3
Add tf tests for sample generate
patrickvonplaten Feb 8, 2022
cf6de09
re-organize
patrickvonplaten Feb 9, 2022
7164d93
boom boom
patrickvonplaten Feb 9, 2022
0844c83
Apply suggestions from code review
patrickvonplaten Feb 9, 2022
b5ae041
re-add
patrickvonplaten Feb 9, 2022
5c3f3f5
Merge branch 'tf_generate_refactor' of https://github.com/patrickvonp…
patrickvonplaten Feb 9, 2022
0671b89
add all code
patrickvonplaten Feb 9, 2022
c23bff2
make random greedy pass
patrickvonplaten Feb 10, 2022
7786d18
make encoder-decoder random work
patrickvonplaten Feb 10, 2022
d04530f
further improvements
patrickvonplaten Feb 10, 2022
73090dd
delete bogus file
patrickvonplaten Feb 10, 2022
3db93ff
make gpt2 and t5 tests work
patrickvonplaten Feb 11, 2022
7a7b7ef
finish logits tests
patrickvonplaten Feb 11, 2022
7b1b2cc
correct logits processors
patrickvonplaten Feb 14, 2022
a8cf81e
Merge branch 'master' of https://github.com/huggingface/transformers …
patrickvonplaten Feb 14, 2022
1a9e870
correct past / encoder_outputs drama
patrickvonplaten Feb 14, 2022
385c24f
refactor some methods
patrickvonplaten Feb 14, 2022
bd750ff
another fix
patrickvonplaten Feb 14, 2022
49e33b0
refactor shape_list
patrickvonplaten Feb 14, 2022
4b2460d
fix more shape list
patrickvonplaten Feb 14, 2022
ed5f2ff
import shape
patrickvonplaten Feb 14, 2022
dd1c214
finish docs
patrickvonplaten Feb 14, 2022
0c7d049
fix imports
patrickvonplaten Feb 14, 2022
726355e
make style
patrickvonplaten Feb 14, 2022
6293862
correct tf utils
patrickvonplaten Feb 14, 2022
b2934ee
Fix TFRag as well
patrickvonplaten Feb 14, 2022
4f6d927
Apply Lysandre's and Sylvais suggestions
patrickvonplaten Feb 15, 2022
39c0b65
Update tests/test_generation_tf_logits_process.py
patrickvonplaten Feb 15, 2022
920a991
Update src/transformers/tf_utils.py
patrickvonplaten Feb 15, 2022
4b7d994
remove cpu according to gante
patrickvonplaten Feb 15, 2022
3fbe55b
correct logit processor
patrickvonplaten Feb 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/source/internal/generation_utils.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,24 @@ generation.
[[autodoc]] InfNanRemoveLogitsProcessor
- __call__

[[autodoc]] TFLogitsProcessor
- __call__

[[autodoc]] TFLogitsProcessorList
- __call__

[[autodoc]] TFMinLengthLogitsProcessor
- __call__

[[autodoc]] TFNoBadWordsLogitsProcessor
- __call__

[[autodoc]] TFNoRepeatNGramLogitsProcessor
- __call__

[[autodoc]] TFRepetitionPenaltyLogitsProcessor
- __call__

[[autodoc]] FlaxLogitsProcessor
- __call__

Expand Down
17 changes: 17 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1592,6 +1592,14 @@
_import_structure["activations_tf"] = []
_import_structure["benchmark.benchmark_args_tf"] = ["TensorFlowBenchmarkArguments"]
_import_structure["benchmark.benchmark_tf"] = ["TensorFlowBenchmark"]
_import_structure["generation_tf_logits_process"] = [
"TFLogitsProcessor",
"TFLogitsProcessorList",
"TFMinLengthLogitsProcessor",
"TFNoBadWordsLogitsProcessor",
"TFNoRepeatNGramLogitsProcessor",
"TFRepetitionPenaltyLogitsProcessor",
]
_import_structure["generation_tf_utils"] = ["tf_top_k_top_p_filtering"]
_import_structure["keras_callbacks"] = ["KerasMetricCallback", "PushToHubCallback"]
_import_structure["modeling_tf_outputs"] = []
Expand Down Expand Up @@ -2046,6 +2054,7 @@
]
)
_import_structure["optimization_tf"] = ["AdamWeightDecay", "GradientAccumulator", "WarmUp", "create_optimizer"]
_import_structure["tf_utils"] = []
_import_structure["trainer_tf"] = ["TFTrainer"]

else:
Expand Down Expand Up @@ -3572,6 +3581,14 @@

# Benchmarks
from .benchmark.benchmark_tf import TensorFlowBenchmark
from .generation_tf_logits_process import (
TFLogitsProcessor,
TFLogitsProcessorList,
TFMinLengthLogitsProcessor,
TFNoBadWordsLogitsProcessor,
TFNoRepeatNGramLogitsProcessor,
TFRepetitionPenaltyLogitsProcessor,
)
from .generation_tf_utils import tf_top_k_top_p_filtering
from .keras_callbacks import KerasMetricCallback, PushToHubCallback
from .modeling_tf_layoutlm import (
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/generation_flax_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.

import inspect
from abc import ABC

import jax
import jax.lax as lax
Expand Down Expand Up @@ -48,7 +47,7 @@
"""


class FlaxLogitsProcessor(ABC):
class FlaxLogitsProcessor:
"""Abstract base class for all logit processors that can be applied during generation."""

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
Expand All @@ -59,7 +58,7 @@ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray) -> jnp.ndarray:
)


class FlaxLogitsWarper(ABC):
class FlaxLogitsWarper:
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/generation_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import inspect
import math
from abc import ABC
from typing import Callable, Iterable, List, Optional

import numpy as np
Expand Down Expand Up @@ -49,7 +48,7 @@
"""


class LogitsProcessor(ABC):
class LogitsProcessor:
"""Abstract base class for all logit processors that can be applied during generation."""

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
Expand All @@ -60,7 +59,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
)


class LogitsWarper(ABC):
class LogitsWarper:
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
Expand Down
Loading