Skip to content

Commit

Permalink
Merge pull request #26 from risingsunomi/pr139-dev-oct24
Browse files Browse the repository at this point in the history
Pr139 dev oct24
  • Loading branch information
risingsunomi authored Oct 14, 2024
2 parents c12526f + de23294 commit d5a02be
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 101 deletions.
144 changes: 65 additions & 79 deletions exo/inference/torch/tests/test_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from exo.inference.torch.inference import TorchDynamicShardInferenceEngine
from exo.download.hf.hf_shard_download import HFShardDownloader
from exo.inference.inference_engine import InferenceEngine
from exo.inference.torch.utils import print_ram_stats

import numpy as np

Expand All @@ -20,37 +19,37 @@ async def test_inference_engine(

prompt = "In a single word only, what is the last name of the current president of the USA?"

# shard = Shard(
# model_id=model_id,
# start_layer=0,
# end_layer=n_layers-1,
# n_layers=n_layers
# )
#
# resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt(
# "A",
# shard=shard,
# prompt=prompt
# )
#
# print("\n------------resp_full---------------\n")
# print(resp_full)
# print("\n------------resp_full---------------\n")
#
# time.sleep(5)
#
# next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
# "A",
# shard=shard,
# input_data=resp_full,
# inference_state=inference_state_full,
# )
#
# print("\n------------next_resp_full---------------\n")
# print(next_resp_full)
# print("\n------------next_resp_full---------------\n")
#
# time.sleep(5)
shard = Shard(
model_id=model_id,
start_layer=0,
end_layer=n_layers-1,
n_layers=n_layers
)

resp_full, inference_state_full, _ = await inference_engine_1.infer_prompt(
"A",
shard=shard,
prompt=prompt
)

print("\n------------resp_full---------------\n")
print(resp_full)
print("\n------------resp_full---------------\n")

time.sleep(5)

next_resp_full, _next_inference_state_full, _ = await inference_engine_1.infer_tensor(
"A",
shard=shard,
input_data=resp_full,
inference_state=inference_state_full,
)

print("\n------------next_resp_full---------------\n")
print(next_resp_full)
print("\n------------next_resp_full---------------\n")

time.sleep(5)

half_layer = int(n_layers/2)

Expand All @@ -68,7 +67,6 @@ async def test_inference_engine(
n_layers=n_layers
)

print_ram_stats()
resp1, inference_state_1, _ = await inference_engine_1.infer_prompt(
"B",
shard=resp_shard,
Expand All @@ -79,7 +77,6 @@ async def test_inference_engine(
print(resp1)
print("\n------------resp1---------------\n")

print_ram_stats()
time.sleep(5)

resp2, inference_state_2, _ = await inference_engine_2.infer_tensor(
Expand All @@ -93,63 +90,52 @@ async def test_inference_engine(
print(resp2)
print("\n------------resp2---------------\n")

#resp3, inference_state_3, _ = await inference_engine_1.infer_tensor(
# "B",
# shard=resp_shard,
# input_data=resp2,
# inference_state=inference_state_2,
#)
resp3, inference_state_3, _ = await inference_engine_1.infer_tensor(
"B",
shard=resp_shard,
input_data=resp2,
inference_state=inference_state_2,
)

#print("\n------------resp3---------------\n")
#print(resp3)
#print("\n------------resp3---------------\n")
print("\n------------resp3---------------\n")
print(resp3)
print("\n------------resp3---------------\n")

#resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
# "B",
# shard=resp_shard2,
# input_data=resp3,
# inference_state=inference_state_3,
#)
resp4, _inference_state_4, _ = await inference_engine_2.infer_tensor(
"B",
shard=resp_shard2,
input_data=resp3,
inference_state=inference_state_3,
)

#print("\n------------resp4---------------\n")
#print(resp4)
#print("\n------------resp4---------------\n")
print("\n------------resp4---------------\n")
print(resp4)
print("\n------------resp4---------------\n")

#assert np.array_equal(resp_full, resp2)
#assert np.array_equal(next_resp_full, resp4)
assert np.array_equal(resp_full, resp2)
assert np.array_equal(next_resp_full, resp4)

if __name__ == '__main__':
try:
print("\n\n -------- TEST Qwen/Qwen2.5-3B-Instruct -------- \n\n")
asyncio.run(test_inference_engine(
TorchDynamicShardInferenceEngine(HFShardDownloader()),
TorchDynamicShardInferenceEngine(HFShardDownloader()),
"Qwen/Qwen2.5-3B-Instruct",
36
))
except Exception as err:
print(f"\n\n !!!!!!!!!!! QWEN2 TEST FAILED \n{err}\n")

#try:
# print("\n-------- Test unsloth/Llama-3.2-1B-Instruct ----------\n")
# print("\n\n -------- TEST Qwen/Qwen2.5-3B-Instruct -------- \n\n")
# asyncio.run(test_inference_engine(
# TorchDynamicShardInferenceEngine(HFShardDownloader()),
# TorchDynamicShardInferenceEngine(HFShardDownloader()),
# "unsloth/Llama-3.2-1B-Instruct",
# 24
# "Qwen/Qwen2.5-3B-Instruct",
# 36
# ))
#except Exception as err:
# print(f"\n\n !!!!!!!!!!! meta-llama/Llama-3.2-1B-Instruct TEST FAILED \n{err}\n")
# print(f"\n!!!! QWEN2 TEST FAILED \n{err}\n")

#try:
# print("\n-------- Test unsloth/Meta-Llama-3.1-8B-Instruct ----------\n")
# asyncio.run(test_inference_engine(
# TorchDynamicShardInferenceEngine(HFShardDownloader()),
# TorchDynamicShardInferenceEngine(HFShardDownloader()),
# "unsloth/Meta-Llama-3.1-8B-Instruct",
# 32
# ))
#except Exception as err:
# print(f"\n\n !!!!!!!!!!! unsloth/Llama-3.1-8B-Instruct TEST FAILED \n{err}\n")
try:
print("\n-------- Test unsloth/Meta-Llama-3.1-8B-Instruct ----------\n")
asyncio.run(test_inference_engine(
TorchDynamicShardInferenceEngine(HFShardDownloader()),
TorchDynamicShardInferenceEngine(HFShardDownloader()),
"unsloth/Meta-Llama-3.1-8B-Instruct",
32
))
except Exception as err:
print(f"\n!!!! unsloth/Meta-Llama-3.1-8B-Instruct TEST FAILED \n{err}\n")


15 changes: 3 additions & 12 deletions exo/inference/torch/tests/test_split_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,10 @@
from exo.download.hf.hf_helpers import get_weight_map
from exo.download.hf.hf_shard_download import HFShardDownloader
from exo.inference.shard import Shard
from exo.inference.torch.utils import print_cuda_vram_stats

from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

def print_ram_stats():
if torch.cuda.is_available():
allocated_memory = torch.cuda.memory_allocated()
max_memory = torch.cuda.max_memory_allocated()
cached_memory = torch.cuda.memory_reserved()

print("Cuda stats")
print(f'Allocated memory: {allocated_memory / 1024**2} MB')
print(f'Max allocated memory: {max_memory / 1024**2} MB')
print(f'Cached memory: {cached_memory / 1024**2} MB')

def load_model(
repo_id: str,
shard: Shard,
Expand Down Expand Up @@ -118,7 +108,8 @@ def load_model(
local_files_only=True,
)

print_ram_stats()
if torch.cuda.is_available() and device == "cuda":
print_cuda_vram_stats()

prompt = "In a single word only, what color is a red apple?"

Expand Down
22 changes: 12 additions & 10 deletions exo/inference/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,15 @@ def extract_layers(

return layer_weight_map

def print_ram_stats():
if torch.cuda.is_available():
allocated_memory = torch.cuda.memory_allocated()
max_memory = torch.cuda.max_memory_allocated()
cached_memory = torch.cuda.memory_reserved()

print("Cuda stats")
print(f'Allocated memory: {allocated_memory / 1024**2} MB')
print(f'Max allocated memory: {max_memory / 1024**2} MB')
print(f'Cached memory: {cached_memory / 1024**2} MB')
def print_cuda_vram_stats():
"""
Prints CUDA VRAM stats being used by pytorch
"""
allocated_memory = torch.cuda.memory_allocated()
max_memory = torch.cuda.max_memory_allocated()
cached_memory = torch.cuda.memory_reserved()

print("CUDA stats")
print(f'Allocated memory: {allocated_memory / 1024**2} MB')
print(f'Max allocated memory: {max_memory / 1024**2} MB')
print(f'Cached memory: {cached_memory / 1024**2} MB')

0 comments on commit d5a02be

Please sign in to comment.