File tree Expand file tree Collapse file tree 2 files changed +5
-2
lines changed Expand file tree Collapse file tree 2 files changed +5
-2
lines changed Original file line number Diff line number Diff line change 1414"""Greedy Sampler."""
1515
1616import tensorflow as tf
17+ from tensorflow import keras
1718
1819from keras_nlp .samplers .sampler import Sampler
1920from keras_nlp .samplers .sampler import base_sampler_keyword_args
2021from keras_nlp .samplers .sampler import call_keyword_docstring
2122from keras_nlp .samplers .sampler import sample_keyword_docstring
2223
2324
25+ @keras .utils .register_keras_serializable (package = "keras_nlp" )
2426class Greedy (Sampler ):
2527 """Greedy sampler class.
2628
@@ -39,7 +41,6 @@ class Greedy(Sampler):
3941 VOCAB_SIZE = 10
4042 FEATURE_SIZE = 16
4143 START_ID = 1
42- END_ID = 2
4344
4445 # Create a dummy model to predict the next token.
4546 model = keras.Sequential(
@@ -62,7 +63,7 @@ def token_probability_fn(inputs, mask):
6263
6364 sampler = keras_nlp.samplers.Greedy()
6465 # Print the generated sequence (token ids).
65- print(sampler(token_probability_fn, prompt, 10, end_token_id=END_ID ))
66+ print(sampler(token_probability_fn, prompt, 10))
6667 ```
6768 """
6869
Original file line number Diff line number Diff line change 1414"""Base sampler class."""
1515
1616import tensorflow as tf
17+ from tensorflow import keras
1718
1819
20+ @keras .utils .register_keras_serializable (package = "keras_nlp" )
1921class Sampler :
2022 """Base sampler class.
2123
You can’t perform that action at this time.
0 commit comments