Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: galv <[email protected]>
  • Loading branch information
galv committed May 29, 2024
1 parent abc61e6 commit 766ff80
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,29 +691,37 @@ def _partial_graphs_compile(self):
with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.separate_graphs.before_outer_loop, stream=stream_for_graph, capture_error_mode="thread_local"),
torch.cuda.graph(
self.separate_graphs.before_outer_loop, stream=stream_for_graph, capture_error_mode="thread_local"
),
):
self._before_outer_loop()

with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.separate_graphs.before_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local"),
torch.cuda.graph(
self.separate_graphs.before_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local"
),
):
self._before_inner_loop_get_decoder_output()
self._before_inner_loop_get_joint_output()

with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.separate_graphs.inner_loop_code, stream=stream_for_graph, capture_error_mode="thread_local"),
torch.cuda.graph(
self.separate_graphs.inner_loop_code, stream=stream_for_graph, capture_error_mode="thread_local"
),
):
self._inner_loop_code()

with (
torch.cuda.stream(stream_for_graph),
torch.inference_mode(),
torch.cuda.graph(self.separate_graphs.after_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local"),
torch.cuda.graph(
self.separate_graphs.after_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local"
),
):
self._after_inner_loop()

Expand Down

0 comments on commit 766ff80

Please sign in to comment.