Skip to content

Commit

Permalink
[npu] add update_loss_scaling npu min value (#35270)
Browse files Browse the repository at this point in the history
  • Loading branch information
Baibaifan authored Sep 2, 2021
1 parent df57df9 commit 280d742
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 8 deletions.
23 changes: 15 additions & 8 deletions paddle/fluid/operators/amp/update_loss_scaling_op_npu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/npu_op_runner.h"

DECLARE_int32(min_loss_scaling);

namespace paddle {
namespace operators {

Expand Down Expand Up @@ -49,7 +51,7 @@ void Update(const platform::NPUDeviceContext& ctx,

std::vector<int> bad_out_data;
TensorToVector(*bad_out_tensor, ctx, &bad_out_data);
if (bad_out_data[0] == decr_every_n_nan_or_inf) {
if (bad_out_data[0] >= decr_every_n_nan_or_inf) {
const auto& runner_p3 = NpuOpRunner("Power", {*pre_loss_scaling_tensor},
{*updated_loss_scaling_tensor},
{{"power", static_cast<float>(1)},
Expand All @@ -60,13 +62,18 @@ void Update(const platform::NPUDeviceContext& ctx,

std::vector<T> new_loss_scaling;
TensorToVector(*updated_loss_scaling_tensor, ctx, &new_loss_scaling);
if (new_loss_scaling[0] < static_cast<T>(1)) {
float min_value = 1.0;
if (FLAGS_min_loss_scaling > 1) {
min_value = static_cast<float>(FLAGS_min_loss_scaling);
}

if (new_loss_scaling[0] < min_value) {
// updated_loss_scaling_data = 1
const auto& runner_p4 = NpuOpRunner("Power", {*pre_loss_scaling_tensor},
{*updated_loss_scaling_tensor},
{{"power", static_cast<float>(1)},
{"scale", static_cast<float>(0)},
{"shift", static_cast<float>(1)}});
const auto& runner_p4 = NpuOpRunner(
"Power", {*pre_loss_scaling_tensor}, {*updated_loss_scaling_tensor},
{{"power", static_cast<float>(1)},
{"scale", static_cast<float>(0)},
{"shift", static_cast<float>(min_value)}});

runner_p4.Run(stream);
}
Expand All @@ -93,7 +100,7 @@ void Update(const platform::NPUDeviceContext& ctx,
std::vector<int> good_out_data;
TensorToVector(*good_out_tensor, ctx, &good_out_data);

if (good_out_data[0] == incr_every_n_steps) {
if (good_out_data[0] >= incr_every_n_steps) {
const auto& runner_p3 = NpuOpRunner("Power", {*pre_loss_scaling_tensor},
{*updated_loss_scaling_tensor},
{{"power", static_cast<float>(1)},
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/platform/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ DEFINE_string(
npu_config_path, "",
"The absolute path of configuration json file, like: /tmp/config.json. "
"If proveided, it will be passed to aclInit().");
DEFINE_int32(min_loss_scaling, 1, "set minmum loss scaling value!");
#endif

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/pybind/global_value_getter_setter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ DECLARE_string(selected_xpus);
#ifdef PADDLE_WITH_ASCEND_CL
// device management
DECLARE_string(selected_npus);
// set minmum loss scaling value
DECLARE_int32(min_loss_scaling);
#endif

#ifdef PADDLE_WITH_DISTRIBUTE
Expand Down Expand Up @@ -385,6 +387,7 @@ static void RegisterGlobalVarGetterSetter() {

#ifdef PADDLE_WITH_ASCEND_CL
REGISTER_PUBLIC_GLOBAL_VAR(FLAGS_selected_npus);
REGISTER_PUBLIC_GLOBAL_VAR(FLAGS_min_loss_scaling);
#endif

#ifdef PADDLE_WITH_DITRIBUTE
Expand Down
1 change: 1 addition & 0 deletions python/paddle/fluid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def __bootstrap__():
'npu_config_path',
'get_host_by_name_time',
'hccl_check_nan',
'min_loss_scaling',
]

core.init_gflags(["--tryfromenv=" + ",".join(read_env_flags)])
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np
import sys
import os
sys.path.append("..")
from op_test import OpTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.contrib.mixed_precision.amp_nn as amp_nn
from test_update_loss_scaling_op_npu import TestUpdateLossScalingOpBad

paddle.enable_static()
SEED = 2021


class TestUpdateLossScalingOpMinLossScalingBad(TestUpdateLossScalingOpBad):
def setUp(self):
self.set_npu()
self.op_type = "update_loss_scaling"
self.place = paddle.NPUPlace(0)

self.init()
fluid.core.globals()['FLAGS_min_loss_scaling'] = 1639
found_inf = np.array([True], dtype=np.bool)
x = np.random.random((1024, 1024)).astype(self.dtype)
i = np.random.randint(0, 1024, 1)
j = np.random.randint(0, 1024, 1)
x[i[0]][j[0]] = np.inf

self.inputs = {
'X': [('x0', x)],
'FoundInfinite': found_inf,
'PrevLossScaling': self.prev_loss_scaling,
'InGoodSteps': self.num_good_steps,
'InBadSteps': self.num_bad_steps
}

self.outputs = {
'Out': [('out0', np.zeros_like(x))],
'LossScaling': np.array([1639.0]).astype(self.dtype),
'OutGoodSteps': self.zero_steps,
'OutBadSteps': self.zero_steps
}

def init(self):
self.incr_ratio = 2.0
self.decr_ratio = 0.8
self.dtype = np.float32
self.prev_loss_scaling = np.array([2048]).astype(self.dtype)
self.num_good_steps = np.array([999], dtype=np.int32)
self.num_bad_steps = np.array([1], dtype=np.int32)
self.zero_steps = np.array([0], dtype=np.int32)
self.attrs = {
'incr_every_n_steps': 1000,
'decr_every_n_nan_or_inf': 2,
'incr_ratio': self.incr_ratio,
'decr_ratio': self.decr_ratio,
}


if __name__ == '__main__':
unittest.main()

0 comments on commit 280d742

Please sign in to comment.