@@ -2057,7 +2057,12 @@ def cumsum(
20572057 return wrap_nested (_op .cumsum (data ._expr , axis , dtype , exclusive ), name )
20582058
20592059
2060- def multinomial_from_uniform (prob : Tensor , uniform_sample : Tensor , dtype : str = "int64" ):
2060+ def multinomial_from_uniform (
2061+ prob : Tensor ,
2062+ uniform_sample : Tensor ,
2063+ sample_indices : Optional [Tensor ] = None ,
2064+ dtype : str = "int64" ,
2065+ ):
20612066 """Returns a tensor where each row contains the index sampled from the multinomial
20622067 probability distribution located in the corresponding row of tensor prob.
20632068
@@ -2075,57 +2080,97 @@ def multinomial_from_uniform(prob: Tensor, uniform_sample: Tensor, dtype: str =
20752080 The sum of values in each row is 1, forming a valid distribution.
20762081
20772082 uniform_sample : Tensor
2078- The uniformly sampled 2-D tensor with the shape (batch , 1).
2083+ The uniformly sampled 2-D tensor with the shape (n , 1).
20792084 Values range from 0 to 1, indicating probabilities sampled uniformly.
20802085
2086+ sample_indices : Optional[Tensor]
2087+ The 2-D tensor with the shape [n, 1], which indicates the specific
2088+ probability distribution to sample from. The value of sample_indices[i]
2089+ determines that the ith token should be sampled from the sample_indices[i]th
2090+ probability distribution. For instance, if there are 3 distinct probability
2091+ distributions and the requirement is to sample 2, 3, and 4 tokens from each,
2092+ then sample_indices would be [0, 0, 1, 1, 1, 2, 2, 2, 2].
2093+
2094+ dtype : str
2095+ The data type of output tensor.
2096+
2097+
20812098 Returns
20822099 -------
20832100 result : Tensor
2084- The computed tensor with shape (batch , 1).
2101+ The computed tensor with shape (n , 1).
20852102
20862103 Examples
20872104 --------
20882105 .. code-block:: python
20892106
20902107 prob = [[0.2, 0.3, 0.5], [0.3, 0.4, 0.3]]
20912108 usample = [[0.4], [0.9]]
2109+ sample_indices = [[0], [1]]
20922110
20932111 multinomial_from_uniform(prob, usample)
20942112 -> [[1], [2]]
2113+ multinomial_from_uniform(prob, usample, sample_indices)
2114+ -> [[1], [2]]
20952115 """
20962116 prob_dtype = prob .dtype
20972117 sample_dtype = uniform_sample .dtype
2098- batch = prob .shape [0 ]
2118+ out_batch = uniform_sample .shape [0 ]
2119+
2120+ if sample_indices is not None :
2121+ assert (
2122+ sample_indices .shape == uniform_sample .shape
2123+ ), "The shape of sample_indices must match the shape of uniform_sample."
2124+ else :
2125+ assert (
2126+ prob .shape [0 ] == uniform_sample .shape [0 ]
2127+ ), "Number of samples must match the number of probability distributions."
2128+ sample_indices = Tensor .from_const (np .arange (out_batch ).reshape (out_batch , 1 ))
2129+
2130+ sample_indices_dtype = sample_indices .dtype
20992131
21002132 @T .prim_func (private = True )
2101- def _get_sample_index (A : T .handle , B : T .handle , C : T .handle ):
2133+ def _get_sample_index (A : T .handle , B : T .handle , C : T .handle , D : T . handle ):
21022134 batch , vocab_size = T .int64 (), T .int64 ()
21032135 prob = T .match_buffer (A , (batch , vocab_size ), prob_dtype )
2104- usample = T .match_buffer (B , (batch , 1 ), sample_dtype )
2105- output_index = T .match_buffer (C , (batch , 1 ), dtype )
2136+ out_batch = T .int64 ()
2137+ usample = T .match_buffer (B , (out_batch , 1 ), sample_dtype )
2138+ sample_indices = T .match_buffer (C , (out_batch , 1 ), sample_indices_dtype )
2139+ output_index = T .match_buffer (D , (out_batch , 1 ), dtype )
21062140
2107- for ax0 , ax1 in T .grid (batch , vocab_size ):
2141+ for ax0 , ax1 in T .grid (out_batch , vocab_size ):
21082142 with T .block ("T_get_sample_index" ):
21092143 v_ax0 , v_ax1 = T .axis .remap ("SS" , [ax0 , ax1 ])
21102144 T .writes (output_index [v_ax0 , 0 ])
2111- if usample [v_ax0 , T .int64 (0 )] < prob [v_ax0 , v_ax1 ] or v_ax1 + 1 == vocab_size :
2145+ if (
2146+ usample [v_ax0 , T .int64 (0 )] < prob [sample_indices [v_ax0 , T .int64 (0 )], v_ax1 ]
2147+ or v_ax1 + 1 == vocab_size
2148+ ):
21122149 if v_ax1 == 0 :
21132150 output_index [v_ax0 , 0 ] = 0
2114- elif usample [v_ax0 , T .int64 (0 )] >= prob [v_ax0 , v_ax1 - 1 ]:
2151+ elif (
2152+ usample [v_ax0 , T .int64 (0 )]
2153+ >= prob [sample_indices [v_ax0 , T .int64 (0 )], v_ax1 - 1 ]
2154+ ):
21152155 output_index [v_ax0 , 0 ] = v_ax1
21162156
21172157 cumsum_prob = cumsum (prob , axis = 1 , exclusive = False )
21182158
21192159 return tensor_ir_op (
21202160 _get_sample_index ,
21212161 "get_sample_index" ,
2122- args = [cumsum_prob , uniform_sample ],
2123- out = Tensor .placeholder ([batch , 1 ], dtype ),
2162+ args = [cumsum_prob , uniform_sample , sample_indices ],
2163+ out = Tensor .placeholder ([out_batch , 1 ], dtype ),
21242164 )
21252165
21262166
21272167def sample_top_p_top_k_from_sorted_prob (
2128- sorted_prob : Tensor , sorted_index : Tensor , top_p : Tensor , top_k : Tensor , uniform_sample : Tensor
2168+ sorted_prob : Tensor ,
2169+ sorted_index : Tensor ,
2170+ top_p : Tensor ,
2171+ top_k : Tensor ,
2172+ uniform_sample : Tensor ,
2173+ sample_indices : Optional [Tensor ] = None ,
21292174):
21302175 """Samples indices from a sorted probability tensor based on top_p and top_k criteria.
21312176
@@ -2152,12 +2197,20 @@ def sample_top_p_top_k_from_sorted_prob(
21522197 to consider for top-k sampling.
21532198
21542199 uniform_sample : Tensor
2155- Uniformly sampled values with shape (batch, 1) are used to select the output indices.
2200+ Uniformly sampled values with shape (n, 1) are used to select the output indices.
2201+
2202+ sample_indices : Optional[Tensor]
2203+ The 2-D tensor with the shape [n, 1], which indicates the specific
2204+ probability distribution to sample from. The value of sample_indices[i]
2205+ determines that the ith token should be sampled from the sample_indices[i]th
2206+ probability distribution. For instance, if there are 3 distinct probability
2207+ distributions and the requirement is to sample 2, 3, and 4 tokens from each,
2208+ then sample_indices would be [0, 0, 1, 1, 1, 2, 2, 2, 2].
21562209
21572210 Returns
21582211 -------
21592212 result : Tensor
2160- The selected indices with shape (batch , 1).
2213+ The selected indices with shape (n , 1).
21612214
21622215 Examples
21632216 --------
@@ -2172,15 +2225,31 @@ def sample_top_p_top_k_from_sorted_prob(
21722225 top_p = [[0.6],[0.9]]
21732226 top_k = [[3],[2]]
21742227 uniform_sample = [[0.5], [0.6]]
2228+ sample_indices = [[0], [1]]
21752229
21762230 sample_top_p_top_k_from_sorted_prob(
2177- sorted_prob, sorted_index,top_p, top_k, uniform_sample)
2231+ sorted_prob, sorted_index,top_p, top_k, uniform_sample, sample_indices )
21782232 -> [2, 0]
21792233
21802234 """
21812235 prob_dtype = sorted_prob .dtype
21822236 index_dtype = sorted_index .dtype
2183- batch = sorted_prob .shape [0 ]
2237+ prob_batch = sorted_prob .shape [0 ]
2238+ out_batch = uniform_sample .shape [0 ]
2239+
2240+ if sample_indices is not None :
2241+ assert (
2242+ sample_indices .shape == uniform_sample .shape
2243+ ), "The shape of sample_indices must match the shape of uniform_sample."
2244+ else :
2245+ assert (
2246+ sorted_prob .shape [0 ] == uniform_sample .shape [0 ]
2247+ ), "Number of samples must match the number of probability distributions."
2248+ sample_indices = Tensor .from_const (
2249+ np .arange (out_batch ).reshape (out_batch , 1 ).astype (np .int64 )
2250+ )
2251+ print ("sample_indices: " , sample_indices )
2252+ sample_indices_dtype = sample_indices .dtype
21842253
21852254 def _cumsum_mask (cumsum_sorted , top_p , top_k , i , j ):
21862255 return _tir .all (cumsum_sorted [i , j ] < top_p [i , 0 ], j + 1 < top_k [i , 0 ])
@@ -2204,27 +2273,34 @@ def _get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle):
22042273 renorm_prob [v_ax0 , 0 ] = cumsum_sorted [v_ax0 , v_ax1 + 1 ]
22052274
22062275 @T .prim_func (private = True )
2207- def _get_index_from_sorted (A : T .handle , B : T .handle , C : T .handle , D : T .handle , E : T .handle ):
2276+ def _get_index_from_sorted (
2277+ A : T .handle , B : T .handle , C : T .handle , D : T .handle , E : T .handle , F : T .handle
2278+ ):
22082279 batch , vocab_size = T .int64 (), T .int64 ()
2280+ out_batch = T .int64 ()
22092281 cumsum_sorted = T .match_buffer (A , (batch , vocab_size ), prob_dtype )
2210- renorm_prob = T .match_buffer (B , (batch , 1 ), prob_dtype )
2211- usample = T .match_buffer (C , (batch , 1 ), prob_dtype )
2212- indices = T .match_buffer (D , (batch , vocab_size ), index_dtype )
2213- output_index = T .match_buffer (E , (batch , 1 ), index_dtype )
2282+ indices = T .match_buffer (B , (batch , vocab_size ), index_dtype )
2283+ renorm_prob = T .match_buffer (C , (batch , 1 ), prob_dtype )
2284+ usample = T .match_buffer (D , (out_batch , 1 ), prob_dtype )
2285+ sample_indices = T .match_buffer (E , (out_batch , 1 ), sample_indices_dtype )
2286+ output_index = T .match_buffer (F , (out_batch , 1 ), index_dtype )
22142287
2215- for ax0 , ax1 in T .grid (batch , vocab_size ):
2288+ for ax0 , ax1 in T .grid (out_batch , vocab_size ):
22162289 with T .block ("T_get_index_from_sorted" ):
22172290 v_ax0 , v_ax1 = T .axis .remap ("SS" , [ax0 , ax1 ])
22182291 T .writes (output_index [v_ax0 , 0 ])
22192292 if (
2220- usample [v_ax0 , T .int64 (0 )] < cumsum_sorted [v_ax0 , v_ax1 ] / renorm_prob [v_ax0 , 0 ]
2293+ usample [v_ax0 , T .int64 (0 )]
2294+ < cumsum_sorted [sample_indices [v_ax0 , T .int64 (0 )], v_ax1 ]
2295+ / renorm_prob [sample_indices [v_ax0 , T .int64 (0 )], 0 ]
22212296 or v_ax1 + 1 == vocab_size
22222297 ):
22232298 if v_ax1 == 0 :
22242299 output_index [v_ax0 , 0 ] = indices [v_ax0 , 0 ]
22252300 elif (
22262301 usample [v_ax0 , T .int64 (0 )]
2227- >= cumsum_sorted [v_ax0 , v_ax1 - 1 ] / renorm_prob [v_ax0 , 0 ]
2302+ >= cumsum_sorted [sample_indices [v_ax0 , T .int64 (0 )], v_ax1 - 1 ]
2303+ / renorm_prob [sample_indices [v_ax0 , T .int64 (0 )], 0 ]
22282304 ):
22292305 output_index [v_ax0 , 0 ] = indices [v_ax0 , v_ax1 ]
22302306
@@ -2235,16 +2311,16 @@ def _get_index_from_sorted(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E
22352311 "get_renorm_prob" ,
22362312 args = [cumsum_sorted , top_p , top_k ],
22372313 out = Tensor .placeholder (
2238- [batch , 1 ],
2314+ [prob_batch , 1 ],
22392315 prob_dtype ,
22402316 ),
22412317 )
22422318
22432319 out_index_in_sorted = tensor_ir_op (
22442320 _get_index_from_sorted ,
22452321 "get_index_from_sorted" ,
2246- args = [cumsum_sorted , renorm_prob , uniform_sample , sorted_index ],
2247- out = Tensor .placeholder ([batch , 1 ], index_dtype ),
2322+ args = [cumsum_sorted , sorted_index , renorm_prob , uniform_sample , sample_indices ],
2323+ out = Tensor .placeholder ([out_batch , 1 ], index_dtype ),
22482324 )
22492325 return out_index_in_sorted
22502326
@@ -2293,7 +2369,7 @@ def _get_renorm_cutoff(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.
22932369 top_k = T .match_buffer (D , (batch , 1 ), top_k_dtype )
22942370 cutoff = T .match_buffer (E , (batch , 1 ), prob_dtype )
22952371 for ax0 , ax1 in T .grid (batch , vocab_size ):
2296- with T .block ("T_get_renorm_prob " ):
2372+ with T .block ("T_get_renorm_cutoff " ):
22972373 v_ax0 , v_ax1 = T .axis .remap ("SS" , [ax0 , ax1 ])
22982374 if _cumsum_mask (cumsum_sorted , top_p , top_k , v_ax0 , 0 ) == 0 :
22992375 cutoff [v_ax0 , 0 ] = sorted_prob [v_ax0 , 0 ]
0 commit comments