diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py index fe339c6f9a8b..1818d33dc0d3 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py @@ -268,7 +268,7 @@ def __init__( # (@adithyare) the persistent=False will not pollute the indices into the state_dict of this module. self.register_buffer("indices", torch.LongTensor(list(range(self.virtual_tokens))), persistent=False) self.embedding = torch.nn.Embedding(self.virtual_tokens, self.embedding_dim) - self.inference_table = InferenceTable("taskname", self.embedding_dim, self.virtual_tokens) + self.inference_table = InferenceTable("taskname", self.output_dim, self.virtual_tokens) self.first = ColumnParallelLinear( self.embedding_dim, self.bottleneck_dim,