Skip to content

Commit caa6754

Browse files
Address comments: move token_probability_fn to the second place
1 parent 8323153 commit caa6754

File tree

4 files changed

+23
-26
lines changed

4 files changed

+23
-26
lines changed

keras_nlp/samplers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def get(identifier):
4646
dict containing `class_name` and `config` as an identifier. Also note that
4747
the `class_name` must map to a `Sampler` class.
4848
49-
>>> cfg = {'class_name': 'Greedy', 'config': {}}
49+
>>> cfg = {'class_name': 'keras_nlp>Greedy', 'config': {}}
5050
>>> sampler = keras_nlp.samplers.get(cfg)
5151
5252
In the case that the `identifier` is a class, this method will return a new

keras_nlp/samplers/greedy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def token_probability_fn(inputs, mask):
6767
6868
sampler = keras_nlp.samplers.Greedy()
6969
# Print the generated sequence (token ids).
70-
print(sampler(token_probability_fn, prompt, 10))
70+
print(sampler(prompt, token_probability_fn, 10))
7171
```
7272
"""
7373

@@ -79,7 +79,7 @@ def __init__(
7979

8080
@format_docstring(sample_args=sample_args_docstring)
8181
def sample(
82-
self, token_probability_fn, prompt, mask, num_steps, from_logits=True
82+
self, prompt, token_probability_fn, mask, num_steps, from_logits=True
8383
):
8484
"""Sampling logic implementation.
8585

keras_nlp/samplers/greedy_test.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,17 @@ def token_probability_fn(inputs, mask):
4848

4949
def test_generate_with_1d_prompt(self):
5050
inputs = tf.constant([1])
51-
outputs = self.sampler(self.token_probability_fn, inputs, max_length=5)
51+
outputs = self.sampler(inputs, self.token_probability_fn, max_length=5)
5252
self.assertEqual(outputs.shape, [5])
5353

5454
def test_generate_with_2d_prompt(self):
5555
inputs = tf.constant([[1], [1]])
56-
outputs = self.sampler(self.token_probability_fn, inputs, max_length=5)
56+
outputs = self.sampler(inputs, self.token_probability_fn, max_length=5)
5757
self.assertEqual(outputs.shape, [2, 5])
5858

5959
def test_generate_with_list_prompt(self):
6060
inputs = [[1], [1]]
61-
outputs = self.sampler(self.token_probability_fn, inputs, max_length=5)
61+
outputs = self.sampler(inputs, self.token_probability_fn, max_length=5)
6262
self.assertEqual(outputs.shape, [2, 5])
6363

6464
def test_generate_with_ragged_prompt(self):
@@ -71,7 +71,7 @@ def token_probability_fn(inputs, mask):
7171
return tf.repeat(tf.repeat(prob, 2, axis=0), max_length, axis=1)
7272

7373
inputs = tf.ragged.constant([[1], [2, 1, 2]])
74-
outputs = self.sampler(token_probability_fn, inputs, max_length)
74+
outputs = self.sampler(inputs, token_probability_fn, max_length)
7575
self.assertEqual(outputs.shape, [2, 5])
7676

7777
def test_assert_generation_is_correct(self):
@@ -86,7 +86,7 @@ def token_probability_fn(inputs, mask):
8686

8787
inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32)
8888
outputs = self.sampler(
89-
token_probability_fn, inputs, max_length=max_length
89+
inputs, token_probability_fn, max_length=max_length
9090
)
9191
self.assertAllEqual(
9292
outputs, 3 * tf.ones(shape=[batch_size, max_length])
@@ -105,8 +105,8 @@ def token_probability_fn(inputs, mask):
105105
sampler = Greedy()
106106
inputs = tf.constant([[0, 1], [1, 2]])
107107
outputs = sampler(
108-
token_probability_fn,
109108
inputs,
109+
token_probability_fn,
110110
max_length=max_length,
111111
end_token_id=2,
112112
)
@@ -117,12 +117,12 @@ def test_compare_xla_noxla_results(self):
117117
inputs = [[1], [1]]
118118
xla_sampler = Greedy(jit_compile=True)
119119
outputs_xla = xla_sampler(
120-
self.token_probability_fn, inputs, max_length=5
120+
inputs, self.token_probability_fn, max_length=5
121121
)
122122

123123
xla_sampler = Greedy(jit_compile=False)
124124
outputs_no_xla = xla_sampler(
125-
self.token_probability_fn, inputs, max_length=5
125+
inputs, self.token_probability_fn, max_length=5
126126
)
127127

128128
self.assertAllEqual(outputs_xla, outputs_no_xla)

keras_nlp/samplers/sampler.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@
2323
"""
2424

2525
call_args_docstring = """
26-
token_probability_fn: a function that generates the probability of
27-
the next token over the whole vocabulary for each input token.
2826
prompt: a list of integers or an integer Tensor, can be 1D or 2D. The
2927
initial tokens to append generated tokens.
28+
token_probability_fn: a function that generates the probability of
29+
the next token over the whole vocabulary for each input token.
3030
max_length: int. The max length of generated sequence.
31-
padding_mask: a tensor, defaults to None. The padding mask of the prompt.
31+
mask: a tensor, defaults to None. The padding mask of the prompt.
3232
end_token_id: int, defaults to None. The token marking the end of the
3333
sequence, once encountered the generation is finished for the exact
3434
sequence. If None, every sequence is generated up to `max_length`.
@@ -40,10 +40,10 @@
4040
"""
4141

4242
sample_args_docstring = """
43-
token_probability_fn: a function that generates the probability of
44-
the next token over the whole vocabulary for each input token.
4543
prompt: a dense int Tensor of shape [batch_size, max_length]. The
4644
placeholder for generated sequence.
45+
token_probability_fn: a function that generates the probability of
46+
the next token over the whole vocabulary for each input token.
4747
mask: a dense bool Tensor of shape [batch_size, max_length]. The mask of
4848
prompt.
4949
num_steps: int. The remaining number of tokens to generate.
@@ -100,7 +100,7 @@ def token_probability_fn(inputs, mask):
100100
101101
sampler = keras_nlp.samplers.Greedy()
102102
# Print the generated sequence (token ids).
103-
print(sampler(token_probability_fn, prompt, 10, end_token_id=END_ID))
103+
print(sampler(prompt, token_probability_fn, 10, end_token_id=END_ID))
104104
```
105105
106106
Use with string inputs:
@@ -131,8 +131,8 @@ def token_probability_fn(inputs, mask):
131131
prompt = tokenizer("the quick brown fox")
132132
sampler = keras_nlp.samplers.Greedy()
133133
generated = sampler(
134-
token_probability_fn,
135134
prompt,
135+
token_probability_fn,
136136
10,
137137
end_token_id=tokenizer.token_to_id("[END]")
138138
)
@@ -210,17 +210,14 @@ def _mask_tokens_after_end_token(
210210

211211
def __call__(
212212
self,
213-
token_probability_fn,
214213
prompt,
214+
token_probability_fn,
215215
max_length,
216-
padding_mask=None,
216+
mask=None,
217217
end_token_id=None,
218218
from_logits=True,
219219
):
220-
prompt, padding_mask = self._validate_prompt_and_mask(
221-
prompt,
222-
padding_mask,
223-
)
220+
prompt, mask = self._validate_prompt_and_mask(prompt, mask)
224221

225222
input_is_1d = prompt.shape.rank == 1
226223
if input_is_1d:
@@ -238,8 +235,8 @@ def __call__(
238235
# `jit_compile` accordingly.
239236
sample = tf.function(self.sample, jit_compile=self.jit_compile)
240237
prompt = sample(
241-
token_probability_fn,
242238
prompt,
239+
token_probability_fn,
243240
mask,
244241
max_length - shortest_prompt_len,
245242
from_logits,
@@ -257,7 +254,7 @@ def __call__(
257254

258255
@format_docstring(sample_args=sample_args_docstring)
259256
def sample(
260-
self, token_probability_fn, prompt, mask, num_steps, from_logits=True
257+
self, prompt, token_probability_fn, mask, num_steps, from_logits=True
261258
):
262259
"""Sampling logic implementation.
263260

0 commit comments

Comments
 (0)