Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


def chat():
device = 'cuda'
device = 'mps'
model = AutoModel.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True)

Expand All @@ -29,7 +29,7 @@ def chat():
else:
prompt = torch.cat([prompt, input_ids[:, 1:]], dim=1)

out = generate(model, prompt, steps=steps, gen_length=gen_length, block_length=32, temperature=0., cfg_scale=0., remasking='low_confidence')
out = generate(model, prompt, steps=steps, gen_length=gen_length, block_length=32, temperature=0., cfg_scale=0., remasking='low_confidence', is_mps=True)

answer = tokenizer.batch_decode(out[:, prompt.shape[1]:], skip_special_tokens=True)[0]
print(f"Bot's reply: {answer}")
Expand Down
19 changes: 13 additions & 6 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@
from transformers import AutoTokenizer, AutoModel


def add_gumbel_noise(logits, temperature):
def add_gumbel_noise(logits, temperature, is_mps):
'''
The Gumbel max is a method for sampling categorical distributions.
According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
Thus, we use float64.
'''
logits = logits.to(torch.float64)
noise = torch.rand_like(logits, dtype=torch.float64)
if not is_mps:
logits = logits.to(torch.float64)
noise = torch.rand_like(logits, dtype=torch.float64)
else:
logits = logits.to(torch.float32)
noise = torch.rand_like(logits, dtype=torch.float32)
gumbel_noise = (- torch.log(noise)) ** temperature
return logits.exp() / gumbel_noise

Expand Down Expand Up @@ -40,7 +44,7 @@ def get_num_transfer_tokens(mask_index, steps):

@ torch.no_grad()
def generate(model, prompt, steps=128, gen_length=128, block_length=128, temperature=0.,
cfg_scale=0., remasking='low_confidence', mask_id=126336):
cfg_scale=0., remasking='low_confidence', mask_id=126336, is_mps=False):
'''
Args:
model: Mask predictor.
Expand Down Expand Up @@ -79,11 +83,14 @@ def generate(model, prompt, steps=128, gen_length=128, block_length=128, tempera
else:
logits = model(x).logits

logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
logits_with_noise = add_gumbel_noise(logits, temperature=temperature, is_mps=is_mps)
x0 = torch.argmax(logits_with_noise, dim=-1) # b, l

if remasking == 'low_confidence':
p = F.softmax(logits.to(torch.float64), dim=-1)
if not is_mps:
p = F.softmax(logits.to(torch.float64), dim=-1)
else:
p = F.softmax(logits.to(torch.float32), dim=-1)
x0_p = torch.squeeze(
torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
elif remasking == 'random':
Expand Down