Skip to content

Commit 98b9267

Browse files
committed
ahmad's comments
1 parent 01a5621 commit 98b9267

File tree

4 files changed

+69
-13
lines changed

4 files changed

+69
-13
lines changed

torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp

+20-10
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ std::vector<torch::Tensor> decode_jpegs_cuda(
4040
"torchvision.csrc.io.image.cuda.decode_jpegs_cuda.decode_jpegs_cuda");
4141

4242
std::lock_guard<std::mutex> lock(decoderMutex);
43+
std::vector<torch::Tensor> contig_images;
44+
contig_images.reserve(encoded_images.size());
4345

4446
for (auto& encoded_image : encoded_images) {
4547
TORCH_CHECK(
@@ -52,6 +54,13 @@ std::vector<torch::Tensor> decode_jpegs_cuda(
5254
TORCH_CHECK(
5355
encoded_image.dim() == 1 && encoded_image.numel() > 0,
5456
"Expected a non empty 1-dimensional tensor");
57+
58+
// nvjpeg requires images to be contiguous
59+
if (encoded_image.is_contiguous()) {
60+
contig_images.push_back(encoded_image);
61+
} else {
62+
contig_images.push_back(encoded_image.contiguous());
63+
}
5564
}
5665

5766
TORCH_CHECK(device.is_cuda(), "Expected a cuda device");
@@ -81,9 +90,11 @@ std::vector<torch::Tensor> decode_jpegs_cuda(
8190

8291
if (cudaJpegDecoder == nullptr || device != cudaJpegDecoder->target_device) {
8392
if (cudaJpegDecoder != nullptr)
84-
delete cudaJpegDecoder.release();
85-
cudaJpegDecoder = std::make_unique<CUDAJpegDecoder>(device);
86-
std::atexit([]() { delete cudaJpegDecoder.release(); });
93+
cudaJpegDecoder.reset(new CUDAJpegDecoder(device));
94+
else {
95+
cudaJpegDecoder = std::make_unique<CUDAJpegDecoder>(device);
96+
std::atexit([]() { cudaJpegDecoder.reset(); });
97+
}
8798
}
8899

89100
nvjpegOutputFormat_t output_format;
@@ -109,14 +120,13 @@ std::vector<torch::Tensor> decode_jpegs_cuda(
109120

110121
try {
111122
at::cuda::CUDAEvent event;
123+
auto result = cudaJpegDecoder->decode_images(contig_images, output_format);
124+
auto current_stream{
125+
device.has_index() ? at::cuda::getCurrentCUDAStream(
126+
cudaJpegDecoder->original_device.index())
127+
: at::cuda::getCurrentCUDAStream()};
112128
event.record(cudaJpegDecoder->stream);
113-
auto result = cudaJpegDecoder->decode_images(encoded_images, output_format);
114-
if (device.has_index())
115-
event.block(at::cuda::getCurrentCUDAStream(
116-
cudaJpegDecoder->original_device.index()));
117-
else
118-
event.block(at::cuda::getCurrentCUDAStream());
119-
return result;
129+
event.block(current_stream) return result;
120130
} catch (const std::exception& e) {
121131
if (typeid(e) != typeid(std::runtime_error)) {
122132
TORCH_CHECK(false, "Error while decoding JPEG images: ", e.what());

torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class CUDAJpegDecoder {
2222
const torch::Device target_device;
2323
const c10::cuda::CUDAStream stream;
2424

25-
protected:
25+
private:
2626
std::tuple<
2727
std::vector<nvjpegImage_t>,
2828
std::vector<torch::Tensor>,

torchvision/csrc/io/image/cuda/encode_decode_jpegs_cuda.h

+38
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,49 @@
88
namespace vision {
99
namespace image {
1010

11+
/*
12+
13+
Fast jpeg decoding with CUDA.
14+
A100+ GPUs have dedicated hardware support for jpeg decoding.
15+
16+
Args:
17+
- encoded_images (const std::vector<torch::Tensor>&): a vector of tensors
18+
containing the jpeg bitstreams to be decoded. Each tensor must have dtype
19+
torch.uint8 and device cpu
20+
- mode (ImageReadMode): IMAGE_READ_MODE_UNCHANGED, IMAGE_READ_MODE_GRAY and
21+
IMAGE_READ_MODE_RGB are supported
22+
- device (torch::Device): The desired CUDA device to run the decoding on and
23+
which will contain the output tensors
24+
25+
Returns:
26+
- decoded_images (std::vector<torch::Tensor>): a vector of torch::Tensors of
27+
dtype torch.uint8 on the specified <device> containing the decoded images
28+
29+
Notes:
30+
- If a single image fails, the whole batch fails.
31+
- This function is thread-safe
32+
*/
1133
C10_EXPORT std::vector<torch::Tensor> decode_jpegs_cuda(
1234
const std::vector<torch::Tensor>& encoded_images,
1335
vision::image::ImageReadMode mode,
1436
torch::Device device);
1537

38+
/*
39+
Fast jpeg encoding with CUDA.
40+
41+
Args:
42+
- decoded_images (const std::vector<torch::Tensor>&): a vector of contiguous
43+
CUDA tensors of dtype torch.uint8 to be encoded.
44+
- quality (int64_t): 0-100, 75 is the default
45+
46+
Returns:
47+
- encoded_images (std::vector<torch::Tensor>): a vector of CUDA
48+
torch::Tensors of dtype torch.uint8 containing the encoded images
49+
50+
Notes:
51+
- If a single image fails, the whole batch fails.
52+
- This function is thread-safe
53+
*/
1654
C10_EXPORT std::vector<torch::Tensor> encode_jpegs_cuda(
1755
const std::vector<torch::Tensor>& decoded_images,
1856
const int64_t quality);

torchvision/io/image.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def decode_jpeg(
145145
"""
146146
Decodes a (list of) JPEG image(s) into a (list of) 3 dimensional RGB or grayscale Tensor(s).
147147
Optionally converts the image(s) to the desired format.
148-
The values of the output tensor(s) are uint8 between 0 and 255.
148+
149149
150150
.. note::
151151
When using a CUDA device, passing a list of tensors is more efficient than repeated individual calls to ``decode_jpeg``.
@@ -175,7 +175,15 @@ def decode_jpeg(
175175
Default: False. Only implemented for JPEG format on CPU.
176176
177177
Returns:
178-
output (Tensor[image_channels, image_height, image_width])
178+
output (Tensor[image_channels, image_height, image_width] or list[Tensor[image_channels, image_height, image_width]]):
179+
The values of the output tensor(s) are uint8 between 0 and 255. output.device will be set to the specified ``device``
180+
181+
Notes:
182+
The cuda version of this function been designed with thread-safety in mind.
183+
The CPU version seems to work fine as well but using it in a multithreaded environment is at your own risk.
184+
This function does not return partial results in case of an error.
185+
186+
179187
"""
180188
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
181189
_log_api_usage_once(decode_jpeg)

0 commit comments

Comments
 (0)