Skip to content

Commit 8a8f940

Browse files
small changes
1 parent d4339ce commit 8a8f940

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

keras_nlp/samplers/greedy.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414
"""Greedy Sampler."""
1515

1616
import tensorflow as tf
17+
from tensorflow import keras
1718

1819
from keras_nlp.samplers.sampler import Sampler
1920
from keras_nlp.samplers.sampler import base_sampler_keyword_args
2021
from keras_nlp.samplers.sampler import call_keyword_docstring
2122
from keras_nlp.samplers.sampler import sample_keyword_docstring
2223

2324

25+
@keras.utils.register_keras_serializable(package="keras_nlp")
2426
class Greedy(Sampler):
2527
"""Greedy sampler class.
2628
@@ -39,7 +41,6 @@ class Greedy(Sampler):
3941
VOCAB_SIZE = 10
4042
FEATURE_SIZE = 16
4143
START_ID = 1
42-
END_ID = 2
4344
4445
# Create a dummy model to predict the next token.
4546
model = keras.Sequential(
@@ -62,7 +63,7 @@ def token_probability_fn(inputs, mask):
6263
6364
sampler = keras_nlp.samplers.Greedy()
6465
# Print the generated sequence (token ids).
65-
print(sampler(token_probability_fn, prompt, 10, end_token_id=END_ID))
66+
print(sampler(token_probability_fn, prompt, 10))
6667
```
6768
"""
6869

keras_nlp/samplers/sampler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
"""Base sampler class."""
1515

1616
import tensorflow as tf
17+
from tensorflow import keras
1718

1819

20+
@keras.utils.register_keras_serializable(package="keras_nlp")
1921
class Sampler:
2022
"""Base sampler class.
2123

0 commit comments

Comments
 (0)