From a69f0f4417e071965b2f0dc1910687aec4bfc90e Mon Sep 17 00:00:00 2001 From: Adi Renduchintala Date: Wed, 12 Jul 2023 10:21:41 -0700 Subject: [PATCH] ptuning inference table bug fix (#7015) * remove hardcoded input and output Signed-off-by: arendu * fix inf table Signed-off-by: arendu --------- Signed-off-by: arendu Signed-off-by: Adi Renduchintala --- .../nlp/modules/common/megatron/adapters/parallel_adapters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py index fe339c6f9a8b..1818d33dc0d3 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py @@ -268,7 +268,7 @@ def __init__( # (@adithyare) the persistent=False will not pollute the indices into the state_dict of this module. self.register_buffer("indices", torch.LongTensor(list(range(self.virtual_tokens))), persistent=False) self.embedding = torch.nn.Embedding(self.virtual_tokens, self.embedding_dim) - self.inference_table = InferenceTable("taskname", self.embedding_dim, self.virtual_tokens) + self.inference_table = InferenceTable("taskname", self.output_dim, self.virtual_tokens) self.first = ColumnParallelLinear( self.embedding_dim, self.bottleneck_dim,