@@ -271,21 +271,25 @@ def __init__(
271271
272272 def load_model (self ) -> None :
273273 with HabanaMemoryProfiler () as m :
274- self . model = get_model (
275- model_config = self .model_config ,
276- device_config = self .device_config ,
277- load_config = self .load_config ,
278- lora_config = self .lora_config ,
279- vision_language_config = self .vision_language_config ,
280- parallel_config = self .parallel_config ,
281- scheduler_config = self .scheduler_config ,
282- )
283- # FIXME: Running with disable_tensor_cache=True causes RuntimeErrors. This needs to be debugged
284- self . model = _maybe_wrap_in_hpu_graph (self .model )
274+ with HabanaMemoryProfiler () as m_getmodel :
275+ self .model = get_model (
276+ model_config = self .model_config ,
277+ device_config = self .device_config ,
278+ load_config = self .load_config ,
279+ lora_config = self .lora_config ,
280+ vision_language_config = self .vision_language_config ,
281+ parallel_config = self .parallel_config ,
282+ scheduler_config = self . scheduler_config ,
283+ )
284+ logger . info ( f"Pre-loading model weights on { next (self .model . parameters ()). device } took { m_getmodel . get_summary_string () } " )
285285
286- self .model_memory_usage = m .consumed_memory
287- logger .info (f"Loading model weights took "
288- f"{ format_bytes (self .model_memory_usage )} ({ format_bytes (HabanaMemoryProfiler .current_memory_usage ())} /{ format_bytes (HabanaMemoryProfiler .total_memory ())} used)" )
286+ # FIXME: Running with disable_tensor_cache=True causes RuntimeErrors. This needs to be debugged
287+ with HabanaMemoryProfiler () as m_wrap :
288+ self .model = _maybe_wrap_in_hpu_graph (self .model )
289+ logger .info (f"Wrapping in HPU Graph took { m_wrap .get_summary_string ()} " )
290+
291+ self .model_memory_usage = m .consumed_device_memory
292+ logger .info (f"Loading model weights took in total { m .get_summary_string ()} " )
289293
290294 if self .lora_config :
291295 assert hasattr (self .model , "supported_lora_modules"
@@ -932,12 +936,12 @@ def warmup_scenario(self, batch_size, seq_len, is_prompt, kv_caches) -> None:
932936 gc .collect ()
933937
934938 def log_warmup (self , phase , i , max_i , batch_size , seq_len ):
935- free_mem = format_bytes (HabanaMemoryProfiler .current_free_memory ())
939+ free_mem = format_bytes (HabanaMemoryProfiler .current_free_device_memory ())
936940 logger .info (f"[Warmup][{ phase } ][{ i + 1 } /{ max_i } ] batch_size:{ batch_size } seq_len:{ seq_len } free_mem:{ free_mem } " )
937941
938942 def warmup_all_buckets (self , buckets , is_prompt , kv_caches ):
939943 for i , (batch_size , seq_len ) in enumerate (reversed (buckets )):
940- mem_usage = 100.0 * HabanaMemoryProfiler .current_memory_usage () / HabanaMemoryProfiler .total_memory ()
944+ mem_usage = 100.0 * HabanaMemoryProfiler .current_device_memory_usage () / HabanaMemoryProfiler .total_device_memory ()
941945 self .log_warmup ('Prompt' if is_prompt else 'Decode' , i , len (buckets ), batch_size , seq_len )
942946 self .warmup_scenario (batch_size , seq_len , is_prompt , kv_caches )
943947
@@ -966,7 +970,7 @@ def warmup_graphs(self, strategy, buckets, is_prompt, kv_caches, available_mem):
966970 self .log_warmup (phase , idx , num_candidates , batch_size , seq_len )
967971 with HabanaMemoryProfiler () as mem_prof :
968972 self .warmup_scenario (batch_size , seq_len , is_prompt , kv_caches )
969- used_mem = align_workers (mem_prof .consumed_memory , torch .distributed .ReduceOp .MAX )
973+ used_mem = align_workers (mem_prof .consumed_device_memory , torch .distributed .ReduceOp .MAX )
970974 available_mem -= used_mem
971975 total_mem += used_mem
972976 total_batch_seq += batch_seq
@@ -980,14 +984,14 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
980984 logger .info ("Skipping warmup..." )
981985 return
982986 self .profiler .start ('internal' , 'warmup' )
983- start_mem = HabanaMemoryProfiler .current_memory_usage ()
987+ start_mem = HabanaMemoryProfiler .current_device_memory_usage ()
984988 start_time = time .perf_counter ()
985989 self .warmup_all_buckets (self .prompt_buckets , True , kv_caches )
986990 self .warmup_all_buckets (self .decode_buckets , False , kv_caches )
987991
988992 if not self .enforce_eager :
989993 mem_margin = 1.0 - float (os .environ .get ('VLLM_GRAPH_MEM_MARGIN' , '0.02' ))
990- free_mem = mem_margin * HabanaMemoryProfiler .current_free_memory ()
994+ free_mem = mem_margin * HabanaMemoryProfiler .current_free_device_memory ()
991995 free_mem = align_workers (free_mem , torch .distributed .ReduceOp .MIN )
992996 prompt_graph_mem_ratio = float (os .environ .get ('VLLM_GRAPH_PROMPT_RATIO' , '0.5' ))
993997 prompt_available_memory = prompt_graph_mem_ratio * free_mem
@@ -998,7 +1002,7 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
9981002 self .warmup_graphs (decode_strategy , self .decode_buckets , False , kv_caches , decode_available_memory )
9991003
10001004 end_time = time .perf_counter ()
1001- end_mem = HabanaMemoryProfiler .current_memory_usage ()
1005+ end_mem = HabanaMemoryProfiler .current_device_memory_usage ()
10021006 elapsed_time = end_time - start_time
10031007 logger .info (f"Warmup finished in { elapsed_time :.0f} secs, allocated { format_bytes (end_mem - start_mem )} of device memory" )
10041008 self .profiler .end ()
0 commit comments