Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tester/api_config/config_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,9 @@ def get_padding_offset(bsz, max_seq_len, seq_lens_this_time):
self.numpy_tensor = soft_labels.astype(self.dtype)
else:
self.numpy_tensor = numpy.random.randint(0, num_classes, size=self.shape).astype(self.dtype)
elif self.check_arg(api_config, 3, "weight"):
self.numpy_tensor = numpy.random.random(size=self.shape)
self.numpy_tensor = self.numpy_tensor / self.numpy_tensor.sum()

elif api_config.api_name == "paddle.nn.functional.ctc_loss":
if self.check_arg(api_config, 1, "labels"):
Expand Down Expand Up @@ -1880,6 +1883,12 @@ def get_paddle_tensor(self, api_config):
dtype="float32" if self.dtype == 'bfloat16' else self.dtype,
place=self.place
)
if api_config.api_name == "paddle.nn.functional.cross_entropy":
if self.check_arg(api_config, 0, "input") and not self.get_arg(api_config, 8, "use_softmax", True):
axis = self.get_arg(api_config, 7, "axis", -1)
self.paddle_tensor = paddle.exp(self.paddle_tensor)
self.paddle_tensor = self.paddle_tensor / self.paddle_tensor.sum(axis=axis, keepdim=True)

self.paddle_tensor.stop_gradient = False
if self.dtype == "bfloat16":
self.paddle_tensor = paddle.cast(self.paddle_tensor, dtype="uint16")
Expand Down
18 changes: 17 additions & 1 deletion tester/paddle_to_torch/mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -2680,7 +2680,23 @@
}
},
"paddle.nn.functional.cross_entropy": {
"Rule": "CrossEntropyRule"
"Rule": "CrossEntropyRule",
"torch_api": "torch.nn.functional.cross_entropy",
"set_defaults": {
"weight": "None",
"reduction": "'mean'",
"soft_label": "False",
"label_smoothing": "0.0",
"reduction_original":"None"
},
"paddle_torch_args_map": {
"input": "input",
"label": "target",
"weight": "weight",
"reduction": "reduction",
"ignore_index": "ignore_index",
"label_smoothing": "label_smoothing"
}
},
"paddle.nn.functional.ctc_loss": {
"Rule": "CtcLossRule"
Expand Down
64 changes: 36 additions & 28 deletions tester/paddle_to_torch/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,43 +664,48 @@ def apply(self, paddle_api: str) -> ConvertResult:

class CrossEntropyRule(BaseRule):
def apply(self, paddle_api: str) -> ConvertResult:
defaults_code, map_code = self.apply_generic()
pre = """
_kwargs = {}
for paddle_param, torch_param in {
"input": "input",
"label": "target",
"weight": "weight",
"ignore_index": "ignore_index",
"reduction": "reduction",
"label_smoothing": "label_smoothing"
}.items():
if paddle_param in locals() and locals()[paddle_param] is not None:
_kwargs[torch_param] = locals()[paddle_param]
shp = _kwargs['target'].shape
if len(_kwargs["input"].shape) > 2:
perm = [0] + [len(_kwargs["input"].shape)-1]+ [i for i in range(1,len(_kwargs["input"].shape)-1)]
_kwargs['input'] = _kwargs['input'].permute(*perm)
soft_label = locals().get('soft_label',False)
shp = label.shape
if len(input.shape) > 2:
perm = [0] + [len(input.shape)-1]+ [i for i in range(1,len(input.shape)-1)]
input = input.permute(*perm)
axis = locals().get('axis',-1)
use_softmax = locals().get('use_softmax',True)
_kwargs['target'] = _kwargs['target'].squeeze(-1)
if "weight" in _kwargs:
_kwargs['weight'].requires_grad = False
if _kwargs['target'].dtype == torch.int32:
_kwargs['target'] = _kwargs['target'].long()
label = label.squeeze(-1)
if weight is not None:
weight.requires_grad = False
if label.dtype == torch.int32:
label = label.long()
if soft_label and weight is not None and shp == input.shape:
reduction_original = reduction
weight_original = weight
reduction = "none"
weight = None
"""
core = """
result = torch.nn.functional.cross_entropy(**_kwargs)
core = f"""
result = {self.torch_api}(**_kwargs)
"""
post = """
if "reduction" in _kwargs and _kwargs['reduction'] == "none":
if reduction_original is not None:
reduction = reduction_original
loss_weight = label@weight_original
sum_weight = loss_weight.sum()
result *= loss_weight
else:
sum_weight = result.numel()

if reduction == "none":
if soft_label:
result = result.unsqueeze(-1)
else:
result = result.reshape(shp)
elif reduction == "sum":
result = result.sum()
else:
result = result.sum()/sum_weight
"""
code = Code(
preprocess=pre.splitlines(),
preprocess=defaults_code + pre.splitlines() + map_code,
core=core.splitlines(),
postprocess=post.splitlines(),
)
Expand Down Expand Up @@ -948,7 +953,7 @@ def apply(self, paddle_api: str) -> ConvertResult:
result = torch.clamp(**_kwargs)
"""
elif paddle_api == "paddle.Tensor.clip":
core = """
core = """
if min is None and max is None:
result = x
else:
Expand All @@ -958,7 +963,10 @@ def apply(self, paddle_api: str) -> ConvertResult:
return ConvertResult.error(
paddle_api, f"Unsupported clip api: {paddle_api}"
)
code = Code(preprocess=defaults_code + pre.splitlines() + map_code, core=core.splitlines())
code = Code(
preprocess=defaults_code + pre.splitlines() + map_code,
core=core.splitlines(),
)
return ConvertResult.success(paddle_api, code)


Expand Down