Skip to content

Commit

Permalink
Merge pull request #78 from GiovanniFyc/patch-1
Browse files Browse the repository at this point in the history
Fit data type error in trt infer.
  • Loading branch information
Peterande authored Nov 25, 2024
2 parents 3a56bb7 + d2f6546 commit 1d1278c
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tools/inference/trt_inf.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def get_bindings(self, engine, context, max_batch_size=32, device=None) -> Order

def run_torch(self, blob):
for n in self.input_names:
if blob[n].dtype is not self.bindings[n].data.dtype:
blob[n] = blob[n].to(dtype=self.bindings[n].data.dtype)
if self.bindings[n].shape != blob[n].shape:
self.context.set_input_shape(n, blob[n].shape)
self.bindings[n] = self.bindings[n]._replace(shape=blob[n].shape)
Expand Down

0 comments on commit 1d1278c

Please sign in to comment.