Skip to content

Commit bab605d

Browse files
authored
[Cache] rename dtype attribute 🚨 🚨 (#37044)
* yoink * same pattern in all cache
1 parent 9fd9476 commit bab605d

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

src/transformers/cache_utils.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -1199,7 +1199,7 @@ def __init__(
11991199
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
12001200
)
12011201

1202-
self.dtype = dtype
1202+
self._dtype = dtype
12031203
self.num_key_value_heads = (
12041204
config.num_attention_heads
12051205
if getattr(config, "num_key_value_heads", None) is None
@@ -1216,8 +1216,8 @@ def __init__(
12161216
layer_device = layer_device_map[idx]
12171217
else:
12181218
layer_device = device
1219-
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
1220-
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
1219+
new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
1220+
new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
12211221
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
12221222
# preventing compiled graph breaks when updating the cache.
12231223
torch._dynamo.mark_static_address(new_layer_key_cache)
@@ -1680,7 +1680,7 @@ def __init__(
16801680
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
16811681
)
16821682

1683-
self.dtype = dtype
1683+
self._dtype = dtype
16841684
self.num_key_value_heads = (
16851685
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
16861686
)
@@ -1707,8 +1707,8 @@ def __init__(
17071707
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
17081708
# breaks when updating the cache.
17091709
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
1710-
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
1711-
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
1710+
new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
1711+
new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
17121712
torch._dynamo.mark_static_address(new_layer_key_cache)
17131713
torch._dynamo.mark_static_address(new_layer_value_cache)
17141714
self.key_cache.append(new_layer_key_cache)
@@ -1853,8 +1853,8 @@ def __init__(
18531853
dtype: torch.dtype = torch.float16,
18541854
device: Union[torch.device, str, None] = None,
18551855
):
1856-
self.dtype = dtype
18571856
self.max_batch_size = max_batch_size
1857+
self._dtype = dtype
18581858
self.intermediate_size = config.intermediate_size
18591859
self.ssm_state_size = config.state_size
18601860
self.conv_kernel_size = config.conv_kernel
@@ -1868,14 +1868,14 @@ def __init__(
18681868
self.intermediate_size,
18691869
self.conv_kernel_size,
18701870
device=device,
1871-
dtype=dtype,
1871+
dtype=self._dtype,
18721872
)
18731873
ssm_state: torch.Tensor = torch.zeros(
18741874
self.max_batch_size,
18751875
self.intermediate_size,
18761876
self.ssm_state_size,
18771877
device=device,
1878-
dtype=dtype,
1878+
dtype=self._dtype,
18791879
)
18801880

18811881
torch._dynamo.mark_static_address(conv_state)
@@ -1972,7 +1972,7 @@ def __init__(
19721972
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
19731973
self.device = torch.device(device) if layer_device_map is None else torch.device(layer_device_map[0])
19741974
self.offload_device = torch.device(offload_device)
1975-
self.dtype = dtype if dtype is not None else torch.float32
1975+
self._dtype = dtype if dtype is not None else torch.float32
19761976

19771977
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
19781978
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
@@ -2144,8 +2144,8 @@ def _create_key_value_cache_tensors(
21442144

21452145
is_cpu_device = device == torch.device("cpu")
21462146

2147-
key_cache = torch.zeros(shape, dtype=self.dtype, device=device, pin_memory=is_cpu_device)
2148-
value_cache = torch.zeros(shape, dtype=self.dtype, device=device, pin_memory=is_cpu_device)
2147+
key_cache = torch.zeros(shape, dtype=self._dtype, device=device, pin_memory=is_cpu_device)
2148+
value_cache = torch.zeros(shape, dtype=self._dtype, device=device, pin_memory=is_cpu_device)
21492149

21502150
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
21512151
# preventing compiled graph breaks when updating the cache.

0 commit comments

Comments
 (0)