-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP, conv_bn fuse example using paddlefx #33
base: main
Are you sure you want to change the base?
Conversation
这个已解决,之前没考虑到循环时 erase 当前 node 的情况 😂 |
对的,这个问题之前就发现了,我们可以加上如下特判进行「优化」: if paddle.allclose(conv_b_param, paddle.zeros_like(conv_b_param)):
conv_b_param = None 对于未训练的 resnet18 确实可以起到加速作用,因为从 BN 层获得的参数是 0,附加的 bias 可被优化为 None Fused time: 2.448992967605591
Unfused time: 2.5072388648986816
Traced time: 2.4950180053710938 但对于训练过的 resnet18(比如开启 Fused time: 3.230083703994751
Unfused time: 2.48759388923645
Traced time: 2.4984631538391113 因此这并不能算得上优化,应该找到加上 bias 会严重拖慢速度的原因 |
我想原因应该是paddle的Conv2D的实现,在有bias和没有bias的情况下,性能差异很大。 no bias conv + batch_norm 因为fuse能带来的一点性能提升都因为fuse后的conv需要有bias,导致反而变慢了。 下面的代码运行一下,就能很容易的看到差别: import paddle
import paddle.nn.functional as F
import time
class MyNet(paddle.nn.Layer):
def __init__(self, bias=False):
super(MyNet, self).__init__()
self.conv1 = paddle.nn.Conv2D(in_channels=3, out_channels=32, kernel_size=(3, 3), bias_attr=bias)
def forward(self, x):
x = self.conv1(x)
return x
bias_model = MyNet(bias=True)
no_bias_model = MyNet(bias=False)
inp = paddle.rand((128, 3, 224, 224))
def benchmark(model, iters=1000):
for _ in range(50):
model(inp)
# paddle.device.cuda.synchronize()
begin = time.time()
for _ in range(iters):
model(inp)
# paddle.device.cuda.synchronize()
return str(time.time() - begin)
print("no bias time: ", benchmark(no_bias_model))
print("bias time: ", benchmark(bias_model)) on V100 no bias time: 1.829711675643921 |
因为paddle的Conv2D的实现的问题,短期内,为了能展示paddlefx的fuse的能力,也许可以构造一个含有bias的conv + bn的网络,来让这个PR能先合入。 |
一般来说bn之前的conv都不会设置bias吧,或许可以尝试fuse RepVGG这样的网络 |
this is ported from https://pytorch.org/tutorials/intermediate/fx_conv_bn_fuser.html, but still have some critical issues need to solve.
See TODO in examples/conv_bn_fuse.py for details.