@@ -1199,7 +1199,7 @@ def __init__(
1199
1199
config .head_dim if hasattr (config , "head_dim" ) else config .hidden_size // config .num_attention_heads
1200
1200
)
1201
1201
1202
- self .dtype = dtype
1202
+ self ._dtype = dtype
1203
1203
self .num_key_value_heads = (
1204
1204
config .num_attention_heads
1205
1205
if getattr (config , "num_key_value_heads" , None ) is None
@@ -1216,8 +1216,8 @@ def __init__(
1216
1216
layer_device = layer_device_map [idx ]
1217
1217
else :
1218
1218
layer_device = device
1219
- new_layer_key_cache = torch .zeros (cache_shape , dtype = self .dtype , device = layer_device )
1220
- new_layer_value_cache = torch .zeros (cache_shape , dtype = self .dtype , device = layer_device )
1219
+ new_layer_key_cache = torch .zeros (cache_shape , dtype = self ._dtype , device = layer_device )
1220
+ new_layer_value_cache = torch .zeros (cache_shape , dtype = self ._dtype , device = layer_device )
1221
1221
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
1222
1222
# preventing compiled graph breaks when updating the cache.
1223
1223
torch ._dynamo .mark_static_address (new_layer_key_cache )
@@ -1680,7 +1680,7 @@ def __init__(
1680
1680
config .head_dim if hasattr (config , "head_dim" ) else config .hidden_size // config .num_attention_heads
1681
1681
)
1682
1682
1683
- self .dtype = dtype
1683
+ self ._dtype = dtype
1684
1684
self .num_key_value_heads = (
1685
1685
config .num_attention_heads if config .num_key_value_heads is None else config .num_key_value_heads
1686
1686
)
@@ -1707,8 +1707,8 @@ def __init__(
1707
1707
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
1708
1708
# breaks when updating the cache.
1709
1709
cache_shape = global_cache_shape if not self .is_sliding [i ] else sliding_cache_shape
1710
- new_layer_key_cache = torch .zeros (cache_shape , dtype = self .dtype , device = layer_device )
1711
- new_layer_value_cache = torch .zeros (cache_shape , dtype = self .dtype , device = layer_device )
1710
+ new_layer_key_cache = torch .zeros (cache_shape , dtype = self ._dtype , device = layer_device )
1711
+ new_layer_value_cache = torch .zeros (cache_shape , dtype = self ._dtype , device = layer_device )
1712
1712
torch ._dynamo .mark_static_address (new_layer_key_cache )
1713
1713
torch ._dynamo .mark_static_address (new_layer_value_cache )
1714
1714
self .key_cache .append (new_layer_key_cache )
@@ -1853,8 +1853,8 @@ def __init__(
1853
1853
dtype : torch .dtype = torch .float16 ,
1854
1854
device : Union [torch .device , str , None ] = None ,
1855
1855
):
1856
- self .dtype = dtype
1857
1856
self .max_batch_size = max_batch_size
1857
+ self ._dtype = dtype
1858
1858
self .intermediate_size = config .intermediate_size
1859
1859
self .ssm_state_size = config .state_size
1860
1860
self .conv_kernel_size = config .conv_kernel
@@ -1868,14 +1868,14 @@ def __init__(
1868
1868
self .intermediate_size ,
1869
1869
self .conv_kernel_size ,
1870
1870
device = device ,
1871
- dtype = dtype ,
1871
+ dtype = self . _dtype ,
1872
1872
)
1873
1873
ssm_state : torch .Tensor = torch .zeros (
1874
1874
self .max_batch_size ,
1875
1875
self .intermediate_size ,
1876
1876
self .ssm_state_size ,
1877
1877
device = device ,
1878
- dtype = dtype ,
1878
+ dtype = self . _dtype ,
1879
1879
)
1880
1880
1881
1881
torch ._dynamo .mark_static_address (conv_state )
@@ -1972,7 +1972,7 @@ def __init__(
1972
1972
self .max_cache_len = config .max_position_embeddings if max_cache_len is None else max_cache_len
1973
1973
self .device = torch .device (device ) if layer_device_map is None else torch .device (layer_device_map [0 ])
1974
1974
self .offload_device = torch .device (offload_device )
1975
- self .dtype = dtype if dtype is not None else torch .float32
1975
+ self ._dtype = dtype if dtype is not None else torch .float32
1976
1976
1977
1977
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
1978
1978
head_dim = config .head_dim if hasattr (config , "head_dim" ) else config .hidden_size // config .num_attention_heads
@@ -2144,8 +2144,8 @@ def _create_key_value_cache_tensors(
2144
2144
2145
2145
is_cpu_device = device == torch .device ("cpu" )
2146
2146
2147
- key_cache = torch .zeros (shape , dtype = self .dtype , device = device , pin_memory = is_cpu_device )
2148
- value_cache = torch .zeros (shape , dtype = self .dtype , device = device , pin_memory = is_cpu_device )
2147
+ key_cache = torch .zeros (shape , dtype = self ._dtype , device = device , pin_memory = is_cpu_device )
2148
+ value_cache = torch .zeros (shape , dtype = self ._dtype , device = device , pin_memory = is_cpu_device )
2149
2149
2150
2150
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
2151
2151
# preventing compiled graph breaks when updating the cache.
0 commit comments