diff --git a/atom/examples/simple_inference.py b/atom/examples/simple_inference.py index 477d9ff42..fda0de7c6 100644 --- a/atom/examples/simple_inference.py +++ b/atom/examples/simple_inference.py @@ -45,7 +45,7 @@ def main(): "1+2+3=?", "如何在一个月内增肌10公斤", "+".join([f"{i}-{i+1}" for i in range(1000)]) + "=? 最后结果是什么", - "+".join([f"{i}+{i+1}" for i in range(3000)]) + "=? 最后结果是什么", + "+".join([f"{i}+{i+1}" for i in range(1500)]) + "=? 最后结果是什么", ] args = parser.parse_args() # Generate power of 2 sizes for CUDA graph: [1, 2, 4, 8, ...] diff --git a/atom/model_loader/loader.py b/atom/model_loader/loader.py index e1929b285..987b05e18 100644 --- a/atom/model_loader/loader.py +++ b/atom/model_loader/loader.py @@ -12,11 +12,21 @@ from typing import Any import safetensors +import safetensors.torch import torch from torch import nn from tqdm import tqdm from transformers import AutoConfig +# safetensors<=0.7.0 ships a Python `_TYPES` dict missing the `F8_E8M0` +# (MX scale) entry, even though both torch and the safetensors-rust binary +# support it. The mmap'd `safe_open` path goes through Rust and works, but +# the `safetensors.torch.load(bytes)` path used when `ATOM_DISABLE_MMAP=true` +# raises `KeyError: 'F8_E8M0'` on DeepSeek-V4-Pro shards. Register the +# missing dtype string so both paths behave identically. +if "F8_E8M0" not in safetensors.torch._TYPES and hasattr(torch, "float8_e8m0fnu"): + safetensors.torch._TYPES["F8_E8M0"] = torch.float8_e8m0fnu + from atom.utils import envs from transformers.utils import SAFE_WEIGHTS_INDEX_NAME