Skip to content

Commit ecc3c7d

Browse files
committed
Update PyTorch pin and enable MPS qops (#725)
* Update PyTorch pin And enable linter:int8 and linter:int4 acceleration on MPS * Update run-readme-pr.yml * Update install_requirements.sh
1 parent 18937ce commit ecc3c7d

File tree

3 files changed

+17
-9
lines changed

3 files changed

+17
-9
lines changed

Diff for: .github/workflows/run-readme-pr.yml

+4-4
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ jobs:
2020
uname -a
2121
echo "::endgroup::"
2222
23-
# echo "::group::Install newer objcopy that supports --set-section-alignment"
24-
# yum install -y devtoolset-10-binutils
25-
# export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH
26-
# echo "::endgroup::"
23+
echo "::group::Install newer objcopy that supports --set-section-alignment"
24+
yum install -y devtoolset-10-binutils
25+
export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH
26+
echo "::endgroup::"
2727
2828
echo "::group::Create script to run README"
2929
python3 scripts/updown.py --file README.md --replace 'llama3:stories15M,-l 3:-l 2,meta-llama/Meta-Llama-3-8B-Instruct:stories15M' --suppress huggingface-cli,HF_TOKEN > ./run-readme.sh

Diff for: install_requirements.sh

+3-1
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,16 @@ $PIP_EXECUTABLE install -r requirements.txt --extra-index-url https://download.p
3939
# NOTE: If a newly-fetched version of the executorch repo changes the value of
4040
# NIGHTLY_VERSION, you should re-run this script to install the necessary
4141
# package versions.
42-
NIGHTLY_VERSION=dev20240422
42+
NIGHTLY_VERSION=dev20240507
4343

4444
# The pip repository that hosts nightly torch packages. cpu by default.
4545
# If cuda is available, based on presence of nvidia-smi, install the pytorch nightly
4646
# with cuda for faster execution on cuda GPUs.
4747
if [[ -x "$(command -v nvidia-smi)" ]];
4848
then
4949
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/cu121"
50+
# Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same
51+
$PIP_EXECUTABLE uninstall -y triton
5052
else
5153
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/cpu"
5254
fi

Diff for: qops.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def linear_int8_aoti(input, weight, scales):
1515
scales = scales.view(-1)
1616
if (
1717
torch.compiler.is_compiling()
18-
or input.device.type != "cpu"
18+
or input.device.type not in ["cpu", "mps"]
1919
or not hasattr(torch.ops.aten, "_weight_int8pack_mm")
2020
):
2121
lin = F.linear(input, weight.to(dtype=input.dtype))
@@ -395,9 +395,15 @@ def _prepare_weight_and_scales_and_zeros(
395395
weight_int32, scales_and_zeros = group_quantize_tensor(
396396
weight_bf16, n_bit=4, groupsize=groupsize
397397
)
398-
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
399-
weight_int32, inner_k_tiles
400-
)
398+
if weight_bf16.device.type == "mps":
399+
# There are still no MPS-accelerated conversion OP
400+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
401+
weight_int32.cpu(), inner_k_tiles
402+
).to("mps")
403+
else:
404+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
405+
weight_int32, inner_k_tiles
406+
)
401407
return weight_int4pack, scales_and_zeros
402408

403409
@classmethod

0 commit comments

Comments
 (0)