Skip to content

Commit 8323153

Browse files
Address comments: fix docstring, remove multicase support
1 parent 8a8f940 commit 8323153

File tree

4 files changed

+69
-82
lines changed

4 files changed

+69
-82
lines changed

keras_nlp/samplers/__init__.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@ def deserialize(config, custom_objects=None):
2626
all_classes = {
2727
"greedy": Greedy,
2828
}
29-
if config["class_name"].lower() in all_classes:
30-
config["class_name"] = config["class_name"].lower()
3129
return keras.utils.deserialize_keras_object(
3230
config,
3331
module_objects=all_classes,
@@ -55,23 +53,27 @@ def get(identifier):
5553
instance of the class by its constructor.
5654
5755
Args:
58-
identifier: String or dict that contains the sampler name or
59-
configurations.
56+
identifier: String or dict that contains the sampler name or
57+
configurations.
6058
6159
Returns:
62-
Sampler instance base on the input identifier.
60+
Sampler instance base on the input identifier.
6361
6462
Raises:
65-
ValueError: If the input identifier is not a supported type or in a bad
66-
format.
63+
ValueError: If the input identifier is not a supported type or in a bad
64+
format.
6765
"""
6866

6967
if identifier is None:
7068
return None
7169
if isinstance(identifier, dict):
7270
return deserialize(identifier)
7371
elif isinstance(identifier, str):
74-
identifier = {"class_name": str(identifier), "config": {}}
72+
if not identifier.islower():
73+
raise KeyError(
74+
"`keras_nlp.samplers.get()` must take a lowercase string "
75+
f"identifier, but received: {identifier}."
76+
)
7577
return deserialize(identifier)
7678
elif callable(identifier):
7779
return identifier

keras_nlp/samplers/greedy.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@
1717
from tensorflow import keras
1818

1919
from keras_nlp.samplers.sampler import Sampler
20-
from keras_nlp.samplers.sampler import base_sampler_keyword_args
21-
from keras_nlp.samplers.sampler import call_keyword_docstring
22-
from keras_nlp.samplers.sampler import sample_keyword_docstring
20+
from keras_nlp.samplers.sampler import base_sampler_args_docstring
21+
from keras_nlp.samplers.sampler import call_args_docstring
22+
from keras_nlp.samplers.sampler import sample_args_docstring
23+
from keras_nlp.utils.python_utils import format_docstring
2324

2425

26+
@format_docstring(
27+
base_sampler_args=base_sampler_args_docstring, call_args=call_args_docstring
28+
)
2529
@keras.utils.register_keras_serializable(package="keras_nlp")
2630
class Greedy(Sampler):
2731
"""Greedy sampler class.
@@ -30,10 +34,10 @@ class Greedy(Sampler):
3034
token of the largest probability as the next token.
3135
3236
Args:
33-
{{base_sampler_keyword_args}}
37+
{{base_sampler_args}}
3438
3539
Call Args:
36-
{{call_keyword_args}}
40+
{{call_args}}
3741
3842
Examples:
3943
```python
@@ -73,13 +77,14 @@ def __init__(
7377
):
7478
super().__init__(jit_compile)
7579

80+
@format_docstring(sample_args=sample_args_docstring)
7681
def sample(
7782
self, token_probability_fn, prompt, mask, num_steps, from_logits=True
7883
):
7984
"""Sampling logic implementation.
8085
8186
Args:
82-
{{sample_keyword_docstring}}
87+
{{sample_args}}
8388
"""
8489
batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1]
8590
max_length = tf.cast(max_length, num_steps.dtype)
@@ -88,7 +93,6 @@ def sample(
8893
current_index = max_length - num_steps
8994

9095
def one_step(current_index, prompt, mask):
91-
9296
probs = token_probability_fn(prompt, mask)
9397
next_token_prob = tf.gather(
9498
probs,
@@ -143,14 +147,3 @@ def one_step(current_index, prompt, mask):
143147
loop_vars=(current_index, prompt, mask),
144148
)
145149
return prompt
146-
147-
148-
Greedy.__doc__ = Greedy.__doc__.replace(
149-
"{{base_sampler_keyword_args}}", base_sampler_keyword_args
150-
)
151-
Greedy.__doc__ = Greedy.__doc__.replace(
152-
"{{call_keyword_docstring}}", call_keyword_docstring
153-
)
154-
Greedy.sample.__doc__ = Greedy.sample.__doc__.replace(
155-
"{{sample_keyword_docstring}}", sample_keyword_docstring
156-
)

keras_nlp/samplers/sampler.py

Lines changed: 47 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,55 @@
1616
import tensorflow as tf
1717
from tensorflow import keras
1818

19+
from keras_nlp.utils.python_utils import format_docstring
1920

21+
base_sampler_args_docstring = """
22+
jit_compile: bool, defaults to True. If True, XLA compilation will be used.
23+
"""
24+
25+
call_args_docstring = """
26+
token_probability_fn: a function that generates the probability of
27+
the next token over the whole vocabulary for each input token.
28+
prompt: a list of integers or an integer Tensor, can be 1D or 2D. The
29+
initial tokens to append generated tokens.
30+
max_length: int. The max length of generated sequence.
31+
padding_mask: a tensor, defaults to None. The padding mask of the prompt.
32+
end_token_id: int, defaults to None. The token marking the end of the
33+
sequence, once encountered the generation is finished for the exact
34+
sequence. If None, every sequence is generated up to `max_length`.
35+
If set, all tokens after encountering `end_token_id` will be
36+
replaced with `pad_token_id`.
37+
from_logits: bool, defaults to True. Indicate if the `token_probability_fn`
38+
returns logits. If False, `token_probability_fn` returns probability
39+
distributions.
40+
"""
41+
42+
sample_args_docstring = """
43+
token_probability_fn: a function that generates the probability of
44+
the next token over the whole vocabulary for each input token.
45+
prompt: a dense int Tensor of shape [batch_size, max_length]. The
46+
placeholder for generated sequence.
47+
mask: a dense bool Tensor of shape [batch_size, max_length]. The mask of
48+
prompt.
49+
num_steps: int. The remaining number of tokens to generate.
50+
from_logits: bool, defaults to True. Indicate if the `token_probability_fn`
51+
returns logits. If False, `token_probability_fn` returns probability
52+
distributions.
53+
"""
54+
55+
56+
@format_docstring(
57+
base_sampler_args=base_sampler_args_docstring, call_args=call_args_docstring
58+
)
2059
@keras.utils.register_keras_serializable(package="keras_nlp")
2160
class Sampler:
2261
"""Base sampler class.
2362
2463
Args:
25-
{{base_optimizer_keyword_args}}
64+
{{base_sampler_args}}
2665
2766
Call Args:
28-
{{call_keyword_docstring}}
67+
{{call_args}}
2968
3069
The inputs and outputs of Sampler class are both token ids.
3170
@@ -39,7 +78,8 @@ class Sampler:
3978
START_ID = 1
4079
END_ID = 2
4180
42-
# Create a dummy model to predict the next token.
81+
# Create a dummy model to predict the next token. Note that the output is
82+
# random without training, here we jsut demo how `samplers` works.
4383
model = keras.Sequential(
4484
[
4585
keras.Input(shape=[None]),
@@ -178,7 +218,8 @@ def __call__(
178218
from_logits=True,
179219
):
180220
prompt, padding_mask = self._validate_prompt_and_mask(
181-
prompt, padding_mask
221+
prompt,
222+
padding_mask,
182223
)
183224

184225
input_is_1d = prompt.shape.rank == 1
@@ -214,13 +255,14 @@ def __call__(
214255

215256
return tf.squeeze(prompt, axis=0) if input_is_1d else prompt
216257

258+
@format_docstring(sample_args=sample_args_docstring)
217259
def sample(
218260
self, token_probability_fn, prompt, mask, num_steps, from_logits=True
219261
):
220262
"""Sampling logic implementation.
221263
222264
Args:
223-
{{sample_keyword_docstring}}
265+
{{sample_args}}
224266
225267
Returns:
226268
A dense int Tensor, representing the generated text in token id
@@ -232,48 +274,3 @@ def get_config(self):
232274
return {
233275
"jit_compile": self.jit_compile,
234276
}
235-
236-
237-
base_sampler_keyword_args = """
238-
jit_compile: bool, defaults to True. If True, XLA compilation will be used.
239-
"""
240-
241-
call_keyword_docstring = """
242-
token_probability_fn: a function that generates the probability of
243-
the next token over the whole vocabulary for each input token.
244-
prompt: a list of integers or an integer Tensor, can be 1D or 2D. The
245-
initial tokens to append generated tokens.
246-
max_length: int. The max length of generated sequence.
247-
padding_mask: a tensor, defaults to None. The padding mask of the prompt.
248-
end_token_id: int, defaults to None. The token marking the end of the
249-
sequence, once encountered the generation is finished for the exact
250-
sequence. If None, every sequence is generated up to `max_length`.
251-
If set, all tokens after encountering `end_token_id` will be
252-
replaced with `pad_token_id`.
253-
from_logits: bool, defaults to True. Indicate if the `token_probability_fn`
254-
returns logits. If False, `token_probability_fn` returns probability
255-
distributions.
256-
"""
257-
258-
sample_keyword_docstring = """
259-
token_probability_fn: a function that generates the probability of
260-
the next token over the whole vocabulary for each input token.
261-
prompt: a dense int Tensor of shape [batch_size, max_length]. The
262-
placeholder for generated sequence.
263-
mask: a dense bool Tensor of shape [batch_size, max_length]. The mask of
264-
prompt.
265-
num_steps: int. The remaining number of tokens to generate.
266-
from_logits: bool, defaults to True. Indicate if the `token_probability_fn`
267-
returns logits. If False, `token_probability_fn` returns probability
268-
distributions.
269-
"""
270-
271-
Sampler.__doc__ = Sampler.__doc__.replace(
272-
"{{base_sampler_keyword_args}}", base_sampler_keyword_args
273-
)
274-
Sampler.__doc__ = Sampler.__doc__.replace(
275-
"{{call_keyword_docstring}}", call_keyword_docstring
276-
)
277-
Sampler.sample.__doc__ = Sampler.sample.__doc__.replace(
278-
"{{sample_keyword_docstring}}", sample_keyword_docstring
279-
)

keras_nlp/samplers/sampler_test.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def test_serialization(self):
2424
sampler = keras_nlp.samplers.Greedy()
2525
config = keras_nlp.samplers.serialize(sampler)
2626
expected_config = {
27-
"class_name": "Greedy",
27+
"class_name": "keras_nlp>Greedy",
2828
"config": {
2929
"jit_compile": True,
3030
},
@@ -37,11 +37,6 @@ def test_deserialization(self):
3737
sampler = keras_nlp.samplers.get(identifier)
3838
self.assertIsInstance(sampler, Greedy)
3939

40-
# Test string is not case-sensitive.
41-
identifier = "Greedy"
42-
sampler = keras_nlp.samplers.get(identifier)
43-
self.assertIsInstance(sampler, Greedy)
44-
4540
# Test dict identifier.
4641
original_sampler = keras_nlp.samplers.Greedy(jit_compile=False)
4742
config = keras_nlp.samplers.serialize(original_sampler)

0 commit comments

Comments
 (0)