33import json
44import math
55import os
6+ import types
67from abc import ABC , abstractmethod
78from dataclasses import dataclass , field
89from enum import Enum , EnumMeta
910from pathlib import Path
1011from typing import (TYPE_CHECKING , Any , ClassVar , Dict , List , Literal , Optional ,
11- TypeAlias , Union )
12+ Type , TypeAlias , TypeVar , Union , get_args , get_origin )
1213
1314import torch
1415import yaml
6061
6162# TODO[chunweiy]: move the following symbols back to utils scope, and remove the following import
6263
64+ TypeBaseModel = TypeVar ("T" , bound = BaseModel )
65+
6366
6467def Field (default : Any = ...,
6568 * ,
@@ -597,6 +600,62 @@ def pybind_equals(obj0, obj1):
597600 return False
598601 return True
599602
603+ @classmethod
604+ def from_pybind (cls : Type [TypeBaseModel ],
605+ pybind_instance : "PybindMirror" ) -> TypeBaseModel :
606+ """Construct an instance of the given class from the fields in the given
607+ pybind class instance.
608+
609+ Args:
610+ cls: Type of the class to construct, must be a subclass of pydantic
611+ BaseModel
612+ pybind_instance: Instance of the pybind class to construct from its
613+ fields
614+
615+ Notes:
616+ When a field value is None in the pybind class, but it's not
617+ optional and has a default value in the BaseModel class, it would
618+ get the default value defined in the BaseModel class.
619+
620+ Returns:
621+ Instance of the given class, populated with the fields of the given
622+ pybind instance
623+ """ # noqa: D205
624+ assert issubclass (cls , BaseModel )
625+
626+ # Some of the fields are optional in the C++ class but in python they aren't
627+ # optional and have a default value, so copy the value from C++ instance
628+ # only if it has a value, so otherwise the default value defined in the
629+ # python class would be set.
630+ def _is_optional_type (annotation : Any ) -> bool :
631+ """Returns True if a type annotation represents an Optional type
632+ (Optional[X]) or a Union type that includes None (Union[X, Y, None]
633+ or X | Y | None).
634+ """ # noqa: D205
635+ origin = get_origin (annotation )
636+ args = get_args (annotation )
637+
638+ # Union is for Optional[x]
639+ # UnionType is for the new | operation in Python 3.10+
640+ return (origin is Union
641+ or origin is types .UnionType ) and type (None ) in args
642+
643+ fields_non_optional_with_default_value_in_basemodel = {
644+ field_name
645+ for field_name , field_info in cls .model_fields .items ()
646+ if not (_is_optional_type (field_info .annotation )
647+ and field_info .is_required ())
648+ }
649+
650+ kwargs = {}
651+ cpp_fields = PybindMirror .get_pybind_variable_fields (
652+ type (pybind_instance ))
653+ for field_name in cpp_fields :
654+ field_value = getattr (pybind_instance , field_name )
655+ if field_value is not None or field_name not in fields_non_optional_with_default_value_in_basemodel :
656+ kwargs [field_name ] = field_value
657+ return cls (** kwargs )
658+
600659
601660class PybindMirrorMeta (type (PybindMirror )):
602661 pass
@@ -694,11 +753,12 @@ class PeftCacheConfig(StrictBaseModel, PybindMirror):
694753 default = 0 ,
695754 description =
696755 "number of max sized 1-layer 1-module adapterSize=1 sets of weights that can be stored in host cache"
697- )
756+ ", affects host cache size and overrides value of host_cache_size" )
698757 num_device_module_layer : int = Field (
699758 default = 0 ,
700759 description =
701- "number of max sized 1-layer 1-module sets of weights that can be stored in host cache"
760+ "number of max sized 1-layer 1-module sets of weights that can be stored in device cache"
761+ ", affects device cache size and overrides value of device_cache_percent"
702762 )
703763 optimal_adapter_size : int = Field (
704764 default =
@@ -725,15 +785,17 @@ class PeftCacheConfig(StrictBaseModel, PybindMirror):
725785 max_pages_per_block_device : int = Field (
726786 default = 8 ,
727787 description = "Number of cache pages per allocation block (device)" )
728- device_cache_percent : Optional [float ] = Field (
729- default = None ,
730- description = "percent of memory after engine load to use for cache" )
731- host_cache_size : Optional [int ] = Field (
732- default = None , description = "size in bytes to use for host cache" )
788+ device_cache_percent : float = Field (
789+ default = 0.02 ,
790+ description =
791+ "Proportion of free device memory after engine load to use for cache, as a fraction from 0 to 1"
792+ )
793+ host_cache_size : int = Field (
794+ default = 1024 ** 3 , description = "size in bytes to use for host cache" )
733795 lora_prefetch_dir : Optional [str ] = Field (
734796 default = None ,
735797 description =
736- "folder to store the LoRA weights we hope to load during engine initialization"
798+ "folder to store the LoRA weights we hope to load during engine initialization, currently not supported "
737799 )
738800
739801 def _to_pybind (self ):
@@ -1083,27 +1145,6 @@ class BaseLlmArgs(StrictBaseModel):
10831145 # LoRA arguments
10841146 enable_lora : bool = Field (default = False , description = "Enable LoRA." )
10851147
1086- max_lora_rank : Optional [int ] = Field (
1087- default = None ,
1088- description = "The maximum LoRA rank." ,
1089- deprecated = "Use lora_config.max_lora_rank instead." ,
1090- status = "deprecated" ,
1091- )
1092-
1093- max_loras : int = Field (
1094- default = 4 ,
1095- description = "The maximum number of LoRA." ,
1096- deprecated = "Use lora_config.max_loras instead." ,
1097- status = "deprecated" ,
1098- )
1099-
1100- max_cpu_loras : int = Field (
1101- default = 4 ,
1102- description = "The maximum number of LoRA on CPU." ,
1103- deprecated = "Use lora_config.max_cpu_loras instead." ,
1104- status = "deprecated" ,
1105- )
1106-
11071148 lora_config : Optional [LoraConfig ] = Field (
11081149 default = None , description = "LoRA configuration for the model." )
11091150
@@ -1494,10 +1535,10 @@ def validate_build_config_remaining(self):
14941535 if self .parallel_config ._world_size == 1 and self .build_config :
14951536 self .build_config .plugin_config .nccl_plugin = None
14961537
1497- if self .enable_lora and self .lora_config is None and self . backend != 'pytorch' :
1538+ if self .enable_lora and self .backend != 'pytorch' :
14981539 self .build_config .plugin_config .lora_plugin = 'auto'
1499- if self .max_lora_rank is not None :
1500- self .build_config .lora_config .max_lora_rank = self .max_lora_rank
1540+ if self .lora_config is not None :
1541+ self .build_config .lora_config .max_lora_rank = self .lora_config . max_lora_rank
15011542
15021543 if hasattr (self ,
15031544 'enable_prompt_adapter' ) and self .enable_prompt_adapter :
@@ -1601,16 +1642,6 @@ def validate_speculative_config(self):
16011642 @model_validator (mode = "after" )
16021643 def validate_lora_config_consistency (self ):
16031644 if self .lora_config :
1604- if self .max_lora_rank is not None :
1605- logger .warning (
1606- "max_lora_rank is ignored when lora_config is provided." )
1607- if self .max_loras != self .lora_config .max_loras :
1608- logger .warning (
1609- "max_loras is ignored when lora_config is provided." )
1610- if self .max_cpu_loras != self .lora_config .max_cpu_loras :
1611- logger .warning (
1612- "max_cpu_loras is ignored when lora_config is provided." )
1613-
16141645 if len (self .lora_config .lora_dir ) == 0 :
16151646 # TODO [TRTLLM-5173]
16161647 logger .warning (
@@ -1637,6 +1668,14 @@ def validate_lora_config_consistency(self):
16371668 default_trtllm_modules_to_hf_modules .keys ())
16381669 return self
16391670
1671+ @model_validator (mode = "after" )
1672+ def validate_peft_cache_config (self ):
1673+ if self .peft_cache_config is not None and self .peft_cache_config .lora_prefetch_dir is not None :
1674+ raise ValueError (
1675+ f"lora_prefetch_dir was set to '{ self .peft_cache_config .lora_prefetch_dir } ' "
1676+ "while LoRA prefetch is not supported" )
1677+ return self
1678+
16401679 def _update_plugin_config (self , key : str , value : Any ):
16411680 setattr (self .build_config .plugin_config , key , value )
16421681
0 commit comments