Skip to content

Commit

Permalink
Merge pull request #29 from risingsunomi/pr139-dev-oct24
Browse files Browse the repository at this point in the history
Pr139 dev oct24
  • Loading branch information
risingsunomi authored Oct 20, 2024
2 parents 69a8955 + d2302cc commit 35c32eb
Show file tree
Hide file tree
Showing 7 changed files with 540 additions and 147 deletions.
140 changes: 50 additions & 90 deletions exo/inference/torch/model/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from exo.inference.shard import Shard
from exo.helpers import DEBUG
from exo.inference.torch.utils import extract_layers
from exo.inference.torch.model.hf_safe_tensor_shard import HFSafeTensorShard

from transformers import (
AutoModelForCausalLM,
Expand Down Expand Up @@ -52,124 +53,86 @@ def __init__(

# class vars
self.shard = shard
self.hidden_states = None
self.input_ids = None
self.inputs_embeds = None
self.attention_mask = None
self.position_embeddings = None
self.past_key_values = None
self.cache_position = None
self.position_ids = None
self.causal_mask = None
self.local_model_path = local_model_path
self.is_sharded_model = False

self.weight_map = weight_map
self.device = device
self.dtype = dtype
self.device_map = device_map
self.offload_buffers = offload_buffers
self.model_safetensors_path = self.local_model_path/"model.safetensors.index.json"
self.safetensor_sharder = HFSafeTensorShard(
self.local_model_path,
self.shard
)
# setup logit processors
self.logits_processor = LogitsProcessorList([
TopKLogitsWarper(top_k),
TemperatureLogitsWarper(temp),
TopPLogitsWarper(top_p)
])

self.device = device
self.dtype = dtype
self.device_map = device_map

self.offload_buffers = offload_buffers

self.model_safetensors_path = self.local_model_path/"model.safetensors.index.json"

# setup pytorch and transformer llm
# setup sharded llm
try:
if weight_map:
print("loading shard model")
self.llm_model = self.load_sharded_model(
shard,
weight_map,
offload_buffers=self.offload_buffers
)

self.is_sharded_model = True

# clear out edited safetensor json
# this is needed because shard downloader just
# appends and not redownloads the file
os.remove(self.model_safetensors_path)

self.llm_model = AutoModelForCausalLM.from_config(self.llm_model_config).to(self.device)
self.model = self.llm_model.model.to(self.device)
else:
print("loading full model")
self.llm_model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=self.local_model_path,
torch_dtype=self.dtype,
device_map=self.device_map,
offload_buffers=offload_buffers
).to(self.device)

self.llm_model = self.load_sharded_model()
self.model = self.llm_model.model.to(self.device)

# restore originals for next run, if one
self.safetensor_sharder.restore_backups()
except Exception as err:
print(f"error loading and splitting model: {err}")
print(f"error loading and sharding model: {err}")
raise

def load_sharded_model(
self,
shard: Shard,
weight_map: dict,
offload_buffers: bool
) -> AutoModelForCausalLM:
# forward variables
self.hidden_states = None
self.input_ids = None
self.inputs_embeds = None
self.attention_mask = None
self.position_embeddings = None
self.past_key_values = None
self.cache_position = None
self.position_ids = None
self.causal_mask = None

def load_sharded_model(self) -> AutoModelForCausalLM:
"""
Loads sharded version of model where only needed
weights are loaded for necessary layers
Args:
Returns:
llm_model (AutoModelForCausalLM) - sharded llm model with only needed layers loaded
"""
if DEBUG >= 4:
print("load_sharded_model called")
print(f"shard: {shard}")

# break out layers per shard range
layer_weight_map = extract_layers(
weight_map,
shard
)

# rewrite model.safetensors.index.json for only needed layers
try:
mst_json = {}
with open(self.model_safetensors_path, "r") as mst_file:
mst_json = json.load(mst_file)
mst_json["weight_map"] = layer_weight_map

if DEBUG >= 4:
print(f"rewritten safetensor index \n{json.dumps(mst_json, indent=4)}")

os.remove(self.model_safetensors_path)

with open(self.model_safetensors_path, "w") as mst_file:
json.dump(mst_json, mst_file, indent=4)
except Exception as err:
print(f"err: {err}")
raise
# modify safetensor
self.safetensor_sharder.modify_safetensor()
self.safetensor_sharder.create_safetensor_index()
self.safetensor_sharder.shard_safetensor_index(self.weight_map)

# load model
try:
shard_num_hidden_layers = shard.end_layer - shard.start_layer
shard_num_hidden_layers = (self.shard.end_layer - self.shard.start_layer) + 1
if DEBUG >= 4:
print(f"config with {shard_num_hidden_layers} layers")

llm_model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=self.local_model_path,
device_map=self.device_map,
torch_dtype=self.dtype,
offload_buffers=offload_buffers,
offload_buffers=self.offload_buffers,
local_files_only=True,
num_hidden_layers=shard_num_hidden_layers
num_hidden_layers=shard_num_hidden_layers,
use_safetensors=True,
low_cpu_mem_usage=True
)

return llm_model.to(self.device)
# restore backup for next run
self.safetensor_sharder.restore_backups()

if self.device_map == "auto":
return llm_model
else:
return llm_model.to(self.device)

except Exception as err:
print(f"err: {err}")
Expand Down Expand Up @@ -265,14 +228,11 @@ def forward(
self.cache_position = model_inputs["cache_position"]
self.past_key_values = model_inputs["past_key_values"]

if DEBUG >= 4:
print(f"model_inputs: {model_inputs}")
if DEBUG >= 4:
print(f"model_inputs: {model_inputs}")

# run through decoder layers
if self.is_sharded_model:
layer_amt = range(self.shard.end_layer - self.shard.start_layer)
else:
layer_amt = range(self.shard.start_layer, self.shard.end_layer)
layer_amt = range(self.shard.end_layer - self.shard.start_layer)

if DEBUG >= 4:
print(f"hidden_states: {self.hidden_states}")
Expand Down Expand Up @@ -317,7 +277,7 @@ def forward(
# shard is last layer says true at the start and not detecting last layer correctly
if self.shard.is_last_layer():
self.hidden_states = self.model.norm(self.hidden_states)
if use_legacy_cache and self.next_decoder_cache is not None:
if use_legacy_cache:
self.past_key_values = self.next_decoder_cache.to_legacy_cache()
else:
self.past_key_values = self.next_decoder_cache
Expand Down
Loading

0 comments on commit 35c32eb

Please sign in to comment.