Skip to content

Commit fe5a350

Browse files
authored
[Relax] add sample_indices in sampling (#16675)
1 parent 46aaf61 commit fe5a350

File tree

2 files changed

+163
-92
lines changed

2 files changed

+163
-92
lines changed

python/tvm/relax/frontend/nn/op.py

Lines changed: 105 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2057,7 +2057,12 @@ def cumsum(
20572057
return wrap_nested(_op.cumsum(data._expr, axis, dtype, exclusive), name)
20582058

20592059

2060-
def multinomial_from_uniform(prob: Tensor, uniform_sample: Tensor, dtype: str = "int64"):
2060+
def multinomial_from_uniform(
2061+
prob: Tensor,
2062+
uniform_sample: Tensor,
2063+
sample_indices: Optional[Tensor] = None,
2064+
dtype: str = "int64",
2065+
):
20612066
"""Returns a tensor where each row contains the index sampled from the multinomial
20622067
probability distribution located in the corresponding row of tensor prob.
20632068
@@ -2075,57 +2080,97 @@ def multinomial_from_uniform(prob: Tensor, uniform_sample: Tensor, dtype: str =
20752080
The sum of values in each row is 1, forming a valid distribution.
20762081
20772082
uniform_sample : Tensor
2078-
The uniformly sampled 2-D tensor with the shape (batch, 1).
2083+
The uniformly sampled 2-D tensor with the shape (n, 1).
20792084
Values range from 0 to 1, indicating probabilities sampled uniformly.
20802085
2086+
sample_indices : Optional[Tensor]
2087+
The 2-D tensor with the shape [n, 1], which indicates the specific
2088+
probability distribution to sample from. The value of sample_indices[i]
2089+
determines that the ith token should be sampled from the sample_indices[i]th
2090+
probability distribution. For instance, if there are 3 distinct probability
2091+
distributions and the requirement is to sample 2, 3, and 4 tokens from each,
2092+
then sample_indices would be [0, 0, 1, 1, 1, 2, 2, 2, 2].
2093+
2094+
dtype : str
2095+
The data type of output tensor.
2096+
2097+
20812098
Returns
20822099
-------
20832100
result : Tensor
2084-
The computed tensor with shape (batch, 1).
2101+
The computed tensor with shape (n, 1).
20852102
20862103
Examples
20872104
--------
20882105
.. code-block:: python
20892106
20902107
prob = [[0.2, 0.3, 0.5], [0.3, 0.4, 0.3]]
20912108
usample = [[0.4], [0.9]]
2109+
sample_indices = [[0], [1]]
20922110
20932111
multinomial_from_uniform(prob, usample)
20942112
-> [[1], [2]]
2113+
multinomial_from_uniform(prob, usample, sample_indices)
2114+
-> [[1], [2]]
20952115
"""
20962116
prob_dtype = prob.dtype
20972117
sample_dtype = uniform_sample.dtype
2098-
batch = prob.shape[0]
2118+
out_batch = uniform_sample.shape[0]
2119+
2120+
if sample_indices is not None:
2121+
assert (
2122+
sample_indices.shape == uniform_sample.shape
2123+
), "The shape of sample_indices must match the shape of uniform_sample."
2124+
else:
2125+
assert (
2126+
prob.shape[0] == uniform_sample.shape[0]
2127+
), "Number of samples must match the number of probability distributions."
2128+
sample_indices = Tensor.from_const(np.arange(out_batch).reshape(out_batch, 1))
2129+
2130+
sample_indices_dtype = sample_indices.dtype
20992131

21002132
@T.prim_func(private=True)
2101-
def _get_sample_index(A: T.handle, B: T.handle, C: T.handle):
2133+
def _get_sample_index(A: T.handle, B: T.handle, C: T.handle, D: T.handle):
21022134
batch, vocab_size = T.int64(), T.int64()
21032135
prob = T.match_buffer(A, (batch, vocab_size), prob_dtype)
2104-
usample = T.match_buffer(B, (batch, 1), sample_dtype)
2105-
output_index = T.match_buffer(C, (batch, 1), dtype)
2136+
out_batch = T.int64()
2137+
usample = T.match_buffer(B, (out_batch, 1), sample_dtype)
2138+
sample_indices = T.match_buffer(C, (out_batch, 1), sample_indices_dtype)
2139+
output_index = T.match_buffer(D, (out_batch, 1), dtype)
21062140

2107-
for ax0, ax1 in T.grid(batch, vocab_size):
2141+
for ax0, ax1 in T.grid(out_batch, vocab_size):
21082142
with T.block("T_get_sample_index"):
21092143
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
21102144
T.writes(output_index[v_ax0, 0])
2111-
if usample[v_ax0, T.int64(0)] < prob[v_ax0, v_ax1] or v_ax1 + 1 == vocab_size:
2145+
if (
2146+
usample[v_ax0, T.int64(0)] < prob[sample_indices[v_ax0, T.int64(0)], v_ax1]
2147+
or v_ax1 + 1 == vocab_size
2148+
):
21122149
if v_ax1 == 0:
21132150
output_index[v_ax0, 0] = 0
2114-
elif usample[v_ax0, T.int64(0)] >= prob[v_ax0, v_ax1 - 1]:
2151+
elif (
2152+
usample[v_ax0, T.int64(0)]
2153+
>= prob[sample_indices[v_ax0, T.int64(0)], v_ax1 - 1]
2154+
):
21152155
output_index[v_ax0, 0] = v_ax1
21162156

21172157
cumsum_prob = cumsum(prob, axis=1, exclusive=False)
21182158

21192159
return tensor_ir_op(
21202160
_get_sample_index,
21212161
"get_sample_index",
2122-
args=[cumsum_prob, uniform_sample],
2123-
out=Tensor.placeholder([batch, 1], dtype),
2162+
args=[cumsum_prob, uniform_sample, sample_indices],
2163+
out=Tensor.placeholder([out_batch, 1], dtype),
21242164
)
21252165

21262166

21272167
def sample_top_p_top_k_from_sorted_prob(
2128-
sorted_prob: Tensor, sorted_index: Tensor, top_p: Tensor, top_k: Tensor, uniform_sample: Tensor
2168+
sorted_prob: Tensor,
2169+
sorted_index: Tensor,
2170+
top_p: Tensor,
2171+
top_k: Tensor,
2172+
uniform_sample: Tensor,
2173+
sample_indices: Optional[Tensor] = None,
21292174
):
21302175
"""Samples indices from a sorted probability tensor based on top_p and top_k criteria.
21312176
@@ -2152,12 +2197,20 @@ def sample_top_p_top_k_from_sorted_prob(
21522197
to consider for top-k sampling.
21532198
21542199
uniform_sample : Tensor
2155-
Uniformly sampled values with shape (batch, 1) are used to select the output indices.
2200+
Uniformly sampled values with shape (n, 1) are used to select the output indices.
2201+
2202+
sample_indices : Optional[Tensor]
2203+
The 2-D tensor with the shape [n, 1], which indicates the specific
2204+
probability distribution to sample from. The value of sample_indices[i]
2205+
determines that the ith token should be sampled from the sample_indices[i]th
2206+
probability distribution. For instance, if there are 3 distinct probability
2207+
distributions and the requirement is to sample 2, 3, and 4 tokens from each,
2208+
then sample_indices would be [0, 0, 1, 1, 1, 2, 2, 2, 2].
21562209
21572210
Returns
21582211
-------
21592212
result : Tensor
2160-
The selected indices with shape (batch, 1).
2213+
The selected indices with shape (n, 1).
21612214
21622215
Examples
21632216
--------
@@ -2172,15 +2225,31 @@ def sample_top_p_top_k_from_sorted_prob(
21722225
top_p = [[0.6],[0.9]]
21732226
top_k = [[3],[2]]
21742227
uniform_sample = [[0.5], [0.6]]
2228+
sample_indices = [[0], [1]]
21752229
21762230
sample_top_p_top_k_from_sorted_prob(
2177-
sorted_prob, sorted_index,top_p, top_k, uniform_sample)
2231+
sorted_prob, sorted_index,top_p, top_k, uniform_sample, sample_indices)
21782232
-> [2, 0]
21792233
21802234
"""
21812235
prob_dtype = sorted_prob.dtype
21822236
index_dtype = sorted_index.dtype
2183-
batch = sorted_prob.shape[0]
2237+
prob_batch = sorted_prob.shape[0]
2238+
out_batch = uniform_sample.shape[0]
2239+
2240+
if sample_indices is not None:
2241+
assert (
2242+
sample_indices.shape == uniform_sample.shape
2243+
), "The shape of sample_indices must match the shape of uniform_sample."
2244+
else:
2245+
assert (
2246+
sorted_prob.shape[0] == uniform_sample.shape[0]
2247+
), "Number of samples must match the number of probability distributions."
2248+
sample_indices = Tensor.from_const(
2249+
np.arange(out_batch).reshape(out_batch, 1).astype(np.int64)
2250+
)
2251+
print("sample_indices: ", sample_indices)
2252+
sample_indices_dtype = sample_indices.dtype
21842253

21852254
def _cumsum_mask(cumsum_sorted, top_p, top_k, i, j):
21862255
return _tir.all(cumsum_sorted[i, j] < top_p[i, 0], j + 1 < top_k[i, 0])
@@ -2204,27 +2273,34 @@ def _get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle):
22042273
renorm_prob[v_ax0, 0] = cumsum_sorted[v_ax0, v_ax1 + 1]
22052274

22062275
@T.prim_func(private=True)
2207-
def _get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle):
2276+
def _get_index_from_sorted(
2277+
A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.handle, F: T.handle
2278+
):
22082279
batch, vocab_size = T.int64(), T.int64()
2280+
out_batch = T.int64()
22092281
cumsum_sorted = T.match_buffer(A, (batch, vocab_size), prob_dtype)
2210-
renorm_prob = T.match_buffer(B, (batch, 1), prob_dtype)
2211-
usample = T.match_buffer(C, (batch, 1), prob_dtype)
2212-
indices = T.match_buffer(D, (batch, vocab_size), index_dtype)
2213-
output_index = T.match_buffer(E, (batch, 1), index_dtype)
2282+
indices = T.match_buffer(B, (batch, vocab_size), index_dtype)
2283+
renorm_prob = T.match_buffer(C, (batch, 1), prob_dtype)
2284+
usample = T.match_buffer(D, (out_batch, 1), prob_dtype)
2285+
sample_indices = T.match_buffer(E, (out_batch, 1), sample_indices_dtype)
2286+
output_index = T.match_buffer(F, (out_batch, 1), index_dtype)
22142287

2215-
for ax0, ax1 in T.grid(batch, vocab_size):
2288+
for ax0, ax1 in T.grid(out_batch, vocab_size):
22162289
with T.block("T_get_index_from_sorted"):
22172290
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
22182291
T.writes(output_index[v_ax0, 0])
22192292
if (
2220-
usample[v_ax0, T.int64(0)] < cumsum_sorted[v_ax0, v_ax1] / renorm_prob[v_ax0, 0]
2293+
usample[v_ax0, T.int64(0)]
2294+
< cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1]
2295+
/ renorm_prob[sample_indices[v_ax0, T.int64(0)], 0]
22212296
or v_ax1 + 1 == vocab_size
22222297
):
22232298
if v_ax1 == 0:
22242299
output_index[v_ax0, 0] = indices[v_ax0, 0]
22252300
elif (
22262301
usample[v_ax0, T.int64(0)]
2227-
>= cumsum_sorted[v_ax0, v_ax1 - 1] / renorm_prob[v_ax0, 0]
2302+
>= cumsum_sorted[sample_indices[v_ax0, T.int64(0)], v_ax1 - 1]
2303+
/ renorm_prob[sample_indices[v_ax0, T.int64(0)], 0]
22282304
):
22292305
output_index[v_ax0, 0] = indices[v_ax0, v_ax1]
22302306

@@ -2235,16 +2311,16 @@ def _get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E
22352311
"get_renorm_prob",
22362312
args=[cumsum_sorted, top_p, top_k],
22372313
out=Tensor.placeholder(
2238-
[batch, 1],
2314+
[prob_batch, 1],
22392315
prob_dtype,
22402316
),
22412317
)
22422318

22432319
out_index_in_sorted = tensor_ir_op(
22442320
_get_index_from_sorted,
22452321
"get_index_from_sorted",
2246-
args=[cumsum_sorted, renorm_prob, uniform_sample, sorted_index],
2247-
out=Tensor.placeholder([batch, 1], index_dtype),
2322+
args=[cumsum_sorted, sorted_index, renorm_prob, uniform_sample, sample_indices],
2323+
out=Tensor.placeholder([out_batch, 1], index_dtype),
22482324
)
22492325
return out_index_in_sorted
22502326

@@ -2293,7 +2369,7 @@ def _get_renorm_cutoff(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.
22932369
top_k = T.match_buffer(D, (batch, 1), top_k_dtype)
22942370
cutoff = T.match_buffer(E, (batch, 1), prob_dtype)
22952371
for ax0, ax1 in T.grid(batch, vocab_size):
2296-
with T.block("T_get_renorm_prob"):
2372+
with T.block("T_get_renorm_cutoff"):
22972373
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
22982374
if _cumsum_mask(cumsum_sorted, top_p, top_k, v_ax0, 0) == 0:
22992375
cutoff[v_ax0, 0] = sorted_prob[v_ax0, 0]

0 commit comments

Comments
 (0)