Skip to content

Commit

Permalink
gemma : fix bfloat16 -> float16 conversion issue (ggerganov#5810)
Browse files Browse the repository at this point in the history
  • Loading branch information
kunal-vaishnavi authored and hazelnutcloud committed Mar 10, 2024
1 parent d134b79 commit e54be67
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1811,16 +1811,15 @@ def write_tensors(self):
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)

for name, data_torch in self.get_tensors():
# ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89
if name.endswith("norm.weight"):
data_torch = data_torch + 1

old_dtype = data_torch.dtype

# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)

# ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89
if name.endswith("norm.weight"):
data_torch = data_torch + 1
data = data_torch.squeeze().numpy()

# map tensor names
Expand Down

0 comments on commit e54be67

Please sign in to comment.