Skip to content

Commit

Permalink
improve python coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
sneaxiy committed Oct 6, 2021
1 parent 3a53b0e commit 18681d9
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion python/paddle/fluid/tests/unittests/test_cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def cuda_graph_static_graph_main(self, seed, use_cuda_graph):
label.persistable = True
loss = simple_fc_net_with_inputs(image, label, class_num)
loss.persistable = True
optimizer = paddle.optimizer.SGD(learning_rate=1e-3)
lr = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2, 3, 4], values=[0.01, 0.02, 0.03, 0.04])
optimizer = paddle.optimizer.SGD(learning_rate=lr)
optimizer.minimize(loss)
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
Expand All @@ -86,6 +88,9 @@ def cuda_graph_static_graph_main(self, seed, use_cuda_graph):
image_t = scope.var(image.name).get_tensor()
label_t = scope.var(label.name).get_tensor()
loss_t = scope.var(loss.name).get_tensor()
lr_var = main.global_block().var(lr._var_name)
self.assertTrue(lr_var.persistable)
lr_t = scope.var(lr_var.name).get_tensor()
cuda_graph = None
for batch_id in range(20):
image_t.set(
Expand All @@ -101,9 +106,11 @@ def cuda_graph_static_graph_main(self, seed, use_cuda_graph):
cuda_graph.capture_end()

if cuda_graph:
lr_t.set(np.array([lr()], dtype='float32'), place)
cuda_graph.replay()
else:
exe.run(compiled_program)
lr.step()
if cuda_graph:
cuda_graph.reset()
return np.array(loss_t)
Expand Down

0 comments on commit 18681d9

Please sign in to comment.