Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator, INCAcceleratorType
cur_accelerator = auto_detect_accelerator()

from neural_compressor.torch.utils import environ
from neural_compressor.common.utils import logger

descale_fcn = lambda x, scale: torch.mul(x, scale)
scale_fcn = lambda x, scale: torch.div(x, scale)
cast_fcn = lambda x, dtype: x.to(dtype=dtype)
Expand Down Expand Up @@ -106,6 +109,9 @@ def get_fp8_hw_alligned_scales(dtype, device):
}

def calc_maxabs_scale(xmaxabs, fullscale, backoff=1):
if environ.INC_FORCE_NAIVE_SCALING:
backoff = 1.0
logger.warning_once(f"Enabled naive scaling, backoff is set to {backoff}")
scale = xmaxabs / (fullscale * backoff)
return scale

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from .round_scales_function import *
from ..common import get_device_type_for_scales
from .scales_method import *

from neural_compressor.torch.utils import environ
from neural_compressor.common.utils import logger

class QuantTensorName(Enum):
INPUT = auto()
Expand All @@ -40,6 +41,9 @@ class ScaleValueType(Enum):

def parse_rounding_method(config, device_for_scales):
round_method = ScaleIdentity()
if environ.INC_FORCE_NAIVE_SCALING:
logger.warning_once("Enabled naive scaling")
return round_method
if "single" in config and "hw" in config:
round_method = ScaleHwAlignedFixed(device_for_scales)
elif "unit" in config:
Expand Down
4 changes: 4 additions & 0 deletions neural_compressor/torch/utils/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
world_size = int(os.getenv("WORLD_SIZE", "-1"))



INC_FORCE_NAIVE_SCALING = os.getenv("INC_FORCE_NAIVE_SCALING", "0").lower() in ["1", "true"]


################ Check imported sys.module first to decide behavior #################
def is_ipex_imported() -> bool:
"""Check whether intel_extension_for_pytorch is imported."""
Expand Down