Skip to content

Commit 513121e

Browse files
more samplers
1 parent e6483a4 commit 513121e

File tree

3 files changed

+414
-0
lines changed

3 files changed

+414
-0
lines changed

keras_nlp/samplers/beam_sampler.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# Copyright 2022 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Greedy Sampler."""
15+
16+
import tensorflow as tf
17+
from tensorflow import keras
18+
19+
from keras_nlp.samplers.sampler import Sampler
20+
from keras_nlp.samplers.sampler import base_sampler_keyword_args
21+
from keras_nlp.samplers.sampler import call_keyword_docstring
22+
from keras_nlp.samplers.sampler import sample_keyword_docstring
23+
24+
25+
class BeamSampler(Sampler):
26+
"""Beam Sampler class.
27+
28+
This sampler implements beam search algorithm.
29+
30+
Args:
31+
{{base_sampler_keyword_args}}
32+
33+
Call Args:
34+
{{call_keyword_args}}
35+
"""
36+
37+
def __init__(
38+
self,
39+
num_beams,
40+
seed=None,
41+
from_logits=False,
42+
end_token_id=None,
43+
pad_token_id=0,
44+
jit_compile=True,
45+
):
46+
self.num_beams = num_beams
47+
self.seed = seed
48+
self.from_logits = from_logits
49+
super().__init__(end_token_id, pad_token_id, jit_compile)
50+
51+
def sample(self, token_probability_fn, prompt, mask, num_steps):
52+
"""Sampler's logic implementation.
53+
54+
Args:
55+
{{call_keyword_docstring}}
56+
"""
57+
batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1]
58+
max_length = tf.cast(max_length, num_steps.dtype)
59+
length = max_length - num_steps
60+
dummy_preds = self._validate_token_probability_fn(
61+
token_probability_fn, prompt, mask
62+
)
63+
vocab_size = dummy_preds.shape[-1]
64+
pred_dtype = dummy_preds.dtype
65+
66+
num_beams = self.num_beams
67+
68+
# Initialize beam with shape `(batch_size, num_beams, length)`.
69+
beams = tf.repeat(tf.expand_dims(prompt, axis=1), num_beams, axis=1)
70+
# Initialize `beams_prob` with shape `(batch_size, num_beams)`.
71+
beams_prob = tf.zeros([batch_size, 1], dtype=pred_dtype)
72+
beams_prob = tf.concat(
73+
[beams_prob, tf.fill((batch_size, num_beams - 1), pred_dtype.min)],
74+
axis=-1,
75+
)
76+
77+
def one_step(beams, beams_prob, length):
78+
truncated_beams = beams[..., :length]
79+
80+
flattened_beams = tf.reshape(
81+
truncated_beams, shape=[batch_size * num_beams, -1]
82+
)
83+
preds = token_probability_fn(flattened_beams)
84+
if self.from_logits:
85+
preds = keras.activations.softmax(preds, axis=-1)
86+
# Reshape `preds` to shape `(batch_size, num_beams * vocab_size)`.
87+
preds = tf.reshape(preds, shape=[batch_size, -1])
88+
89+
probs = tf.math.log(preds) + tf.repeat(
90+
beams_prob, repeats=vocab_size, axis=1
91+
)
92+
93+
candidate_prob, candidate_indexes = tf.math.top_k(
94+
probs, k=num_beams, sorted=False
95+
)
96+
candidate_beam_indexes = candidate_indexes // vocab_size
97+
next_token = candidate_indexes % vocab_size
98+
99+
beams = tf.gather(
100+
beams, candidate_beam_indexes, axis=1, batch_dims=1
101+
)
102+
103+
# Build a new column of updates to scatter into the beam tensor.
104+
next_token = tf.where(
105+
condition=mask[..., length, tf.newaxis],
106+
x=beams[..., length],
107+
y=next_token,
108+
)
109+
next_token = tf.reshape(next_token, shape=[-1])
110+
111+
# Generate `(batch_index, beam_index)` tuples for each beam.
112+
beam_indices = tf.where(tf.ones((batch_size, num_beams), tf.bool))
113+
beam_indices = tf.cast(beam_indices, dtype=length.dtype)
114+
# Build a tensor of repeated `length` values.
115+
length_indices = tf.fill((batch_size * num_beams, 1), length)
116+
# Concatenate to a triplet of `(batch_index, beam_index, length)`.
117+
indices = tf.concat([beam_indices, length_indices], axis=-1)
118+
119+
# Update `beams[:, :, length]` with `next_token`.
120+
beams = tf.tensor_scatter_nd_update(
121+
tensor=beams,
122+
indices=indices,
123+
updates=next_token,
124+
)
125+
126+
beams_prob = candidate_prob
127+
length = tf.add(length, 1)
128+
129+
return beams, beams_prob, length
130+
131+
# Run a while loop till text of length `max_length` has been generated.
132+
beams, beams_prob, length = tf.while_loop(
133+
cond=lambda beams, beams_prob, length: tf.less(length, max_length),
134+
body=one_step,
135+
loop_vars=(beams, beams_prob, length),
136+
)
137+
138+
# Get the beam with the maximum probability.
139+
max_indexes = tf.math.argmax(beams_prob, axis=-1)
140+
max_beams = tf.gather(
141+
beams, max_indexes[:, tf.newaxis], axis=1, batch_dims=1
142+
)
143+
prompt = tf.squeeze(max_beams)
144+
145+
return prompt
146+
147+
148+
BeamSampler.__doc__ = BeamSampler.__doc__.replace(
149+
"{{base_sampler_keyword_args}}", base_sampler_keyword_args
150+
)
151+
BeamSampler.__doc__ = BeamSampler.__doc__.replace(
152+
"{{call_keyword_docstring}}", call_keyword_docstring
153+
)
154+
BeamSampler.sample.__doc__ = BeamSampler.sample.__doc__.replace(
155+
"{{sample_keyword_docstring}}", sample_keyword_docstring
156+
)

keras_nlp/samplers/top_k_sampler

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Copyright 2022 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Greedy Sampler."""
15+
16+
import tensorflow as tf
17+
18+
from keras_nlp.samplers.sampler import Sampler
19+
from keras_nlp.samplers.sampler import base_sampler_keyword_args
20+
from keras_nlp.samplers.sampler import call_keyword_docstring
21+
from keras_nlp.samplers.sampler import sample_keyword_docstring
22+
23+
24+
class TopKSampler(Sampler):
25+
"""Top-K Sampler class.
26+
27+
This sampler implements top-k search algorithm.
28+
29+
Args:
30+
{{base_sampler_keyword_args}}
31+
32+
Call Args:
33+
{{call_keyword_args}}
34+
"""
35+
36+
def __init__(
37+
self,
38+
k,
39+
seed=None,
40+
from_logits=False,
41+
end_token_id=None,
42+
pad_token_id=0,
43+
jit_compile=True,
44+
):
45+
self.k = k
46+
self.seed = seed
47+
self.from_logits = from_logits
48+
super().__init__(end_token_id, pad_token_id, jit_compile)
49+
50+
def sample(self, token_probability_fn, prompt, mask, num_steps):
51+
"""Sampler's logic implementation.
52+
53+
Args:
54+
{{call_keyword_docstring}}
55+
"""
56+
batch_size, max_length = tf.shape(prompt)[0], tf.shape(prompt)[1]
57+
max_length = tf.cast(max_length, num_steps.dtype)
58+
length = max_length - num_steps
59+
60+
def one_step(length, prompt, mask):
61+
probs = token_probability_fn(prompt, mask)
62+
pred = tf.gather(
63+
probs, tf.repeat(length - 1, batch_size), axis=1, batch_dims=1
64+
)
65+
if self.from_logits:
66+
pred = keras.activations.softmax(pred, axis=-1)
67+
68+
# Filter out top-k tokens.
69+
top_k_pred, top_k_indices = tf.math.top_k(
70+
pred, k=self.k, sorted=False
71+
)
72+
# Sample the next token from the probability distribution.
73+
next_token = tf.random.categorical(
74+
tf.math.log(top_k_pred), 1, seed=self.seed
75+
)
76+
77+
# Rearrange to get the next token idx from the original order.
78+
next_token = tf.gather_nd(top_k_indices, next_token, batch_dims=1)
79+
next_token = tf.cast(next_token, dtype=prompt.dtype)
80+
next_token = tf.where(
81+
mask[:, length], prompt[:, length], next_token
82+
)
83+
84+
mask = tf.tensor_scatter_nd_update(
85+
tensor=mask,
86+
indices=tf.stack(
87+
(
88+
tf.cast(tf.range(batch_size), dtype=length.dtype),
89+
tf.repeat(length, batch_size),
90+
),
91+
axis=1,
92+
),
93+
updates=tf.repeat(True, batch_size),
94+
)
95+
96+
# Append the next token to current sequence.
97+
prompt = tf.tensor_scatter_nd_update(
98+
tensor=prompt,
99+
indices=tf.stack(
100+
(
101+
tf.cast(tf.range(batch_size), dtype=length.dtype),
102+
tf.repeat(length, batch_size),
103+
),
104+
axis=1,
105+
),
106+
updates=next_token,
107+
)
108+
109+
length = tf.add(length, 1)
110+
return (length, prompt, mask)
111+
112+
# Run a while loop till text of length `max_length` has been generated.
113+
length, prompt, mask = tf.while_loop(
114+
cond=lambda length, prompt, mask: tf.less(length, max_length),
115+
body=one_step,
116+
loop_vars=(length, prompt, mask),
117+
)
118+
119+
return prompt
120+
121+
122+
TopKSampler.__doc__ = TopKSampler.__doc__.replace(
123+
"{{base_sampler_keyword_args}}", base_sampler_keyword_args
124+
)
125+
TopKSampler.__doc__ = TopKSampler.__doc__.replace(
126+
"{{call_keyword_docstring}}", call_keyword_docstring
127+
)
128+
TopKSampler.sample.__doc__ = TopKSampler.sample.__doc__.replace(
129+
"{{sample_keyword_docstring}}", sample_keyword_docstring
130+
)

0 commit comments

Comments
 (0)