Skip to content

Commit

Permalink
Style
Browse files Browse the repository at this point in the history
Signed-off-by: MaximumEntropy <[email protected]>
  • Loading branch information
MaximumEntropy committed Mar 19, 2022
1 parent 8968a5b commit 5816680
Showing 1 changed file with 10 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import random
from typing import Optional

import torch
import numpy as np
import random
import itertools
from sacrebleu import corpus_bleu
import torch
from omegaconf.dictconfig import DictConfig
from omegaconf.listconfig import ListConfig
from pytorch_lightning.trainer.trainer import Trainer
from sacrebleu import corpus_bleu

from nemo.collections.nlp.models.language_modeling.megatron_lm_encoder_decoder_model import (
MegatronLMEncoderDecoderModel,
Expand Down Expand Up @@ -278,8 +278,9 @@ def eval_epoch_end(self, outputs, mode):
tr_gt_inp = [None for _ in range(parallel_state.get_data_parallel_world_size())]
# we also need to drop pairs where ground truth is an empty string
torch.distributed.all_gather_object(
tr_gt_inp, [(t, g, i) for (t, g, i) in zip(translations, ground_truths, inputs)],
group=parallel_state.get_data_parallel_group()
tr_gt_inp,
[(t, g, i) for (t, g, i) in zip(translations, ground_truths, inputs)],
group=parallel_state.get_data_parallel_group(),
)
if parallel_state.get_data_parallel_rank() == 0:
_translations = []
Expand All @@ -300,15 +301,9 @@ def eval_epoch_end(self, outputs, mode):
bleu_score = sacre_bleu.score * parallel_state.get_data_parallel_world_size()

dataset_name = "Validation" if mode == 'val' else "Test"
logging.info(
f"{dataset_name}, Dataloader index: {dataloader_idx}, Set size: {len(_translations)}"
)
logging.info(
f"{dataset_name}, Dataloader index: {dataloader_idx}, SacreBLEU = {bleu_score}"
)
logging.info(
f"{dataset_name}, Dataloader index: {dataloader_idx}, Translation Examples:"
)
logging.info(f"{dataset_name}, Dataloader index: {dataloader_idx}, Set size: {len(_translations)}")
logging.info(f"{dataset_name}, Dataloader index: {dataloader_idx}, SacreBLEU = {bleu_score}")
logging.info(f"{dataset_name}, Dataloader index: {dataloader_idx}, Translation Examples:")
logging.info('============================================================')
for example_idx in range(0, 3):
random_index = random.randint(0, len(_translations) - 1)
Expand Down

0 comments on commit 5816680

Please sign in to comment.