@@ -40,6 +40,8 @@ std::vector<torch::Tensor> decode_jpegs_cuda(
40
40
" torchvision.csrc.io.image.cuda.decode_jpegs_cuda.decode_jpegs_cuda" );
41
41
42
42
std::lock_guard<std::mutex> lock (decoderMutex);
43
+ std::vector<torch::Tensor> contig_images;
44
+ contig_images.reserve (encoded_images.size ());
43
45
44
46
for (auto & encoded_image : encoded_images) {
45
47
TORCH_CHECK (
@@ -52,6 +54,13 @@ std::vector<torch::Tensor> decode_jpegs_cuda(
52
54
TORCH_CHECK (
53
55
encoded_image.dim () == 1 && encoded_image.numel () > 0 ,
54
56
" Expected a non empty 1-dimensional tensor" );
57
+
58
+ // nvjpeg requires images to be contiguous
59
+ if (encoded_image.is_contiguous ()) {
60
+ contig_images.push_back (encoded_image);
61
+ } else {
62
+ contig_images.push_back (encoded_image.contiguous ());
63
+ }
55
64
}
56
65
57
66
TORCH_CHECK (device.is_cuda (), " Expected a cuda device" );
@@ -81,9 +90,11 @@ std::vector<torch::Tensor> decode_jpegs_cuda(
81
90
82
91
if (cudaJpegDecoder == nullptr || device != cudaJpegDecoder->target_device ) {
83
92
if (cudaJpegDecoder != nullptr )
84
- delete cudaJpegDecoder.release ();
85
- cudaJpegDecoder = std::make_unique<CUDAJpegDecoder>(device);
86
- std::atexit ([]() { delete cudaJpegDecoder.release (); });
93
+ cudaJpegDecoder.reset (new CUDAJpegDecoder (device));
94
+ else {
95
+ cudaJpegDecoder = std::make_unique<CUDAJpegDecoder>(device);
96
+ std::atexit ([]() { cudaJpegDecoder.reset (); });
97
+ }
87
98
}
88
99
89
100
nvjpegOutputFormat_t output_format;
@@ -109,14 +120,13 @@ std::vector<torch::Tensor> decode_jpegs_cuda(
109
120
110
121
try {
111
122
at::cuda::CUDAEvent event;
123
+ auto result = cudaJpegDecoder->decode_images (contig_images, output_format);
124
+ auto current_stream{
125
+ device.has_index () ? at::cuda::getCurrentCUDAStream (
126
+ cudaJpegDecoder->original_device .index ())
127
+ : at::cuda::getCurrentCUDAStream ()};
112
128
event.record (cudaJpegDecoder->stream );
113
- auto result = cudaJpegDecoder->decode_images (encoded_images, output_format);
114
- if (device.has_index ())
115
- event.block (at::cuda::getCurrentCUDAStream (
116
- cudaJpegDecoder->original_device .index ()));
117
- else
118
- event.block (at::cuda::getCurrentCUDAStream ());
119
- return result;
129
+ event.block (current_stream) return result;
120
130
} catch (const std::exception & e) {
121
131
if (typeid (e) != typeid (std::runtime_error)) {
122
132
TORCH_CHECK (false , " Error while decoding JPEG images: " , e.what ());
0 commit comments