From 083fdafd5c77b35180cbb5aabce1ff297686f384 Mon Sep 17 00:00:00 2001 From: Michael Gschwind <61328285+mikekgfb@users.noreply.github.com> Date: Wed, 5 Feb 2025 10:20:57 -0800 Subject: [PATCH] Enable sdpa backends for server export in export.py (#1478) * Enable sdpa backends for server export in export.py FLASH worked for dso models, so try this with methodical tests * Update more-tests.yml Add tests for sdpa backends with server export (x86 cpu & cuda) * Update export.py Fix typo. * Update more-tests.yml Display test information to simplify debug * Update more-tests.yml fix typo * Update more-tests.yml Need to generate cmake-out * Update more-tests.yml Update MODEL_DIR definition so that aoti_run can find the tokenizer.model * Update more-tests.yml * Update more-tests.yml --------- Co-authored-by: Jack-Khuu --- .github/workflows/more-tests.yml | 69 +++++++++++++++++++++++++++++++- torchchat/export.py | 33 +++++++-------- 2 files changed, 84 insertions(+), 18 deletions(-) diff --git a/.github/workflows/more-tests.yml b/.github/workflows/more-tests.yml index f772382d1..dedbcc982 100644 --- a/.github/workflows/more-tests.yml +++ b/.github/workflows/more-tests.yml @@ -19,6 +19,7 @@ jobs: gpu-arch-version: "12.4" timeout: 60 script: | + set -xeou pipefail echo "::group::Print machine info" uname -a echo "::endgroup::" @@ -39,9 +40,10 @@ jobs: echo "::endgroup::" echo "::group::Run inference" - export MODEL_PATH=checkpoints/stories15M/stories15M.pt + export MODEL_DIR=checkpoints/stories15M/ + export MODEL_PATH=${MODEL_DIR}/stories15M.pt export MODEL_NAME=stories15M - export MODEL_DIR=/tmp + for DTYPE in bfloat16 float16 float32; do ################################################################### @@ -83,3 +85,66 @@ jobs: echo "tests complete" echo "******************************************" echo "::endgroup::" + + + test-sdpa-backends-export: + permissions: + id-token: write + contents: read + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + with: + runner: linux.g5.4xlarge.nvidia.gpu + gpu-arch-type: cuda + gpu-arch-version: "12.4" + timeout: 60 + script: | + set -xeou pipefail + echo "::group::Print machine info" + uname -a + echo "::endgroup::" + + echo "::group::Download checkpoints" + # Install requirements + ./install/install_requirements.sh cuda + pip3 list + python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")' + echo "::endgroup::" + + echo "::group::Download checkpoints" + mkdir -p checkpoints/stories15M + pushd checkpoints/stories15M + wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.pt + wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.model + popd + echo "::endgroup::" + + echo "::group::Run inference" + export MODEL_DIR=checkpoints/stories15M/ + export MODEL_PATH=${MODEL_DIR}/stories15M.pt + export MODEL_NAME=stories15M + + ./torchchat/utils/scripts/build_native.sh aoti + + for DEVICE in cpu cuda; do + # depending on how the parameter passing works, may only be able to do bfloat16 for aoti_run, similar to runner-cuda-dtype.yml + # (although the runner environment should not have an opinion what we us in the artifact, and we might suitably abstract that) + for DTYPE in bfloat16 float16 float32; do + for SDPA in 'math' 'flash_attention' 'efficient_attention' 'cudnn_attention'; do + echo "***************************************************************" + echo "*** $DEVICE $DTYPE $SDPA" + ################################################################### + # Export DSO and run with Python + python torchchat.py export --output-dso dso.so --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE} + python torchchat.py generate --dso-path dso.so --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE} --temperature 0 --prompt "Once upon a time" + ################################################################### + # Export AOTI and run with aoti_run + python torchchat.py export --output-aoti /tmp/model.pt2 --checkpoint-path ${MODEL_PATH} --attention-backend ${SDPA} --device ${DEVICE} --dtype ${DTYPE} + ./cmake-out/aoti_run /tmp/model.pt2 -z ${MODEL_DIR}/tokenizer.model -i "Once upon a time" + ################################################################### + done + done + done + + echo "tests complete" + echo "******************************************" + echo "::endgroup::" diff --git a/torchchat/export.py b/torchchat/export.py index e7cb32309..997639ffe 100644 --- a/torchchat/export.py +++ b/torchchat/export.py @@ -490,13 +490,14 @@ def main(args): print( "WARNING!! The path of compiling a dso is deprecated. Please use --output-aoti-package-path to create a .pt2 artifact instead." ) - export_for_server( - model_to_dso, - builder_args.device, - output_dso_path, - builder_args.dynamic_shapes, - package=False, - ) + with torch.nn.attention.sdpa_kernel([builder_args.attention_backend]): + export_for_server( + model_to_dso, + builder_args.device, + output_dso_path, + builder_args.dynamic_shapes, + package=False, + ) if output_aoti_package_path: output_aoti_package_path = str(os.path.abspath(output_aoti_package_path)) @@ -512,14 +513,15 @@ def main(args): print( "Exporting model using AOT Inductor to " f"{output_aoti_package_path}." ) - export_for_server( - model_to_aoti_package, - builder_args.device, - output_aoti_package_path, - builder_args.dynamic_shapes, - package=True, - metadata=metadata, - ) + with torch.nn.attention.sdpa_kernel([builder_args.attention_backend]): + export_for_server( + model_to_aoti_package, + builder_args.device, + output_aoti_package_path, + builder_args.dynamic_shapes, + package=True, + metadata=metadata, + ) if output_snapshot_path: output_snapshot_path = str(os.path.abspath(output_snapshot_path)) @@ -529,4 +531,3 @@ def main(args): builder_args.device, output_snapshot_path, ) -