Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
try to fix
Browse files Browse the repository at this point in the history
try to fix
  • Loading branch information
sxjscience committed Oct 14, 2019
1 parent c1e72d0 commit bfe9600
Showing 1 changed file with 8 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -256,24 +256,24 @@ When the number of samples for labels are very unbalanced, applying different we

```{.python .input}
class WeightedSoftmaxCE(nn.HybridBlock):
class WeightedSoftmaxCE(nn.Block):
def __init__(self, sparse_label=True, from_logits=False, **kwargs):
super(WeightedSoftmaxCE, self).__init__(**kwargs)
with self.name_scope():
self.sparse_label = sparse_label
self.from_logits = from_logits
def hybrid_forward(self, F, pred, label, class_weight, depth=None):
def forward(self, pred, label, class_weight, depth=None):
if self.sparse_label:
label = F.reshape(label, shape=(-1, ))
label = F.one_hot(label, depth)
label = nd.reshape(label, shape=(-1, ))
label = nd.one_hot(label, depth)
if not self.from_logits:
pred = F.log_softmax(pred, -1)
pred = nd.log_softmax(pred, -1)
weight_label = F.broadcast_mul(label, class_weight)
loss = -F.sum(pred * weight_label, axis=-1)
weight_label = nd.broadcast_mul(label, class_weight)
loss = -nd.sum(pred * weight_label, axis=-1)
# return F.mean(loss, axis=0, exclude=True)
# return nd.mean(loss, axis=0, exclude=True)
return loss
```
Expand Down

0 comments on commit bfe9600

Please sign in to comment.