Skip to content

Commit 6732127

Browse files
mikekgfbshoumikhinmetascroymalfetlucylq
committed
make --device fast the default (pytorch#515)
* make --device fast the default * Update iOS.md (pytorch#517) * Update iOS.md * Update iOS.md * Pip to pip3 (pytorch#504) * remove macos-12 test * pip to pip3 * break aoti CI jobs separately (pytorch#500) * init * fixes * more fixes * fixes * fix * fix * bug fix * add objcopy update * suppress int8 * undefined variable --------- Co-authored-by: Michael Gschwind <[email protected]> * Support llama3 in chat in run.cpp (pytorch#486) * refactor chat runner in preparation for llama3 * add sketch for llama3 prompt template and move to returning tokens * fix tiktoken * fixes to chat * add default llama_ver * Add tests for quantize json, add cuda device specification and precision to cuda.json (pytorch#519) * remove code for no KV Cache path (pytorch#527) * Update ADVANCED-USERS.md (pytorch#529) Update Advanced Users description to reflect changes in the repo since the description was initially created. * runner-aoti on cuda (pytorch#531) * runner-aoti on cuda * transfer results back to CPU * transfer results back to CPU * runner-aoti on cuda * Update runner_build.md (pytorch#530) Update description of runner and build process in runner_build.md * clean up runner code a little (pytorch#532) * clean up runner code a little * update * update * pull out generate loop in chat * updates * edit docs * typo * move int8 linear class and function into qops.py (pytorch#534) * add dtype tests for runner-aoti + runner-et (pytorch#539) * add dtype tests for runner-aoti + runner-et * typo * Quantized embedding (pytorch#536) * move int8 linear class and function into qops.py * move Quantized Embedding to qops.py * Move Linear int4 to qops (pytorch#537) * move int8 linear class and function into qops.py * move Quantized Embedding to qops.py * move int4 linear to qops * Revert "add dtype tests for runner-aoti + runner-et (pytorch#539)" (pytorch#548) This reverts commit a7a24577a65be67ac9ae4dc05452f35d9c49e5d1. * fix generate for llama3 (pytorch#538) * fix generate for llama3 * switch more things to C * remove C++ header * add delegation visualization instructions (pytorch#551) * Add dtype runner aoti (pytorch#552) * add dtype tests for runner-aoti + runner-et * typo * add dtype test runner-aoti * test sdpa with fp16 (pytorch#553) * test sdpa with fp16 * kv cache fp32 * typo * update (pytorch#560) * Only support newest versions of lm-eval (pytorch#556) Summary: remove support for lm-eval 0.3 to reduce the options we have Test Plan: CI Reviewers: Subscribers: Tasks: Tags: * split cpu eval CI by dtype (pytorch#554) * split cpu eval CI by dtype * fix * differentiate names with checks * keep one name the same as old * fix * Removing duplicate HF issue message from README (pytorch#559) Co-authored-by: Michael Gschwind <[email protected]> * doc updates (pytorch#567) * Add VM-safe MPS check --------- Co-authored-by: Anthony Shoumikhin <[email protected]> Co-authored-by: metascroy <[email protected]> Co-authored-by: Nikita Shulga <[email protected]> Co-authored-by: lucylq <[email protected]> Co-authored-by: Jerry Zhang <[email protected]> Co-authored-by: Jack-Khuu <[email protected]>
1 parent 62d4041 commit 6732127

File tree

3 files changed

+18
-4
lines changed

3 files changed

+18
-4
lines changed

build/utils.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -156,13 +156,27 @@ def state_dict_device(d, device="cpu") -> Dict:
156156
#########################################################################
157157
### move state dict to specified device ###
158158

159+
def is_mps_available() -> bool:
160+
if not torch.backends.mps.is_available():
161+
return False
162+
163+
# out system says mps is available, but it's not on VMs
164+
# so let's set up some memry, and see if that work:
165+
try:
166+
mps_tensor = torch.zero(1024, dtype=torch.float16, device="mps")
167+
except:
168+
return False
169+
170+
# MPS, is that you?
171+
return True
172+
159173

160174
def get_device_str(device) -> str:
161175
if isinstance(device, str) and device == "fast":
162176
return (
163177
"cuda"
164178
if torch.cuda.is_available()
165-
else "mps" if torch.backends.mps.is_available() else "cpu"
179+
else "mps" if is_mps_available() else "cpu"
166180
)
167181
else:
168182
return str(device)
@@ -173,6 +187,6 @@ def get_device(device) -> str:
173187
device = (
174188
"cuda"
175189
if torch.cuda.is_available()
176-
else "mps" if torch.backends.mps.is_available() else "cpu"
190+
else "mps" if is_mps_available() else "cpu"
177191
)
178192
return torch.device(device)

cli.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from build.utils import allowable_dtype_names, allowable_params_table, get_device_str
1313
from download import download_and_convert, is_model_downloaded
1414

15-
default_device = "cpu"
15+
default_device = "fast"
1616

1717

1818
# Handle CLI arguments that are common to a majority of subcommands.

generate.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def decode_n_tokens(
210210
):
211211
new_tokens, new_probs = [], []
212212
encountered_eos = False
213-
for i in range(
213+
for _i in range(
214214
num_new_tokens - 1
215215
): # -1 to save space to run an EoS if dont generate it naturally
216216
# Actually better for Inductor to codegen attention here

0 commit comments

Comments
 (0)