Skip to content

Commit

Permalink
move loading of model for inference into _load_inference_model (pytor…
Browse files Browse the repository at this point in the history
…ch#159)

* move loading of modelfor inference into _load_inference_model

* type

* load_inference_model

* load_inference_model

* typo
  • Loading branch information
mikekgfb authored and malfet committed Jul 17, 2024
1 parent 02d657f commit 4c003a6
Showing 1 changed file with 70 additions and 45 deletions.
115 changes: 70 additions & 45 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def device_sync(device):
elif ("cpu" in device) or ("mps" in device):
pass
else:
print(f"device={device} is not yet suppported")
print(f"device={ device } is not yet suppported")


torch._inductor.config.coordinate_descent_tuning = True
Expand Down Expand Up @@ -338,6 +338,69 @@ def _load_model(

B_INST, E_INST = "[INST]", "[/INST]"

def _load_inference_model(
checkpoint_path,
checkpoint_dir,
params_path,
params_table,
dso_path,
pte_path,
quantize,
device,
precision,
use_tp=False
):
print("Loading model ...")
t0 = time.time()
model_ = _load_model(
checkpoint_path,
checkpoint_dir,
params_path,
params_table,
device,
precision,
use_tp
)
device_sync(device=device) # MKG
print(f"Time to load model: {time.time() - t0:.02f} seconds")

if dso_path:
# make sure user did not try to set dtype
# assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export."
assert quantize is None or quantize == "{ }", f"quantize not valid for exported DSO model. Specify quantization during export."
try:
model = model_
# Replace model forward with the AOT-compiled forward
# This is a hacky way to quickly demo AOTI's capability.
# model is still a Python object, and any mutation to its
# attributes will NOT be seen on by AOTI-compiled forward
# function, e.g. calling model.setup_cache will NOT touch
# AOTI compiled and maintained model buffers such as kv_cache.
model.forward = torch._export.aot_load(str(dso_path.absolute()), device)
except:
raise RuntimeError(f"Failed to load AOTI compiled {dso_path}")
elif pte_path:
# make sure user did not try to set dtype
# assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export."
assert quantize is None or quantize == "{ }", f"quantize not valid for exported PTE model. Specify quantization during export."
try:
from model_et import PTEModel
model = PTEModel(model_.config, pte_path)
except Exception as e:
raise RuntimeError(f"Failed to load ET compiled {pte_path}")
else:
model = model_

if quantize:
t0q = time.time()
quantize_model(model, quantize)
device_sync(device=device) # MKG
print(f"Time to quantize model: {time.time() - t0q:.02f} seconds")

model.to(dtype=precision)

return model


def _main(
prompt: str = "Hello, my name is",
Expand Down Expand Up @@ -394,56 +457,21 @@ def _main(
is_speculative = draft_checkpoint_path is not None
is_chat = "chat" in str(checkpoint_path)

print("Loading model ...")
t0 = time.time()
model_ = _load_model(
model = _load_inference_model(
checkpoint_path,
checkpoint_dir,
params_path,
params_table,
dso_path,
pte_path,
quantize,
device,
precision,
use_tp
)
if dso_path:
# make sure user did not try to set dtype
# assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export."
assert quantize is None or quantize == "{ }", f"quantize not valid for exported DSO model. Specify quantization during export."
try:
model = model_
# Replace model forward with the AOT-compiled forward
# This is a hacky way to quickly demo AOTI's capability.
# model is still a Python object, and any mutation to its
# attributes will NOT be seen on by AOTI-compiled forward
# function, e.g. calling model.setup_cache will NOT touch
# AOTI compiled and maintained model buffers such as kv_cache.
model.forward = torch._export.aot_load(str(dso_path.absolute()), device)
except:
raise RuntimeError(f"Failed to load AOTI compiled {dso_path}")
elif pte_path:
# make sure user did not try to set dtype
# assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export."
assert quantize is None or quantize == "{ }", f"quantize not valid for exported PTE model. Specify quantization during export."
try:
from model_et import PTEModel
model = PTEModel(model_.config, pte_path)
except Exception as e:
raise RuntimeError(f"Failed to load ET compiled {pte_path}")
else:
model = model_

# Add new CLI arg
if quantize:
device_sync(device=device)
t0q = time.time()
quantize_model(model, quantize)
device_sync(device=device) # MKG
print(f"Time to quantize model: {time.time() - t0q:.02f} seconds")

# dtype:
if model_dtype:
model.to(dtype=name_to_dtype(model_dtype))

# will add a version of _load_inference_model in future
# (need additional args)
if is_speculative:
draft_model = _load_model(
draft_checkpoint_path,
Expand All @@ -456,9 +484,6 @@ def _main(
else:
draft_model = None

device_sync(device=device) # MKG
print(f"Time to load model: {time.time() - t0:.02f} seconds")

tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
print (encoded)
Expand Down

0 comments on commit 4c003a6

Please sign in to comment.