Skip to content

Commit

Permalink
Update Megatron GPT eval script for non-FP8 path
Browse files Browse the repository at this point in the history
Signed-off-by: Yen-Shi Wang <[email protected]>
  • Loading branch information
Yen-Shi Wang committed May 27, 2023
1 parent c526884 commit 1a09df5
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions examples/nlp/language_modeling/megatron_gpt_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ def main(cfg) -> None:
"compute_logprob": cfg.inference.compute_logprob,
}

if model.cfg.fp8 == True:
fp8_enabled = hasattr(model.cfg, "fp8") and (model.cfg.fp8 == True)
if fp8_enabled:
nb_paddings = 0
while len(cfg.prompts) % 8 != 0:
cfg.prompts.append("")
Expand All @@ -274,21 +275,22 @@ def main(cfg) -> None:
inputs=OmegaConf.to_container(cfg.prompts), length_params=length_params, sampling_params=sampling_params
)

if model.cfg.fp8 == True:
if fp8_enabled:
response = remove_padded_prompts(response, nb_paddings)
print("***************************")
print(response)
print("***************************")

# Second method of running text generation, call trainer.predict
bs = 8 if model.cfg.fp8 == True else 2
bs = 8 if fp8_enabled else 2
ds = RequestDataSet(OmegaConf.to_container(cfg.prompts))
request_dl = DataLoader(dataset=ds, batch_size=bs)
config = OmegaConf.to_container(cfg.inference)
print(f"Config for trainer.predict = {config}")
model.set_inference_config(config)
response = trainer.predict(model, request_dl)

if model.cfg.fp8 == True:
if fp8_enabled:
response[-1] = remove_padded_prompts(response[-1], nb_paddings)
print("***************************")
print(response)
Expand Down

0 comments on commit 1a09df5

Please sign in to comment.