Skip to content

Commit

Permalink
Merge pull request #26 from danny-1k/devdev
Browse files Browse the repository at this point in the history
minor bug fix -> Correct Use of dropout probability as Norm epsilon
  • Loading branch information
HMUNACHI committed Mar 15, 2024
2 parents bccf348 + 190b03d commit 833aebe
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 20 deletions.
6 changes: 3 additions & 3 deletions docs/examples/mistral_copy_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -311,9 +311,9 @@
" shift_size=self.shift_size)\n",
" \n",
" self.feed_forward = PositionWiseFFN(self.feedforward_dim, self.hidden_dim)\n",
" self.norm1 = nn.RMSNorm(self.dropout)\n",
" self.norm2 = nn.RMSNorm(self.dropout)\n",
" self.norm3 = nn.RMSNorm(self.dropout)\n",
" self.norm1 = nn.RMSNorm()\n",
" self.norm2 = nn.RMSNorm()\n",
" self.norm3 = nn.RMSNorm()\n",
" self.dropout1 = nn.Dropout(self.dropout)\n",
" self.dropout2 = nn.Dropout(self.dropout)\n",
" self.dropout3 = nn.Dropout(self.dropout)\n",
Expand Down
4 changes: 2 additions & 2 deletions nanodl/__src/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ def setup(self):
num_heads=self.num_heads,
num_groups=self.num_groups)
self.feed_forward = GemmaMLP(self.feedforward_dim, self.hidden_dim)
self.norm1 = nn.RMSNorm(self.dropout)
self.norm2 = nn.RMSNorm(self.dropout)
self.norm1 = nn.RMSNorm()
self.norm2 = nn.RMSNorm()
self.dropout1 = nn.Dropout(self.dropout)
self.dropout2 = nn.Dropout(self.dropout)

Expand Down
12 changes: 6 additions & 6 deletions nanodl/__src/models/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ def setup(self):
self.attention1 = SelfMultiHeadAttention(hidden_dim=self.hidden_dim, num_heads=self.num_heads)
self.attention2 = SelfMultiHeadAttention(hidden_dim=self.hidden_dim, num_heads=self.num_heads)
self.feed_forward = PositionWiseFFN(self.feedforward_dim, self.hidden_dim)
self.norm1 = nn.LayerNorm(self.dropout)
self.norm2 = nn.LayerNorm(self.dropout)
self.norm3 = nn.LayerNorm(self.dropout)
self.norm1 = nn.LayerNorm()
self.norm2 = nn.LayerNorm()
self.norm3 = nn.LayerNorm()
self.dropout1 = nn.Dropout(self.dropout)
self.dropout2 = nn.Dropout(self.dropout)
self.dropout3 = nn.Dropout(self.dropout)
Expand Down Expand Up @@ -554,9 +554,9 @@ def setup(self):
self.hidden_dim,
self.num_experts,
self.top_k)
self.norm1 = nn.LayerNorm(self.dropout)
self.norm2 = nn.LayerNorm(self.dropout)
self.norm3 = nn.LayerNorm(self.dropout)
self.norm1 = nn.LayerNorm()
self.norm2 = nn.LayerNorm()
self.norm3 = nn.LayerNorm()
self.dropout1 = nn.Dropout(self.dropout)
self.dropout2 = nn.Dropout(self.dropout)
self.dropout3 = nn.Dropout(self.dropout)
Expand Down
6 changes: 3 additions & 3 deletions nanodl/__src/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,9 @@ def setup(self):
num_heads=self.num_heads,
num_groups=self.num_groups)
self.feed_forward = PositionWiseFFN(self.feedforward_dim, self.hidden_dim)
self.norm1 = nn.RMSNorm(self.dropout)
self.norm2 = nn.RMSNorm(self.dropout)
self.norm3 = nn.RMSNorm(self.dropout)
self.norm1 = nn.RMSNorm()
self.norm2 = nn.RMSNorm()
self.norm3 = nn.RMSNorm()
self.dropout1 = nn.Dropout(self.dropout)
self.dropout2 = nn.Dropout(self.dropout)
self.dropout3 = nn.Dropout(self.dropout)
Expand Down
12 changes: 6 additions & 6 deletions nanodl/__src/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,9 @@ def setup(self):
shift_size=self.shift_size)

self.feed_forward = PositionWiseFFN(self.feedforward_dim, self.hidden_dim)
self.norm1 = nn.RMSNorm(self.dropout)
self.norm2 = nn.RMSNorm(self.dropout)
self.norm3 = nn.RMSNorm(self.dropout)
self.norm1 = nn.RMSNorm()
self.norm2 = nn.RMSNorm()
self.norm3 = nn.RMSNorm()
self.dropout1 = nn.Dropout(self.dropout)
self.dropout2 = nn.Dropout(self.dropout)
self.dropout3 = nn.Dropout(self.dropout)
Expand Down Expand Up @@ -686,9 +686,9 @@ def setup(self):
shift_size=self.shift_size)

self.feed_forward = SparseMixtureOfExperts(self.feedforward_dim, self.hidden_dim)
self.norm1 = nn.RMSNorm(self.dropout)
self.norm2 = nn.RMSNorm(self.dropout)
self.norm3 = nn.RMSNorm(self.dropout)
self.norm1 = nn.RMSNorm()
self.norm2 = nn.RMSNorm()
self.norm3 = nn.RMSNorm()
self.dropout1 = nn.Dropout(self.dropout)
self.dropout2 = nn.Dropout(self.dropout)
self.dropout3 = nn.Dropout(self.dropout)
Expand Down

0 comments on commit 833aebe

Please sign in to comment.