@@ -994,6 +994,8 @@ class DynamicCache(Cache):
994994 ```
995995 """
996996
997+ _export_registered = False
998+
997999 def __init__ (
9981000 self ,
9991001 ddp_cache_data : Optional [Iterable [tuple [torch .Tensor , torch .Tensor ]]] = None ,
@@ -1047,6 +1049,34 @@ def __init__(
10471049 else :
10481050 super ().__init__ (layers = layers )
10491051
1052+ self ._register_export_support ()
1053+
1054+ @classmethod
1055+ def _register_export_support (cls ):
1056+ """
1057+ Utilities for `DynamicCache` <> torch.export support
1058+ """
1059+ if cls ._export_registered :
1060+ return
1061+
1062+ # Pytree registration causes memory leak for FSDP runs, see here: https://github.com/huggingface/transformers/issues/39795
1063+ if is_torch_greater_or_equal ("2.3" ) and not is_fsdp_enabled ():
1064+ torch .utils ._pytree .register_pytree_node (
1065+ DynamicCache ,
1066+ lambda dynamic_cache : torch .utils ._pytree ._dict_flatten (cls ._get_cache_dict (dynamic_cache )),
1067+ cls ._unflatten_dynamic_cache ,
1068+ serialized_type_name = f"{ DynamicCache .__module__ } .{ DynamicCache .__name__ } " ,
1069+ flatten_with_keys_fn = lambda dynamic_cache : torch .utils ._pytree ._dict_flatten_with_keys (
1070+ cls ._get_cache_dict (dynamic_cache )
1071+ ),
1072+ )
1073+ # TODO (tmanlaibaatar) This won't be needed in torch 2.7.
1074+ torch .fx ._pytree .register_pytree_flatten_spec (
1075+ DynamicCache , lambda cache , spec : torch .fx ._pytree ._dict_flatten_spec (cls ._get_cache_dict (cache ), spec )
1076+ )
1077+
1078+ cls ._export_registered = True
1079+
10501080 def to_legacy_cache (self ) -> tuple [tuple [torch .Tensor , torch .Tensor ]]:
10511081 """
10521082 Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for
@@ -1070,12 +1100,9 @@ def from_legacy_cache(cls, past_key_values: tuple[tuple[torch.Tensor, torch.Tens
10701100 cache .update (key_states , value_states , layer_idx )
10711101 return cache
10721102
1073-
1074- # Utilities for `DynamicCache` <> torch.export support
1075- # Pytree registration is not supported for FSDP runs, see here: https://github.com/huggingface/transformers/issues/39795
1076- if is_torch_greater_or_equal ("2.3" ) and not is_fsdp_enabled ():
1077-
1078- def _get_cache_dict (cache : DynamicCache ):
1103+ @staticmethod
1104+ def _get_cache_dict (cache ):
1105+ """Convert cache to dictionary format for pytree operations."""
10791106 if any (not isinstance (layer , (DynamicLayer , DynamicSlidingWindowLayer )) for layer in cache .layers ):
10801107 raise RuntimeError ("This pytree flattening function should only be applied to DynamicCache" )
10811108
@@ -1089,12 +1116,10 @@ def _get_cache_dict(cache: DynamicCache):
10891116 "value_cache" : [layer .values for layer in cache .layers if layer .values is not None ],
10901117 }
10911118
1092- def _unflatten_dynamic_cache (
1093- values ,
1094- context : torch .utils ._pytree .Context ,
1095- ):
1119+ @classmethod
1120+ def _unflatten_dynamic_cache (cls , values , context : torch .utils ._pytree .Context ):
10961121 dictionary = torch .utils ._pytree ._dict_unflatten (values , context )
1097- cache = DynamicCache ()
1122+ cache = cls ()
10981123 # Reconstruct layers from keys and values lists
10991124 key_list = dictionary .get ("key_cache" , [])
11001125 value_list = dictionary .get ("value_cache" , [])
@@ -1104,20 +1129,6 @@ def _unflatten_dynamic_cache(
11041129 cache .update (key , value , idx )
11051130 return cache
11061131
1107- torch .utils ._pytree .register_pytree_node (
1108- DynamicCache ,
1109- lambda dynamic_cache : torch .utils ._pytree ._dict_flatten (_get_cache_dict (dynamic_cache )),
1110- _unflatten_dynamic_cache ,
1111- serialized_type_name = f"{ DynamicCache .__module__ } .{ DynamicCache .__name__ } " ,
1112- flatten_with_keys_fn = lambda dynamic_cache : torch .utils ._pytree ._dict_flatten_with_keys (
1113- _get_cache_dict (dynamic_cache )
1114- ),
1115- )
1116- # TODO (tmanlaibaatar) This won't be needed in torch 2.7.
1117- torch .fx ._pytree .register_pytree_flatten_spec (
1118- DynamicCache , lambda cache , spec : torch .fx ._pytree ._dict_flatten_spec (_get_cache_dict (cache ), spec )
1119- )
1120-
11211132
11221133class OffloadedCache (Cache ):
11231134 """
0 commit comments