Skip to content

Commit cf6952d

Browse files
Add host memory profiling to HabanaMemoryProfiler (vllm-project#51)
1 parent ab359ac commit cf6952d

File tree

3 files changed

+56
-31
lines changed

3 files changed

+56
-31
lines changed

vllm/executor/habana_executor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,7 @@ def initialize_cache(self, num_gpu_blocks : int, num_cpu_blocks) -> None:
8080

8181
with HabanaMemoryProfiler() as cache_init_m:
8282
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
83-
logger.info(f"init_cache_engine took "
84-
f"{format_bytes(cache_init_m.consumed_memory)} ({cache_init_m.consumed_memory/HabanaMemoryProfiler.total_memory():.2%} of total memory, gpu_memory_utilization: {self.cache_config.gpu_memory_utilization}, {format_bytes(HabanaMemoryProfiler.current_memory_usage())}/{format_bytes(HabanaMemoryProfiler.total_memory())} used)")
83+
logger.info(f"init_cache_engine took {cache_init_m.get_summary_string()}")
8584

8685
def execute_model(
8786
self,

vllm/utils.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -496,33 +496,55 @@ class HabanaMemoryProfiler:
496496
def __init__(self, device=None):
497497
self.device = device
498498

499-
def current_memory_usage() -> float:
500-
# Return the memory usage in bytes.
499+
def current_device_memory_usage() -> float:
500+
# Return the device memory usage in bytes.
501501
free_hpu_memory, total_hpu_memory = torch.hpu.mem_get_info()
502502
return total_hpu_memory - free_hpu_memory
503503

504-
def current_free_memory() -> float:
505-
# Return the memory usage in bytes.
504+
def current_free_device_memory() -> float:
505+
# Return the device memory usage in bytes.
506506
free_hpu_memory, _ = torch.hpu.mem_get_info()
507507
return free_hpu_memory
508508

509-
def total_memory() -> float:
510-
# Return the memory usage in bytes.
509+
def total_device_memory() -> float:
510+
# Return the device memory usage in bytes.
511511
_, total_hpu_memory = torch.hpu.mem_get_info()
512512
return total_hpu_memory
513513

514+
def current_host_memory_usage() -> float:
515+
# Return the host memory usage in bytes.
516+
return HabanaMemoryProfiler.total_host_memory() - HabanaMemoryProfiler.current_free_host_memory()
517+
518+
def current_free_host_memory() -> float:
519+
# Return the host memory usage in bytes.
520+
return psutil.virtual_memory().available
521+
522+
def total_host_memory() -> float:
523+
# Return the host memory usage in bytes.
524+
return psutil.virtual_memory().total
525+
526+
def get_summary_string(self):
527+
if getattr(self, 'final_device_memory', None) is None or getattr(self, 'final_host_memory', None) is None:
528+
raise RuntimeError("HabanaMemoryProfiler.get_summary_string() can only be called after closing context manager")
529+
return (f"{format_bytes(self.consumed_device_memory)} of device memory ({format_bytes(self.final_device_memory)}/{format_bytes(HabanaMemoryProfiler.total_device_memory())} used) and "
530+
f"{format_bytes(self.consumed_host_memory)} of host memory ({format_bytes(self.final_host_memory)}/{format_bytes(HabanaMemoryProfiler.total_host_memory())} used)")
531+
514532
def __enter__(self):
515533
# Force garbage collection
516534
gc.collect()
517-
self.initial_memory = HabanaMemoryProfiler.current_memory_usage()
535+
self.initial_device_memory = HabanaMemoryProfiler.current_device_memory_usage()
536+
self.initial_host_memory = HabanaMemoryProfiler.current_host_memory_usage()
518537
# This allows us to call methods of the context manager if needed
519538
return self
520539

521540
def __exit__(self, exc_type, exc_val, exc_tb):
522541
# Force garbage collection
523542
gc.collect()
524-
self.final_memory = HabanaMemoryProfiler.current_memory_usage()
525-
self.consumed_memory = self.final_memory - self.initial_memory
543+
self.final_device_memory = HabanaMemoryProfiler.current_device_memory_usage()
544+
self.final_host_memory = HabanaMemoryProfiler.current_host_memory_usage()
545+
self.consumed_device_memory = self.final_device_memory - self.initial_device_memory
546+
self.consumed_host_memory = self.final_host_memory - self.initial_host_memory
547+
526548

527549

528550
# Adapted from https://stackoverflow.com/a/49361727

vllm/worker/habana_model_runner.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -271,21 +271,25 @@ def __init__(
271271

272272
def load_model(self) -> None:
273273
with HabanaMemoryProfiler() as m:
274-
self.model = get_model(
275-
model_config=self.model_config,
276-
device_config=self.device_config,
277-
load_config=self.load_config,
278-
lora_config=self.lora_config,
279-
vision_language_config=self.vision_language_config,
280-
parallel_config=self.parallel_config,
281-
scheduler_config=self.scheduler_config,
282-
)
283-
# FIXME: Running with disable_tensor_cache=True causes RuntimeErrors. This needs to be debugged
284-
self.model = _maybe_wrap_in_hpu_graph(self.model)
274+
with HabanaMemoryProfiler() as m_getmodel:
275+
self.model = get_model(
276+
model_config=self.model_config,
277+
device_config=self.device_config,
278+
load_config=self.load_config,
279+
lora_config=self.lora_config,
280+
vision_language_config=self.vision_language_config,
281+
parallel_config=self.parallel_config,
282+
scheduler_config=self.scheduler_config,
283+
)
284+
logger.info(f"Pre-loading model weights on {next(self.model.parameters()).device} took {m_getmodel.get_summary_string()}")
285285

286-
self.model_memory_usage = m.consumed_memory
287-
logger.info(f"Loading model weights took "
288-
f"{format_bytes(self.model_memory_usage)} ({format_bytes(HabanaMemoryProfiler.current_memory_usage())}/{format_bytes(HabanaMemoryProfiler.total_memory())} used)")
286+
# FIXME: Running with disable_tensor_cache=True causes RuntimeErrors. This needs to be debugged
287+
with HabanaMemoryProfiler() as m_wrap:
288+
self.model = _maybe_wrap_in_hpu_graph(self.model)
289+
logger.info(f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}")
290+
291+
self.model_memory_usage = m.consumed_device_memory
292+
logger.info(f"Loading model weights took in total {m.get_summary_string()}")
289293

290294
if self.lora_config:
291295
assert hasattr(self.model, "supported_lora_modules"
@@ -932,12 +936,12 @@ def warmup_scenario(self, batch_size, seq_len, is_prompt, kv_caches) -> None:
932936
gc.collect()
933937

934938
def log_warmup(self, phase, i, max_i, batch_size, seq_len):
935-
free_mem = format_bytes(HabanaMemoryProfiler.current_free_memory())
939+
free_mem = format_bytes(HabanaMemoryProfiler.current_free_device_memory())
936940
logger.info(f"[Warmup][{phase}][{i+1}/{max_i}] batch_size:{batch_size} seq_len:{seq_len} free_mem:{free_mem}")
937941

938942
def warmup_all_buckets(self, buckets, is_prompt, kv_caches):
939943
for i, (batch_size, seq_len) in enumerate(reversed(buckets)):
940-
mem_usage = 100.0 * HabanaMemoryProfiler.current_memory_usage() / HabanaMemoryProfiler.total_memory()
944+
mem_usage = 100.0 * HabanaMemoryProfiler.current_device_memory_usage() / HabanaMemoryProfiler.total_device_memory()
941945
self.log_warmup('Prompt' if is_prompt else 'Decode', i, len(buckets), batch_size, seq_len)
942946
self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches)
943947

@@ -966,7 +970,7 @@ def warmup_graphs(self, strategy, buckets, is_prompt, kv_caches, available_mem):
966970
self.log_warmup(phase, idx, num_candidates, batch_size, seq_len)
967971
with HabanaMemoryProfiler() as mem_prof:
968972
self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches)
969-
used_mem = align_workers(mem_prof.consumed_memory, torch.distributed.ReduceOp.MAX)
973+
used_mem = align_workers(mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX)
970974
available_mem -= used_mem
971975
total_mem += used_mem
972976
total_batch_seq += batch_seq
@@ -980,14 +984,14 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
980984
logger.info("Skipping warmup...")
981985
return
982986
self.profiler.start('internal', 'warmup')
983-
start_mem = HabanaMemoryProfiler.current_memory_usage()
987+
start_mem = HabanaMemoryProfiler.current_device_memory_usage()
984988
start_time = time.perf_counter()
985989
self.warmup_all_buckets(self.prompt_buckets, True, kv_caches)
986990
self.warmup_all_buckets(self.decode_buckets, False, kv_caches)
987991

988992
if not self.enforce_eager:
989993
mem_margin = 1.0 - float(os.environ.get('VLLM_GRAPH_MEM_MARGIN', '0.02'))
990-
free_mem = mem_margin * HabanaMemoryProfiler.current_free_memory()
994+
free_mem = mem_margin * HabanaMemoryProfiler.current_free_device_memory()
991995
free_mem = align_workers(free_mem, torch.distributed.ReduceOp.MIN)
992996
prompt_graph_mem_ratio = float(os.environ.get('VLLM_GRAPH_PROMPT_RATIO', '0.5'))
993997
prompt_available_memory = prompt_graph_mem_ratio * free_mem
@@ -998,7 +1002,7 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
9981002
self.warmup_graphs(decode_strategy, self.decode_buckets, False, kv_caches, decode_available_memory)
9991003

10001004
end_time = time.perf_counter()
1001-
end_mem = HabanaMemoryProfiler.current_memory_usage()
1005+
end_mem = HabanaMemoryProfiler.current_device_memory_usage()
10021006
elapsed_time = end_time - start_time
10031007
logger.info(f"Warmup finished in {elapsed_time:.0f} secs, allocated {format_bytes(end_mem - start_mem)} of device memory")
10041008
self.profiler.end()

0 commit comments

Comments
 (0)