From deb8d28c622d0898aa31f5a1c3dd799109b85afd Mon Sep 17 00:00:00 2001 From: Daniel Galvez Date: Wed, 29 May 2024 10:38:32 -0700 Subject: [PATCH] Add relevant changs to TDT cuda graphs decoding as well. I didn't test this because I'm not sure how. But it seems low risk. Signed-off-by: Daniel Galvez --- .../asr/parts/submodules/tdt_loop_labels_computer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py b/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py index 7ad7065e019c1..65bbe6aa608eb 100644 --- a/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py +++ b/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py @@ -691,14 +691,14 @@ def _partial_graphs_compile(self): with ( torch.cuda.stream(stream_for_graph), torch.inference_mode(), - torch.cuda.graph(self.separate_graphs.before_outer_loop, stream=stream_for_graph), + torch.cuda.graph(self.separate_graphs.before_outer_loop, stream=stream_for_graph, capture_error_mode="thread_local"), ): self._before_outer_loop() with ( torch.cuda.stream(stream_for_graph), torch.inference_mode(), - torch.cuda.graph(self.separate_graphs.before_inner_loop, stream=stream_for_graph), + torch.cuda.graph(self.separate_graphs.before_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local"), ): self._before_inner_loop_get_decoder_output() self._before_inner_loop_get_joint_output() @@ -706,14 +706,14 @@ def _partial_graphs_compile(self): with ( torch.cuda.stream(stream_for_graph), torch.inference_mode(), - torch.cuda.graph(self.separate_graphs.inner_loop_code, stream=stream_for_graph), + torch.cuda.graph(self.separate_graphs.inner_loop_code, stream=stream_for_graph, capture_error_mode="thread_local"), ): self._inner_loop_code() with ( torch.cuda.stream(stream_for_graph), torch.inference_mode(), - torch.cuda.graph(self.separate_graphs.after_inner_loop, stream=stream_for_graph), + torch.cuda.graph(self.separate_graphs.after_inner_loop, stream=stream_for_graph, capture_error_mode="thread_local"), ): self._after_inner_loop() @@ -726,7 +726,7 @@ def _full_graph_compile(self): with ( torch.cuda.stream(stream_for_graph), torch.inference_mode(), - torch.cuda.graph(self.full_graph, stream=stream_for_graph), + torch.cuda.graph(self.full_graph, stream=stream_for_graph, capture_error_mode="thread_local"), ): self._before_outer_loop()