1616import tensorflow as tf
1717from tensorflow import keras
1818
19+ from keras_nlp .utils .python_utils import format_docstring
1920
21+ base_sampler_args_docstring = """
22+ jit_compile: bool, defaults to True. If True, XLA compilation will be used.
23+ """
24+
25+ call_args_docstring = """
26+ token_probability_fn: a function that generates the probability of
27+ the next token over the whole vocabulary for each input token.
28+ prompt: a list of integers or an integer Tensor, can be 1D or 2D. The
29+ initial tokens to append generated tokens.
30+ max_length: int. The max length of generated sequence.
31+ padding_mask: a tensor, defaults to None. The padding mask of the prompt.
32+ end_token_id: int, defaults to None. The token marking the end of the
33+ sequence, once encountered the generation is finished for the exact
34+ sequence. If None, every sequence is generated up to `max_length`.
35+ If set, all tokens after encountering `end_token_id` will be
36+ replaced with `pad_token_id`.
37+ from_logits: bool, defaults to True. Indicate if the `token_probability_fn`
38+ returns logits. If False, `token_probability_fn` returns probability
39+ distributions.
40+ """
41+
42+ sample_args_docstring = """
43+ token_probability_fn: a function that generates the probability of
44+ the next token over the whole vocabulary for each input token.
45+ prompt: a dense int Tensor of shape [batch_size, max_length]. The
46+ placeholder for generated sequence.
47+ mask: a dense bool Tensor of shape [batch_size, max_length]. The mask of
48+ prompt.
49+ num_steps: int. The remaining number of tokens to generate.
50+ from_logits: bool, defaults to True. Indicate if the `token_probability_fn`
51+ returns logits. If False, `token_probability_fn` returns probability
52+ distributions.
53+ """
54+
55+
56+ @format_docstring (
57+ base_sampler_args = base_sampler_args_docstring , call_args = call_args_docstring
58+ )
2059@keras .utils .register_keras_serializable (package = "keras_nlp" )
2160class Sampler :
2261 """Base sampler class.
2362
2463 Args:
25- {{base_optimizer_keyword_args }}
64+ {{base_sampler_args }}
2665
2766 Call Args:
28- {{call_keyword_docstring }}
67+ {{call_args }}
2968
3069 The inputs and outputs of Sampler class are both token ids.
3170
@@ -39,7 +78,8 @@ class Sampler:
3978 START_ID = 1
4079 END_ID = 2
4180
42- # Create a dummy model to predict the next token.
81+ # Create a dummy model to predict the next token. Note that the output is
82+ # random without training, here we jsut demo how `samplers` works.
4383 model = keras.Sequential(
4484 [
4585 keras.Input(shape=[None]),
@@ -178,7 +218,8 @@ def __call__(
178218 from_logits = True ,
179219 ):
180220 prompt , padding_mask = self ._validate_prompt_and_mask (
181- prompt , padding_mask
221+ prompt ,
222+ padding_mask ,
182223 )
183224
184225 input_is_1d = prompt .shape .rank == 1
@@ -214,13 +255,14 @@ def __call__(
214255
215256 return tf .squeeze (prompt , axis = 0 ) if input_is_1d else prompt
216257
258+ @format_docstring (sample_args = sample_args_docstring )
217259 def sample (
218260 self , token_probability_fn , prompt , mask , num_steps , from_logits = True
219261 ):
220262 """Sampling logic implementation.
221263
222264 Args:
223- {{sample_keyword_docstring }}
265+ {{sample_args }}
224266
225267 Returns:
226268 A dense int Tensor, representing the generated text in token id
@@ -232,48 +274,3 @@ def get_config(self):
232274 return {
233275 "jit_compile" : self .jit_compile ,
234276 }
235-
236-
237- base_sampler_keyword_args = """
238- jit_compile: bool, defaults to True. If True, XLA compilation will be used.
239- """
240-
241- call_keyword_docstring = """
242- token_probability_fn: a function that generates the probability of
243- the next token over the whole vocabulary for each input token.
244- prompt: a list of integers or an integer Tensor, can be 1D or 2D. The
245- initial tokens to append generated tokens.
246- max_length: int. The max length of generated sequence.
247- padding_mask: a tensor, defaults to None. The padding mask of the prompt.
248- end_token_id: int, defaults to None. The token marking the end of the
249- sequence, once encountered the generation is finished for the exact
250- sequence. If None, every sequence is generated up to `max_length`.
251- If set, all tokens after encountering `end_token_id` will be
252- replaced with `pad_token_id`.
253- from_logits: bool, defaults to True. Indicate if the `token_probability_fn`
254- returns logits. If False, `token_probability_fn` returns probability
255- distributions.
256- """
257-
258- sample_keyword_docstring = """
259- token_probability_fn: a function that generates the probability of
260- the next token over the whole vocabulary for each input token.
261- prompt: a dense int Tensor of shape [batch_size, max_length]. The
262- placeholder for generated sequence.
263- mask: a dense bool Tensor of shape [batch_size, max_length]. The mask of
264- prompt.
265- num_steps: int. The remaining number of tokens to generate.
266- from_logits: bool, defaults to True. Indicate if the `token_probability_fn`
267- returns logits. If False, `token_probability_fn` returns probability
268- distributions.
269- """
270-
271- Sampler .__doc__ = Sampler .__doc__ .replace (
272- "{{base_sampler_keyword_args}}" , base_sampler_keyword_args
273- )
274- Sampler .__doc__ = Sampler .__doc__ .replace (
275- "{{call_keyword_docstring}}" , call_keyword_docstring
276- )
277- Sampler .sample .__doc__ = Sampler .sample .__doc__ .replace (
278- "{{sample_keyword_docstring}}" , sample_keyword_docstring
279- )
0 commit comments