diff --git a/src/transformers/bart_mem_prof.py b/src/transformers/bart_mem_prof.py new file mode 100644 index 000000000000..f3523f05d770 --- /dev/null +++ b/src/transformers/bart_mem_prof.py @@ -0,0 +1,67 @@ +from transformers import * +import torch +DEFAULT_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' +def runner(source_path, out_file, batch_size=8, device=DEFAULT_DEVICE, prof_generate=False): + + tokenizer = BartTokenizer.from_pretrained('bart-large') + lns = [" " + x.rstrip() for x in open(source_path).readlines()][:batch_size] + + dct = tokenizer.batch_encode_plus(lns, max_length=1024, return_tensors="pt", pad_to_max_length=True) + ids = dct['input_ids'].to(DEFAULT_DEVICE) + msk = dct['attention_mask'].to(DEFAULT_DEVICE) + model = BartForConditionalGeneration.from_pretrained('bart-large-cnn', output_past=prof_generate).to(DEFAULT_DEVICE) + model.log_mem('starting') + if prof_generate: + + summaries = model.generate( + input_ids=ids, + attention_mask=msk, + num_beams=4, + length_penalty=2.0, + max_length=140 + 2, # +2 from original because we start at step=1 and stop before max_length + min_length=55 + 1, # +1 from original because we start at step=1 + no_repeat_ngram_size=3, + early_stopping=True, + do_sample=False, + decoder_start_token_id=model.config.eos_token_ids[0], + ) + model.log_mem('done') + dec = [tokenizer.decode(s) for s in summaries] + print(dec[0]) + else: + #model.decoder.generation_mode = Fals + with torch.no_grad(): + model( + input_ids=ids, + attention_mask=msk, + ) + + log_df = model.combine_logs() + log_df.to_csv(out_file) + + +import argparse +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument( + "output_path", type=str, help="where to save summaries", + ) + parser.add_argument( + "--source_path", type=str, default="/home/shleifer/transformers_fork/notebooks/test.source", + help="like cnn_dm/test.source", required=False + ) + parser.add_argument( + "--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.", + ) + parser.add_argument( + "--bs", type=int, default=8, required=False, help="batch size: how many to summarize at a time", + ) + parser.add_argument( + "--do-generate", action='store_true', required=False, help="batch size: how many to summarize at a time", + ) + args = parser.parse_args() + runner(args.source_path, args.output_path, batch_size=args.bs, device=args.device, prof_generate=args.do_generate) + + + diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index ac1764de8bd6..b8e59c5cddd1 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -25,7 +25,8 @@ from .configuration_bart import BartConfig from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .modeling_utils import PreTrainedModel, create_position_ids_from_input_ids - +from durbango.logging_utils import LoggingMixin +from durbango.torch_utils import print_tensor_sizes, local_sizeof, get_tensor_shapes_and_pointers logger = logging.getLogger(__name__) @@ -109,7 +110,7 @@ def _prepare_bart_decoder_inputs( return decoder_input_ids, decoder_attn_mask -class PretrainedBartModel(PreTrainedModel): +class PretrainedBartModel(PreTrainedModel, LoggingMixin): config_class = BartConfig base_model_prefix = "model" pretrained_model_archive_map = BART_PRETRAINED_MODEL_ARCHIVE_MAP @@ -185,9 +186,10 @@ def make_padding_mask(input_ids, padding_idx=1): # Helper Modules +from durbango.torch_utils import get_shapes -class EncoderLayer(nn.Module): +class EncoderLayer(nn.Module, LoggingMixin): def __init__(self, config: BartConfig): super().__init__() self.embed_dim = config.d_model @@ -216,7 +218,8 @@ def forward(self, x, encoder_padding_mask): encoded output of shape `(seq_len, batch, embed_dim)` """ residual = x - x, attn_weights = self.self_attn(query=x, key=x, key_padding_mask=encoder_padding_mask,) + x, attn_weights = self.self_attn(query=x, key=x, key_padding_mask=encoder_padding_mask, update_layer_state=False,) + x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x x = self.self_attn_layer_norm(x) @@ -230,8 +233,8 @@ def forward(self, x, encoder_padding_mask): x = self.final_layer_norm(x) return x, attn_weights - -class BartEncoder(nn.Module): +import gc +class BartEncoder(nn.Module, LoggingMixin): """ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a :class:`EncoderLayer`. @@ -282,16 +285,19 @@ def forward( attention_mask = attention_mask.eq(0) inputs_embeds = self.embed_tokens(input_ids) - embed_pos = self.embed_positions(input_ids) - x = inputs_embeds + embed_pos + x = inputs_embeds + self.embed_positions(input_ids) x = self.layernorm_embedding(x) x = F.dropout(x, p=self.dropout, training=self.training) + assert not (self.output_attentions or self.output_hidden_states) # B x T x C -> T x B x C x = x.transpose(0, 1) - + self.log_mem('encoder: starting_loop') encoder_states, all_attentions = [], [] - for encoder_layer in self.layers: + #rdd_start = print_tensor_sizes() + #rdd_start.to_csv(f'rdd_start.csv') + for i, encoder_layer in enumerate(self.layers): + if self.output_hidden_states: encoder_states.append(x) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) @@ -299,19 +305,16 @@ def forward( if self.training and (dropout_probability < self.layerdrop): # skip the layer attn = None else: - x, attn = encoder_layer(x, attention_mask) - - if self.output_attentions: - all_attentions.append(attn) - - if self.output_hidden_states: - encoder_states.append(x) - + x, _ = encoder_layer(x, attention_mask) + assert len(encoder_states) == 0 + assert len(all_attentions) == 0 + #self.log_mem(f'x: {x.shape}, attn: {attn.shape}') + self.log_mem(f'Encoder: called layer {i}') encoder_states = [hidden_state.transpose(0, 1) for hidden_state in encoder_states] return x, encoder_states, all_attentions -class DecoderLayer(nn.Module): +class DecoderLayer(nn.Module, LoggingMixin): def __init__(self, config: BartConfig): super().__init__() self.embed_dim = config.d_model @@ -374,7 +377,7 @@ def forward( ) # just self_attn weights for now, following t5, layer_state = cache for decoding -class BartDecoder(nn.Module): +class BartDecoder(nn.Module, LoggingMixin): """ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`DecoderLayer`. @@ -435,6 +438,8 @@ def forward( encoder_padding_mask = encoder_padding_mask.eq(0) # embed positions + + self.log_mem('decoder: embedded positions') positions = self.embed_positions(input_ids, generation_mode=generation_mode) if generation_mode: @@ -443,6 +448,7 @@ def forward( assert input_ids.ne(self.padding_idx).any() x = self.embed_tokens(input_ids) + self.log_mem('decoder: embedded tokens') x += positions x = self.layernorm_embedding(x) @@ -464,6 +470,7 @@ def forward( x, layer_self_attn, layer_past = decoder_layer( x, encoder_hidden_states, encoder_padding_mask, layer_state=layer_state, attention_mask=combined_mask, ) + self.log_mem(f'decoder: called attn {i}') if self.output_past: next_decoder_cache.append(layer_past.copy()) @@ -483,6 +490,7 @@ def forward( return x, next_cache, all_hidden_states, list(all_self_attns) + def _reorder_buffer(attn_cache, new_order): for k, input_buffer_k in attn_cache.items(): if input_buffer_k is not None: @@ -490,8 +498,8 @@ def _reorder_buffer(attn_cache, new_order): return attn_cache -class SelfAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" +class SelfAttention(nn.Module, LoggingMixin): + """Multi-headed attention from "Attention Is All You Need""" def __init__( self, @@ -519,11 +527,16 @@ def __init__( def _shape(self, tensor, dim_0, bsz): return tensor.contiguous().view(dim_0, bsz * self.num_heads, self.head_dim).transpose(0, 1) + + def log_mem(self, msg='', verbose=False): + super().log_mem(msg=f'{self.cache_key}_attn:{msg}', verbose=verbose) + def forward( self, query, key: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, + update_layer_state=True, layer_state: Optional[Dict[str, Optional[Tensor]]] = None, attn_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor]]: @@ -544,6 +557,7 @@ def forward( layer_state = {} q = self.q_proj(query) * self.scaling + self.log_mem('\tq_proj') if static_kv: if key is None: k = v = None @@ -554,29 +568,39 @@ def forward( k = self.k_proj(query) v = self.v_proj(query) + q = self._shape(q, tgt_len, bsz) + self.log_mem(f'\tq_reshape -> {q.shape}') if k is not None: k = self._shape(k, -1, bsz) + self.log_mem(f'\t done reshaping k,v ->, {k.shape}') if v is not None: v = self._shape(v, -1, bsz) + if saved_state is not None: + self.log_mem('\t about to use saved_state') k, v, key_padding_mask = self._use_saved_state(k, v, saved_state, key_padding_mask, static_kv, bsz) # Update cache - layer_state[self.cache_key] = { - "prev_key": k.view(bsz, self.num_heads, -1, self.head_dim), - "prev_value": v.view(bsz, self.num_heads, -1, self.head_dim), - "prev_key_padding_mask": key_padding_mask if not static_kv else None, - } + if update_layer_state: + layer_state[self.cache_key] = { + "prev_key": k.view(bsz, self.num_heads, -1, self.head_dim), + "prev_value": v.view(bsz, self.num_heads, -1, self.head_dim), + "prev_key_padding_mask": key_padding_mask if not static_kv else None, + } + self.log_mem('\t attn: done layer_state') assert k is not None src_len = k.size(1) + self.log_mem('\t attn: before BMM(q,k)') attn_weights = torch.bmm(q, k.transpose(1, 2)) + self.log_mem('\t attn: done BMM(q,k)') assert attn_weights.size() == (bsz * self.num_heads, tgt_len, src_len) if attn_mask is not None: attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_mask + self.log_mem('\t attn: done causal mask') attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) # This is part of a workaround to get around fork/join parallelism not supporting Optional types. @@ -584,21 +608,26 @@ def forward( key_padding_mask = None assert key_padding_mask is None or key_padding_mask.size()[:2] == (bsz, src_len,) - if key_padding_mask is not None: # don't attend to padding symbols + if key_padding_mask is not None: # shape (bsz, src_len) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2) - attn_weights = attn_weights.masked_fill(reshaped, float("-inf")) + attn_weights = attn_weights.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf")) + self.log_mem('\t attn: done masked_fill') attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = F.softmax(attn_weights, dim=-1) + self.log_mem('\t attn: done softmax') attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training,) + assert v is not None attn_output = torch.bmm(attn_probs, v) + self.log_mem('\t attn: done BMM(probs, v)') assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + self.log_mem('\t attn: done view(output)') attn_output = self.out_proj(attn_output) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - return attn_output, attn_weights + self.log_mem('\t attn: done out_proj') + #attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + return attn_output, None def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz): # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) @@ -655,7 +684,7 @@ def _cat_prev_key_padding_mask( return new_key_padding_mask -class BartClassificationHead(nn.Module): +class BartClassificationHead(nn.Module, LoggingMixin): """Head for sentence-level classification tasks.""" # This can trivially be shared with RobertaClassificationHead @@ -727,6 +756,8 @@ def _filter_out_falsey_values(tup) -> Tuple: # Public API +import time +import pandas as pd @add_start_docstrings( "The bare BART Model outputting raw hidden-states without any specific head on top.", BART_START_DOCSTRING, @@ -743,6 +774,7 @@ def __init__(self, config: BartConfig): self.encoder = BartEncoder(config, self.shared) self.decoder = BartDecoder(config, self.shared) + self.init_weights() @add_start_docstrings_to_callable(BART_INPUTS_DOCSTRING) @@ -758,6 +790,9 @@ def forward( ): # make masks if user doesn't supply + if encoder_outputs is None: + encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) + assert isinstance(encoder_outputs, tuple) if not generation_mode: decoder_input_ids, decoder_attention_mask = _prepare_bart_decoder_inputs( self.config, @@ -767,9 +802,6 @@ def forward( mask_dtype=self.shared.weight.dtype, ) assert decoder_input_ids is not None - if encoder_outputs is None: - encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) - assert isinstance(encoder_outputs, tuple) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) decoder_outputs = self.decoder( decoder_input_ids, @@ -804,10 +836,10 @@ class BartForConditionalGeneration(PretrainedBartModel): def __init__(self, config: BartConfig): super().__init__(config) - # if base_model is None: - base_model = BartModel(config) - self.model = base_model - self.lm_head = _make_linear_from_emb(self.model.shared) + # if base_model is Nones: + #self.log_mem('pre-init') + self.model = BartModel(config) + #self.lm_head = _make_linear_from_emb(self.model.shared) def tie_weights(self): pass # hack to prevent changing lm_head.out_features. The input and output embeddings are still the same. @@ -866,6 +898,7 @@ def forward( tokenizer.decode(predictions).split() # ['good', 'great', 'all', 'really', 'very'] """ + self.model.log_mem('before BartModel.forward') outputs = self.model( input_ids, attention_mask=attention_mask, @@ -875,7 +908,10 @@ def forward( decoder_cached_states=decoder_cached_states, generation_mode=generation_mode, ) - lm_logits = self.lm_head(outputs[0]) + self.model.log_mem('after call, before lm_head') + lm_logits = F.linear(outputs[0], self.model.shared.weight) + #lm_logits = self.lm_head(outputs[0]) + self.model.log_mem('after lm_head') outputs = (lm_logits,) + outputs[1:] # Add hidden states and attention if they are here if lm_labels is not None: loss_fct = nn.CrossEntropyLoss() @@ -885,6 +921,7 @@ def forward( return outputs + def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, **kwargs): assert past is not None, "past has to be defined for encoder_outputs" @@ -932,7 +969,7 @@ def get_encoder(self): return self.model.encoder def get_output_embeddings(self): - return self.lm_head + return _make_linear_from_emb(self.model.shared) # make it on the fly @add_start_docstrings( diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 808c16009477..99586b066b41 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -896,6 +896,16 @@ def generate( effective_batch_size = batch_size effective_batch_mult = 1 + if self.config.is_encoder_decoder: + assert bos_token_id is not None, "Encoder Decoder Models need to have a bos_token_id" + assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self) + assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder) + + # get encoder and store encoder outputs + encoder = self.get_encoder() + + encoder_outputs = encoder(input_ids, attention_mask=attention_mask) + # Expand input ids if num_beams > 1 or num_return_sequences > 1 if num_return_sequences > 1 or num_beams > 1: input_ids_len = input_ids.shape[-1] @@ -911,16 +921,8 @@ def generate( effective_batch_size * num_beams, input_ids_len ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) - if self.config.is_encoder_decoder: - assert bos_token_id is not None, "Encoder Decoder Models need to have a bos_token_id" - assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self) - assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder) - - # get encoder and store encoder outputs - encoder = self.get_encoder() - - encoder_outputs = encoder(input_ids, attention_mask=attention_mask) + if self.config.is_encoder_decoder: # create empty decoder_input_ids input_ids = torch.full( (effective_batch_size * num_beams, 1), diff --git a/tests/test_bart_memory.py b/tests/test_bart_memory.py new file mode 100644 index 000000000000..222222e30cfb --- /dev/null +++ b/tests/test_bart_memory.py @@ -0,0 +1,38 @@ +import unittest + +import torch + +from tests.utils import require_torch, slow +from transformers import BartTokenizer, BartModel +from transformers.modeling_bart import shift_tokens_right + +DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + +@require_torch +class TestHface(unittest.TestCase): + + @classmethod + def setUpClass(cls): + source_path = "test.source" + cls.lns = [" " + x.rstrip() for x in open(source_path).readlines()][:6] + tokenizer = BartTokenizer.from_pretrained('bart-large') + dct = tokenizer.batch_encode_plus(cls.lns, max_length=1024, return_tensors="pt", pad_to_max_length=True) + cls.ids = dct['input_ids'].to(DEFAULT_DEVICE) + cls.prev_output_tokens = shift_tokens_right(cls.ids, 1).to(DEFAULT_DEVICE) + cls.model = BartModel.from_pretrained('bart-large').to(DEFAULT_DEVICE) + #cls.lns = pickle_load('/Users/shleifer/transformers_fork/lns.pkl') + return cls + + def test_hf_fwd_batch(self): + bart = self.model + bart.reset_logs() + with torch.no_grad(): + bart(self.ids) + try: + log_df = bart.combine_logs() + #log_df.to_csv('hf_batch_fwd_logs.csv') + bart.save_logs('hf_batch_fwd_logs.txt') + print(bart.summary) + except AttributeError as e: + print(e) + diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index d064f0f780e8..6c33cdc9a55a 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -24,6 +24,7 @@ from .utils import CACHE_DIR, require_torch, slow, torch_device + if is_torch_available(): import torch from transformers import ( @@ -242,6 +243,7 @@ def test_lm_forward(self): self.assertEqual(logits.shape, expected_shape) self.assertIsInstance(loss.item(), float) + def test_lm_uneven_forward(self): config = BartConfig( vocab_size=self.vocab_size, @@ -255,11 +257,12 @@ def test_lm_uneven_forward(self): max_position_embeddings=48, ) lm_model = BartForConditionalGeneration(config).to(torch_device) + lm_model.log_mem('starting') context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device) summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long().to(torch_device) loss, logits, enc_features = lm_model(input_ids=context, decoder_input_ids=summary, lm_labels=summary) - expected_shape = (*summary.shape, config.vocab_size) - self.assertEqual(logits.shape, expected_shape) + log_df = lm_model.combine_logs() + tot = log_df.cpu_mem.max()-log_df.cpu_mem.min() def test_generate_beam_search(self): input_ids = torch.Tensor([[71, 82, 2], [68, 34, 2]]).long().to(torch_device) @@ -282,6 +285,7 @@ def test_generate_beam_search(self): lm_model.eval() max_length = 5 + new_input_ids = lm_model.generate( input_ids.clone(), do_sample=True, @@ -294,6 +298,7 @@ def test_generate_beam_search(self): # TODO(SS): uneven length batches, empty inputs def test_shift_tokens_right(self): + input_ids = torch.Tensor([[71, 82, 18, 33, 2, 1, 1], [68, 34, 26, 58, 30, 82, 2]]).long() shifted = shift_tokens_right(input_ids, 1) n_pad_before = input_ids.eq(1).float().sum() @@ -319,7 +324,12 @@ def test_generate_fp16(self): config, input_ids, batch_size = self._get_config_and_data(output_past=True) attention_mask = input_ids.ne(1).to(torch_device) model = BartForConditionalGeneration(config).eval().to(torch_device).half() + #trace = start_memory_tracing(modules_to_trace="transformers") model.generate(input_ids, attention_mask=attention_mask, do_sample=False, early_stopping=True) + #summary = MemoryViewer(stop_memory_tracing(trace)) + #summary.save_line_by_line('hf_mem_half_gen.txt') + + @unittest.skipIf(torch_device == "cpu", "Cant do half precision") def test_base_model_fp16(self): @@ -415,14 +425,13 @@ def test_mnli_inference(self): example_b = [0, 31414, 232, 328, 740, 1140, 69, 46078, 1588, 2, 1] input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2], example_b]) - - model = AutoModelForSequenceClassification.from_pretrained("bart-large-mnli").to( - torch_device - ) # eval called in from_pre + model = AutoModelForSequenceClassification.from_pretrained("bart-large-mnli").to(torch_device) inputs_dict = prepare_bart_inputs_dict(model.config, input_ids) # Test that model hasn't changed + #trace = start_memory_tracing(modules_to_trace="transformers") + with torch.no_grad(): - batched_logits, features = model(**inputs_dict) + batched_logits, features = model.forward(**inputs_dict) expected_shape = torch.Size((2, 3)) self.assertEqual(batched_logits.shape, expected_shape) expected_slice = torch.Tensor([[0.1907, 1.4342, -1.0289]]).to(torch_device) @@ -445,7 +454,37 @@ def test_model_from_pretrained(self): self.assertIsNotNone(model) @slow - def test_cnn_summarization_same_as_fairseq(self): + def test_compare_generation_mem(self): + hf = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(torch_device) + if torch_device == 'cuda': + hf = hf.half() + tok = BartTokenizer.from_pretrained("bart-large") + text = " (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian" + tokens = tok.encode(text, return_tensors="pt").to(torch_device) + + @slow + def test_cnn_easy_summarization_same_as_fairseq(self): + hf = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(torch_device) + if torch_device == 'cuda': + hf = hf.half() + tok = BartTokenizer.from_pretrained("bart-large") + text = " (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian" + tokens = tok.encode(text, return_tensors="pt").to(torch_device) + extra_len = 20 + + gen_tokens = hf.generate( + tokens, + num_beams=4, + max_length=extra_len + 2, + do_sample=False, + decoder_start_token_id=hf.config.eos_token_ids[0], + ) # repetition_penalty=10., + expected_result = "The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday." + generated = [tok.decode(g,) for g in gen_tokens] + self.assertEqual(expected_result, generated[0]) + + @slow + def test_cnn_summarization_same_as_fairseq_hard(self): hf = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(torch_device) tok = BartTokenizer.from_pretrained("bart-large")