Skip to content

please lower std_mean and var_mean #2790

@mobassir94

Description

@mobassir94

torch.std_mean wasn't lowered to xla

here : #2776 @taylanbil shared his minimal test with following code :

import torch.nn as nn
import torch
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met

layer = nn.Linear(100, 100)
d = xm.xla_device()
layer = layer.to(d)
std, mean = torch.std_mean(layer.weight, dim=[1, 0], keepdim=True, unbiased=False)
rep = met.metrics_report()
print(rep)

and observed that the aten counter is there in the metrics report:

...
Counter: XrtSubTuple_Empty
Value: 128
Counter: aten::std_mean
Value: 1
Counter: xla::_copy_from
Value: 2
...

So this op need to be lowered.

var_mean was not lowered either.

but separating works as shown by @taylanbil :

d = xm.xla_device()
layer = layer.to(d)
#std, mean = torch.var_mean(layer.weight, dim=[1, 0], keepdim=True, unbiased=False)
#std, mean = torch.std_mean(layer.weight, dim=[1, 0], keepdim=True, unbiased=False)
std= torch.std(layer.weight, dim=[1, 0], keepdim=True, unbiased=False)
mean = torch.mean(layer.weight.data, dim=[1, 0], keepdim=True)
rep = met.metrics_report()
print(rep)

results in

Counter: xla::mean
Value: 1
Counter: xla::std
Value: 1

Metadata

Metadata

Assignees

Labels

loweringATen Operation loweringnostaleDo not consider for staleness

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions