-
Notifications
You must be signed in to change notification settings - Fork 664
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to use oneflow to achieve similar functions to torch.nn.PixelShuffle() #3790
Comments
Thanks for your feedback. This Op is not yet available, but will be added to the development plan. Anyone else who are intrested in helping us to implement this op, can reply here and we will give a hand as much as possible. |
@416047850 I have implemented the You can have a try :) reference: apache/mxnet#13571 import numpy as np
import oneflow as flow
import oneflow.typing as oft
import torch
def PixelShuffle(input, h_factor, w_factor):
b, c, h, w = input.shape
assert c % (h_factor * w_factor) == 0
new_c = int(c / (h_factor * w_factor))
out = flow.reshape(input, [b, new_c, h_factor * w_factor, h, w])
out = flow.reshape(out, [b * new_c, h_factor, w_factor, h, w])
out = flow.transpose(out, [0, 3, 1, 4, 2])
out = flow.reshape(out, [b, new_c, h * h_factor, w * w_factor])
return out
input_shape = (3, 4, 2, 2)
h_factor = 2
w_factor = 2
func_config = flow.FunctionConfig()
func_config.default_data_type(flow.float)
@flow.global_function(type="predict", function_config=func_config)
def PixelShuffleJob(input: oft.Numpy.Placeholder(input_shape, dtype=flow.float32)):
with flow.scope.placement("gpu", "0:0"):
out = PixelShuffle(input, h_factor, w_factor)
return out
# OneFlow
check_point = flow.train.CheckPoint()
check_point.init()
input = np.random.uniform(size=input_shape).astype(np.float32)
of_out = PixelShuffleJob(input).get()
# print(input)
# print(of_out.numpy())
arr = torch.tensor(input)
t_out = torch.pixel_shuffle(arr, upscale_factor=2)
# print(t_out.numpy())
print(np.allclose(of_out.numpy(), t_out.numpy())) |
Thanks. Youre my God! : ) |
How to use oneflow to achieve similar functions to torch.nn.PixelShuffle()
The text was updated successfully, but these errors were encountered: