Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ Dy2static ] infer_program may be incorrect in amp mode. #44487

Merged
merged 7 commits into from
Jul 21, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 35 additions & 9 deletions python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,14 @@ def _train_pure_fp16_program(self):
def _infer_program_id(self):
return _hash_with_id(self._infer_program, self)

@LazyInitialized
def _infer_pure_fp16_program_id(self):
return _hash_with_id(self._infer_pure_fp16_program, self)

@LazyInitialized
def _infer_amp_program_id(self):
return _hash_with_id(self._infer_amp_program, self)

@LazyInitialized
def _train_program_id(self):
program_id = _hash_with_id(self._train_program, self)
Expand Down Expand Up @@ -341,7 +349,7 @@ def _get_end_op_index(self):
elif _in_pure_fp16_guard():
infer_program = self._infer_pure_fp16_program
else:
infer_program = self._infer_program
infer_program = self.infer_program
return infer_program.desc.block(0).op_size()

def __call__(self, inputs):
Expand Down Expand Up @@ -380,14 +388,9 @@ def _cast_fp16_if_pure_fp16(self, in_vars):
@property
def program(self):
if self.training:
if _in_amp_guard():
return self._train_amp_program
elif _in_pure_fp16_guard():
return self._train_pure_fp16_program
else:
return self._train_program
return self.train_program
else:
return self._infer_program
return self.infer_program

@property
def program_id(self):
Expand All @@ -399,7 +402,30 @@ def program_id(self):
else:
return self._train_program_id
else:
return self._infer_program_id
if _in_amp_guard():
return self._infer_amp_program_id
elif _in_pure_fp16_guard():
return self._infer_pure_fp16_program_id
else:
return self._infer_program_id

@property
def train_program(self):
if _in_amp_guard():
return self._train_amp_program
elif _in_pure_fp16_guard():
return self._train_pure_fp16_program
else:
return self._train_program

@property
def infer_program(self):
if _in_amp_guard():
return self._infer_amp_program
elif _in_pure_fp16_guard():
return self._infer_pure_fp16_program
else:
return self._infer_program

def _prepare(self, inputs):
"""
Expand Down