@@ -20,11 +20,12 @@ def __init__(self, model: Transformer, tokenizer: Tokenizer):
2020 backend = "torchxla_trace_once" , fullgraph = True )
2121
2222 def _generate_one_token (self , tokens , input_tokens , input_text_mask , cur_pos_tensor ,
23- input_pos_tensor , output_pos_tensor , cache_kvs , temperature , top_p ):
23+ input_pos_tensor , output_pos_tensor , cache_kvs ,
24+ temperature_tensor , top_p_tensor , with_temp ):
2425 logits , cache_kvs = self .model (input_tokens , input_pos_tensor , output_pos_tensor , cache_kvs )
25- if temperature > 0 :
26- probs = torch .softmax (logits / temperature , dim = - 1 )
27- next_token = sample_top_p (probs , top_p )
26+ if with_temp :
27+ probs = torch .softmax (logits / temperature_tensor , dim = - 1 )
28+ next_token = sample_top_p (probs , top_p_tensor )
2829 else :
2930 next_token = torch .argmax (logits , dim = - 1 )
3031 next_token = next_token .reshape (- 1 )
@@ -71,6 +72,12 @@ def generate(
7172 tokens = tokens .to (device )
7273 input_text_mask = tokens != self .tokenizer .pad_id
7374
75+ # Passing tensors instead of floats into self._generate_one_token_fn,
76+ # so that different values would not trigger compilations of new graphs
77+ temperature_tensor = torch .tensor (temperature ).to (device )
78+ top_p_tensor = torch .tensor (top_p ).to (device )
79+ with_temp = temperature > 0
80+
7481 cache_kvs = self .model .cache_kvs
7582 xm .mark_step ()
7683
@@ -92,7 +99,8 @@ def generate(
9299 tokens , input_tokens , cur_pos_tensor , input_pos_tensor , output_pos_tensor , cache_kvs \
93100 = self ._generate_one_token_fn (
94101 tokens , input_tokens , input_text_mask , cur_pos_tensor ,
95- input_pos_tensor , output_pos_tensor , cache_kvs , temperature , top_p
102+ input_pos_tensor , output_pos_tensor , cache_kvs ,
103+ temperature_tensor , top_p_tensor , with_temp
96104 )
97105 xm .mark_step ()
98106
@@ -103,7 +111,8 @@ def generate(
103111 tokens , input_tokens , cur_pos_tensor , input_pos_tensor , output_pos_tensor , cache_kvs \
104112 = self ._generate_one_token_fn (
105113 tokens , input_tokens , input_text_mask , cur_pos_tensor ,
106- input_pos_tensor , output_pos_tensor , cache_kvs , temperature , top_p
114+ input_pos_tensor , output_pos_tensor , cache_kvs ,
115+ temperature_tensor , top_p_tensor , with_temp
107116 )
108117 xm .mark_step ()
109118 self .model .cache_kvs = cache_kvs
@@ -133,7 +142,7 @@ def generate(
133142def sample_top_p (probs , p ):
134143 probs_sort , probs_idx = torch .sort (probs , dim = - 1 , descending = True )
135144 probs_sum = torch .cumsum (probs_sort , dim = - 1 )
136- mask = probs_sum - probs_sort > p
145+ mask = ( probs_sum - probs_sort ) > p
137146 probs_sort = torch .where (mask , 0.0 , probs_sort )
138147 probs_sort .div_ (probs_sort .sum (dim = - 1 , keepdim = True ))
139148 next_token = torch .multinomial (probs_sort , num_samples = 1 )
0 commit comments