Skip to content

Commit

Permalink
Merge branch 'pytorch:main' into patch-23
Browse files Browse the repository at this point in the history
  • Loading branch information
mikekgfb authored Jan 18, 2025
2 parents 1dd12fd + 2fc98f7 commit a42ece0
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 21 deletions.
36 changes: 23 additions & 13 deletions install/install_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,6 @@ VISION_NIGHTLY_VERSION=dev20241218
# Nightly version for torchtune
TUNE_NIGHTLY_VERSION=dev20241218

# Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same
(
set -x
$PIP_EXECUTABLE uninstall -y triton
)

# The pip repository that hosts nightly torch packages. cpu by default.
# If cuda is available, based on presence of nvidia-smi, install the pytorch nightly
# with cuda for faster execution on cuda GPUs.
Expand All @@ -74,16 +68,28 @@ then
elif [[ -x "$(command -v rocminfo)" ]];
then
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/rocm6.2"
elif [[ -x "$(command -v xpu-smi)" ]];
then
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/xpu"
else
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/cpu"
fi

# pip packages needed by exir.
REQUIREMENTS_TO_INSTALL=(
torch=="2.6.0.${PYTORCH_NIGHTLY_VERSION}"
torchvision=="0.22.0.${VISION_NIGHTLY_VERSION}"
torchtune=="0.5.0.${TUNE_NIGHTLY_VERSION}"
)
if [[ -x "$(command -v xpu-smi)" ]];
then
REQUIREMENTS_TO_INSTALL=(
torch=="2.6.0.${PYTORCH_NIGHTLY_VERSION}"
torchvision=="0.22.0.${VISION_NIGHTLY_VERSION}"
torchtune=="0.5.0"
)
else
REQUIREMENTS_TO_INSTALL=(
torch=="2.6.0.${PYTORCH_NIGHTLY_VERSION}"
torchvision=="0.22.0.${VISION_NIGHTLY_VERSION}"
torchtune=="0.5.0.${TUNE_NIGHTLY_VERSION}"
)
fi

#
# First install requirements in install/requirements.txt. Older torch may be
Expand All @@ -95,6 +101,12 @@ REQUIREMENTS_TO_INSTALL=(
$PIP_EXECUTABLE install -r install/requirements.txt --extra-index-url "${TORCH_NIGHTLY_URL}"
)

# Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same
(
set -x
$PIP_EXECUTABLE uninstall -y triton
)

# Install the requirements. --extra-index-url tells pip to look for package
# versions on the provided URL if they aren't available on the default URL.
(
Expand All @@ -116,8 +128,6 @@ if [[ -x "$(command -v nvidia-smi)" ]]; then
$PYTHON_EXECUTABLE torchchat/utils/scripts/patch_triton.py
)
fi


(
set -x
$PIP_EXECUTABLE install evaluate=="0.4.3" lm-eval=="0.4.2" psutil=="6.0.0"
Expand Down
7 changes: 6 additions & 1 deletion torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,12 @@ class BuilderArgs:

def __post_init__(self):
if self.device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
self.device = "cuda"
elif torch.xpu.is_available():
self.device = "xpu"
else:
self.device = "cpu"

if not (
(self.checkpoint_path and self.checkpoint_path.is_file())
Expand Down
4 changes: 2 additions & 2 deletions torchchat/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ def _add_model_config_args(parser, verb: str) -> None:
"--device",
type=str,
default=None,
choices=["fast", "cpu", "cuda", "mps"],
help="Hardware device to use. Options: fast, cpu, cuda, mps",
choices=["fast", "cpu", "cuda", "mps", "xpu"],
help="Hardware device to use. Options: fast, cpu, cuda, mps, xpu",
)


Expand Down
9 changes: 8 additions & 1 deletion torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,8 +1203,10 @@ def callback(x, *, done_generating=False):
if hasattr(prof, "export_chrome_trace"):
if self.builder_args.device == "cpu":
print(prof.key_averages().table(sort_by="self_cpu_time_total"))
else:
elif self.builder_args.device == "cuda":
print(prof.key_averages().table(sort_by="self_cuda_time_total"))
else:
print(prof.key_averages().table(sort_by="self_xpu_time_total"))
prof.export_chrome_trace(f"{self.profile}.json")

if start_pos >= max_seq_length:
Expand Down Expand Up @@ -1289,6 +1291,9 @@ def callback(x, *, done_generating=False):
)
if torch.cuda.is_available():
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
if torch.xpu.is_available():
print(f"Memory used: {torch.xpu.max_memory_reserved() / 1e9:.02f} GB")



class DistributedGenerator(LocalGenerator):
Expand Down Expand Up @@ -1615,6 +1620,8 @@ def run_generator(
)
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
if torch.xpu.is_available():
torch.xpu.reset_peak_memory_stats()

for _ in gen.chat(generator_args):
pass
Expand Down
8 changes: 6 additions & 2 deletions torchchat/utils/build_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ def find_multiple(n: int, k: int) -> int:
def device_sync(device="cpu"):
if "cuda" in device:
torch.cuda.synchronize(device)
elif "xpu" in device:
torch.xpu.synchronize(device)
elif ("cpu" in device) or ("mps" in device):
pass
else:
Expand Down Expand Up @@ -279,7 +281,8 @@ def get_device_str(device) -> str:
device = (
"cuda"
if torch.cuda.is_available()
else "mps" if is_mps_available() else "cpu"
else "mps" if is_mps_available()
else "xpu" if torch.xpu.is_available() else "cpu"
)
return device
else:
Expand All @@ -291,7 +294,8 @@ def get_device(device) -> str:
device = (
"cuda"
if torch.cuda.is_available()
else "mps" if is_mps_available() else "cpu"
else "mps" if is_mps_available()
else "xpu" if torch.xpu.is_available() else "cpu"
)
return torch.device(device)

Expand Down
11 changes: 10 additions & 1 deletion torchchat/utils/device_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def get_device_info(device: str) -> str:
"""Returns a human-readable description of the hardware based on a torch.device.type
Args:
device: A torch.device.type string: one of {"cpu", "cuda"}.
device: A torch.device.type string: one of {"cpu", "cuda", "xpu"}.
Returns:
str: A human-readable description of the hardware or an empty string if the device type is unhandled.
Expand All @@ -37,4 +37,13 @@ def get_device_info(device: str) -> str:
)
if device == "cuda":
return torch.cuda.get_device_name(0)
if device == "xpu":
return (
check_output(
["xpu-smi discovery |grep 'Device Name:'"], shell=True
)
.decode("utf-8")
.split("\n")[0]
.split("Device Name:")[1]
)
return ""
2 changes: 1 addition & 1 deletion torchchat/utils/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def quantize_model(
raise RuntimeError(f"unknown quantizer {quantizer} specified")
else:
# Use tensor subclass API for int4 weight only.
if device == "cuda" and quantizer == "linear:int4":
if (device == "cuda" or device == "xpu") and quantizer == "linear:int4":
quantize_(model, int4_weight_only(q_kwargs["groupsize"]))
if not support_tensor_subclass:
unwrap_tensor_subclass(model)
Expand Down

0 comments on commit a42ece0

Please sign in to comment.