Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Switch to forward interface
Browse files Browse the repository at this point in the history
  • Loading branch information
bgawrych committed Jun 29, 2021
1 parent 70d3c89 commit 88401a0
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions tests/python/mkl/subgraphs/test_transformer_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,22 @@ def __init__(self, units, num_heads, dtype='float32', **kwargs):
self._fc = nn.Dense(in_units=self._units, units=3*self._units, flatten=False, dtype=dtype)
self._scale = math.sqrt(self._units // self._num_heads)

def hybrid_forward(self, F, x, mask):
x = F.np.copy(x)
def forward(self, x, mask):
x = mx.np.copy(x)
out = self._fc(x)
query, key, value = F.np.split(out, 3, axis=-1)
query = F.npx.reshape(query, (-2, -2, self._num_heads, -1))
key = F.npx.reshape(key, (-2, -2, self._num_heads, -1))
value = F.npx.reshape(value, (-2, -2, self._num_heads, -1))
scores = F.npx.batch_dot(F.np.swapaxes(query, 1, 2), F.np.swapaxes(key, 1, 2),
query, key, value = mx.np.split(out, 3, axis=-1)
query = mx.npx.reshape(query, (-2, -2, self._num_heads, -1))
key = mx.npx.reshape(key, (-2, -2, self._num_heads, -1))
value = mx.npx.reshape(value, (-2, -2, self._num_heads, -1))
scores = mx.npx.batch_dot(mx.np.swapaxes(query, 1, 2), mx.np.swapaxes(key, 1, 2),
transpose_b=True)
mask = F.np.expand_dims(mask, axis=1).astype(np.bool)
attn_weights = F.npx.masked_softmax(scores, mask=mask.astype(np.bool),
mask = mx.np.expand_dims(mask, axis=1).astype(np.bool)
attn_weights = mx.npx.masked_softmax(scores, mask=mask.astype(np.bool),
axis=-1, temperature=self._scale)
attn_weights = F.npx.dropout(attn_weights, p=0.1)
context_vec = F.npx.batch_dot(attn_weights,
F.np.swapaxes(value, 1, 2)).transpose((0, 2, 1, 3))
context_vec = F.npx.reshape(context_vec, (-2, -2, -1))
attn_weights = mx.npx.dropout(attn_weights, p=0.1)
context_vec = mx.npx.batch_dot(attn_weights,
mx.np.swapaxes(value, 1, 2)).transpose((0, 2, 1, 3))
context_vec = mx.npx.reshape(context_vec, (-2, -2, -1))

return context_vec

Expand Down

0 comments on commit 88401a0

Please sign in to comment.