Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
d29881f
Cleanup some code
thomasw21 Jul 28, 2022
baa5d87
make style
thomasw21 Jul 28, 2022
69227ae
Woops
thomasw21 Jul 28, 2022
0078a6c
Woops
thomasw21 Jul 28, 2022
ed09d70
Improve signatures
thomasw21 Jul 28, 2022
1a8b80b
WIP
thomasw21 Jul 28, 2022
62adfce
Try to reduce the number of reshape/copies
thomasw21 Jul 28, 2022
ec7442c
Woops
thomasw21 Jul 28, 2022
298e3fd
I don't think we actually need the layer_num scaling trick
thomasw21 Jul 28, 2022
42e5954
Woops
thomasw21 Jul 28, 2022
96307a4
Woops
thomasw21 Jul 28, 2022
9b2c1ca
No need for duplication
thomasw21 Jul 28, 2022
a8cde02
Try to fix beam_search
thomasw21 Jul 28, 2022
3a095d0
Fix beam search
thomasw21 Jul 28, 2022
0ffecbf
Woops
thomasw21 Jul 28, 2022
d40ee96
Woops
thomasw21 Jul 28, 2022
2677a28
Removing layer num normalization seems to be breaking
thomasw21 Jul 28, 2022
ddbe33e
Nit
thomasw21 Jul 28, 2022
77f19b3
Not sure self.layer_number normalization actually matters
thomasw21 Jul 28, 2022
5ed059c
make style
thomasw21 Jul 28, 2022
6c3bf96
Try and be backward compatible
thomasw21 Jul 28, 2022
995d31a
Woops
thomasw21 Jul 28, 2022
7795e23
Try to fix beam_search
thomasw21 Jul 28, 2022
4596be9
Woops
thomasw21 Jul 28, 2022
47e1969
Revert attempt to be backward compatible
thomasw21 Jul 28, 2022
ad1bfe9
Nits
thomasw21 Jul 28, 2022
02bf51d
Woops
thomasw21 Jul 28, 2022
c02f409
I don't like kwargs
thomasw21 Jul 28, 2022
47608f8
Woops
thomasw21 Jul 28, 2022
69663c5
Improve documentation on past_key_values format
thomasw21 Jul 29, 2022
b4346e1
make style
thomasw21 Jul 29, 2022
4db82d4
Optimize the device allocation in case of hidden_states in multiple d…
thomasw21 Jul 29, 2022
323c073
No need to manually cast the values to a specific device
thomasw21 Jul 29, 2022
2d58b7e
Rename with long version of variables
thomasw21 Jul 29, 2022
8699509
Improve type hinting
thomasw21 Jul 29, 2022
98fdf99
Add comment that explains that some methods return views
thomasw21 Jul 29, 2022
1c93638
Make style
thomasw21 Jul 29, 2022
5fcc118
Woops
thomasw21 Jul 29, 2022
49aff18
Actually i think the attention casting only makes sense when we use t…
thomasw21 Jul 29, 2022
88814bf
We don't actually need layer_number to be passed anymore
thomasw21 Jul 29, 2022
87861b2
Merge remote-tracking branch 'origin/main' into thomas/bloom_clean_code
thomasw21 Jul 29, 2022
a3d50c0
Fix FX test
thomasw21 Jul 29, 2022
1e4ccaf
Revert "Fix FX test"
thomasw21 Jul 29, 2022
790e238
Does this help?
thomasw21 Jul 29, 2022
b58a6f7
Try passing a tuple
thomasw21 Jul 29, 2022
fd117da
Bypass torch.baddbmm
thomasw21 Jul 29, 2022
bd1ae60
Apply suggestions from code review
thomasw21 Aug 1, 2022
7c399e6
Add comment about support for torchScript v1.11
thomasw21 Aug 1, 2022
50e1f2f
Add back layer_number normalization
thomasw21 Aug 1, 2022
8f4d603
Revert "Add back layer_number normalization"
thomasw21 Aug 1, 2022
213ec2c
fix ONNX support for bloom (#18456)
NouamaneTazi Aug 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions src/transformers/models/bloom/configuration_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,14 +214,19 @@ def generate_dummy_inputs(
batch, seqlen = common_inputs["input_ids"].shape
# Not using the same length for past_key_values
past_key_values_length = seqlen + 2
past_shape = (
batch,
head_dim = self._config.hidden_size // self.num_attention_heads
past_key_shape = (
batch * self.num_attention_heads,
head_dim,
past_key_values_length,
self.num_attention_heads,
self._config.hidden_size // self.num_attention_heads,
)
past_value_shape = (
batch * self.num_attention_heads,
past_key_values_length,
head_dim,
)
ordered_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers)
(torch.zeros(past_key_shape), torch.zeros(past_value_shape)) for _ in range(self.num_layers)
]

ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
Expand Down
Loading