-
Notifications
You must be signed in to change notification settings - Fork 559
Closed
Labels
loweringATen Operation loweringATen Operation loweringnostaleDo not consider for stalenessDo not consider for staleness
Description
torch.std_mean wasn't lowered to xla
here : #2776 @taylanbil shared his minimal test with following code :
import torch.nn as nn
import torch
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
layer = nn.Linear(100, 100)
d = xm.xla_device()
layer = layer.to(d)
std, mean = torch.std_mean(layer.weight, dim=[1, 0], keepdim=True, unbiased=False)
rep = met.metrics_report()
print(rep)
and observed that the aten counter is there in the metrics report:
...
Counter: XrtSubTuple_Empty
Value: 128
Counter: aten::std_mean
Value: 1
Counter: xla::_copy_from
Value: 2
...
So this op need to be lowered.
var_mean was not lowered either.
but separating works as shown by @taylanbil :
d = xm.xla_device()
layer = layer.to(d)
#std, mean = torch.var_mean(layer.weight, dim=[1, 0], keepdim=True, unbiased=False)
#std, mean = torch.std_mean(layer.weight, dim=[1, 0], keepdim=True, unbiased=False)
std= torch.std(layer.weight, dim=[1, 0], keepdim=True, unbiased=False)
mean = torch.mean(layer.weight.data, dim=[1, 0], keepdim=True)
rep = met.metrics_report()
print(rep)
results in
Counter: xla::mean
Value: 1
Counter: xla::std
Value: 1
Metadata
Metadata
Assignees
Labels
loweringATen Operation loweringATen Operation loweringnostaleDo not consider for stalenessDo not consider for staleness