Skip to content

Commit

Permalink
Enable sdpa backends for server export in export.py (#1478)
Browse files Browse the repository at this point in the history
* Enable sdpa backends for server export in export.py

FLASH worked for dso models, so try this with methodical tests

* Update more-tests.yml

Add tests for sdpa backends with server export (x86 cpu & cuda)

* Update export.py

Fix typo.

* Update more-tests.yml

Display test information to simplify debug

* Update more-tests.yml

fix typo

* Update more-tests.yml

Need to generate cmake-out

* Update more-tests.yml

Update MODEL_DIR definition so that aoti_run can find the tokenizer.model

* Update more-tests.yml

* Update more-tests.yml

---------

Co-authored-by: Jack-Khuu <[email protected]>
  • Loading branch information
mikekgfb and Jack-Khuu authored Feb 5, 2025
1 parent d607ecc commit 083fdaf
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 18 deletions.
69 changes: 67 additions & 2 deletions .github/workflows/more-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ jobs:
gpu-arch-version: "12.4"
timeout: 60
script: |
set -xeou pipefail
echo "::group::Print machine info"
uname -a
echo "::endgroup::"
Expand All @@ -39,9 +40,10 @@ jobs:
echo "::endgroup::"
echo "::group::Run inference"
export MODEL_PATH=checkpoints/stories15M/stories15M.pt
export MODEL_DIR=checkpoints/stories15M/
export MODEL_PATH=${MODEL_DIR}/stories15M.pt
export MODEL_NAME=stories15M
export MODEL_DIR=/tmp
for DTYPE in bfloat16 float16 float32; do
###################################################################
Expand Down Expand Up @@ -83,3 +85,66 @@ jobs:
echo "tests complete"
echo "******************************************"
echo "::endgroup::"
test-sdpa-backends-export:
permissions:
id-token: write
contents: read
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
runner: linux.g5.4xlarge.nvidia.gpu
gpu-arch-type: cuda
gpu-arch-version: "12.4"
timeout: 60
script: |
set -xeou pipefail
echo "::group::Print machine info"
uname -a
echo "::endgroup::"
echo "::group::Download checkpoints"
# Install requirements
./install/install_requirements.sh cuda
pip3 list
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
echo "::endgroup::"
echo "::group::Download checkpoints"
mkdir -p checkpoints/stories15M
pushd checkpoints/stories15M
wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt
wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
popd
echo "::endgroup::"
echo "::group::Run inference"
export MODEL_DIR=checkpoints/stories15M/
export MODEL_PATH=${MODEL_DIR}/stories15M.pt
export MODEL_NAME=stories15M
./torchchat/utils/scripts/build_native.sh aoti
for DEVICE in cpu cuda; do
# depending on how the parameter passing works, may only be able to do bfloat16 for aoti_run, similar to runner-cuda-dtype.yml
# (although the runner environment should not have an opinion what we us in the artifact, and we might suitably abstract that)
for DTYPE in bfloat16 float16 float32; do
for SDPA in 'math' 'flash_attention' 'efficient_attention' 'cudnn_attention'; do
echo "***************************************************************"
echo "*** $DEVICE $DTYPE $SDPA"
###################################################################
# Export DSO and run with Python
python torchchat.py export --output-dso dso.so --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE}
python torchchat.py generate --dso-path dso.so --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE} --temperature 0 --prompt "Once upon a time"
###################################################################
# Export AOTI and run with aoti_run
python torchchat.py export --output-aoti /tmp/model.pt2 --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE}
./cmake-out/aoti_run /tmp/model.pt2 -z ${MODEL_DIR}/tokenizer.model -i "Once upon a time"
###################################################################
done
done
done
echo "tests complete"
echo "******************************************"
echo "::endgroup::"
33 changes: 17 additions & 16 deletions torchchat/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,13 +490,14 @@ def main(args):
print(
"WARNING!! The path of compiling a dso is deprecated. Please use --output-aoti-package-path to create a .pt2 artifact instead."
)
export_for_server(
model_to_dso,
builder_args.device,
output_dso_path,
builder_args.dynamic_shapes,
package=False,
)
with torch.nn.attention.sdpa_kernel([builder_args.attention_backend]):
export_for_server(
model_to_dso,
builder_args.device,
output_dso_path,
builder_args.dynamic_shapes,
package=False,
)

if output_aoti_package_path:
output_aoti_package_path = str(os.path.abspath(output_aoti_package_path))
Expand All @@ -512,14 +513,15 @@ def main(args):
print(
"Exporting model using AOT Inductor to " f"{output_aoti_package_path}."
)
export_for_server(
model_to_aoti_package,
builder_args.device,
output_aoti_package_path,
builder_args.dynamic_shapes,
package=True,
metadata=metadata,
)
with torch.nn.attention.sdpa_kernel([builder_args.attention_backend]):
export_for_server(
model_to_aoti_package,
builder_args.device,
output_aoti_package_path,
builder_args.dynamic_shapes,
package=True,
metadata=metadata,
)

if output_snapshot_path:
output_snapshot_path = str(os.path.abspath(output_snapshot_path))
Expand All @@ -529,4 +531,3 @@ def main(args):
builder_args.device,
output_snapshot_path,
)

0 comments on commit 083fdaf

Please sign in to comment.