Skip to content

Commit cd6baeb

Browse files
committed
Added support for EXIF orientation transform in read_image for PNG
Description: - added support for EXIF orientation transform in read_image for PNG - restructured exif.h - added tests
1 parent c01f6d0 commit cd6baeb

File tree

8 files changed

+122
-59
lines changed

8 files changed

+122
-59
lines changed

test/test_image.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -100,14 +100,15 @@ def test_decode_jpeg(img_path, pil_mode, mode):
100100
assert abs_mean_diff < 2
101101

102102

103+
@pytest.mark.parametrize("codec", [("png", "PNG"), ("jpg", "JPEG")])
103104
@pytest.mark.parametrize("orientation", [1, 2, 3, 4, 5, 6, 7, 8, 0])
104-
def test_decode_jpeg_with_exif_orientation(tmpdir, orientation):
105-
fp = os.path.join(tmpdir, f"exif_oriented_{orientation}.jpg")
105+
def test_decode_with_exif_orientation(tmpdir, codec, orientation):
106+
fp = os.path.join(tmpdir, f"exif_oriented_{orientation}.{codec[0]}")
106107
t = torch.randint(0, 256, size=(3, 256, 257), dtype=torch.uint8)
107108
im = F.to_pil_image(t)
108109
exif = im.getexif()
109110
exif[0x0112] = orientation # set exif orientation
110-
im.save(fp, "JPEG", exif=exif.tobytes())
111+
im.save(fp, codec[1], exif=exif.tobytes())
111112

112113
data = read_file(fp)
113114
output = decode_image(data, apply_exif_orientation=True)

torchvision/csrc/io/image/cpu/decode_image.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ torch::Tensor decode_image(
2727
if (memcmp(jpeg_signature, datap, 3) == 0) {
2828
return decode_jpeg(data, mode, apply_exif_orientation);
2929
} else if (memcmp(png_signature, datap, 4) == 0) {
30-
return decode_png(data, mode);
30+
return decode_png(
31+
data, mode, /*allow_16_bits=*/false, apply_exif_orientation);
3132
} else {
3233
TORCH_CHECK(
3334
false,

torchvision/csrc/io/image/cpu/decode_jpeg.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ torch::Tensor decode_jpeg(
203203

204204
int exif_orientation = -1;
205205
if (apply_exif_orientation) {
206-
exif_orientation = fetch_exif_orientation(&cinfo);
206+
exif_orientation = fetch_jpeg_exif_orientation(&cinfo);
207207
}
208208

209209
jpeg_start_decompress(&cinfo);

torchvision/csrc/io/image/cpu/decode_png.cpp

+19-3
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
#include "decode_png.h"
22
#include "common_png.h"
3+
#include "exif.h"
34

45
namespace vision {
56
namespace image {
67

8+
using namespace exif_private;
9+
710
#if !PNG_FOUND
811
torch::Tensor decode_png(
912
const torch::Tensor& data,
1013
ImageReadMode mode,
11-
bool allow_16_bits) {
14+
bool allow_16_bits,
15+
bool apply_exif_orientation) {
1216
TORCH_CHECK(
1317
false, "decode_png: torchvision not compiled with libPNG support");
1418
}
@@ -22,7 +26,8 @@ bool is_little_endian() {
2226
torch::Tensor decode_png(
2327
const torch::Tensor& data,
2428
ImageReadMode mode,
25-
bool allow_16_bits) {
29+
bool allow_16_bits,
30+
bool apply_exif_orientation) {
2631
C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.image.cpu.decode_png.decode_png");
2732
// Check that the input tensor dtype is uint8
2833
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
@@ -234,8 +239,19 @@ torch::Tensor decode_png(
234239
t_ptr = tensor.accessor<int32_t, 3>().data();
235240
}
236241
}
242+
243+
int exif_orientation = -1;
244+
if (apply_exif_orientation) {
245+
exif_orientation = fetch_png_exif_orientation(png_ptr, info_ptr);
246+
}
247+
237248
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
238-
return tensor.permute({2, 0, 1});
249+
250+
auto output = tensor.permute({2, 0, 1});
251+
if (apply_exif_orientation) {
252+
return exif_orientation_transform(output, exif_orientation);
253+
}
254+
return output;
239255
}
240256
#endif
241257

torchvision/csrc/io/image/cpu/decode_png.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ namespace image {
99
C10_EXPORT torch::Tensor decode_png(
1010
const torch::Tensor& data,
1111
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED,
12-
bool allow_16_bits = false);
12+
bool allow_16_bits = false,
13+
bool apply_exif_orientation = false);
1314

1415
} // namespace image
1516
} // namespace vision

torchvision/csrc/io/image/cpu/exif.h

+86-47
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,12 @@ direct,
5151
// https://github.com/opencv/opencv/blob/097891e311fae1d8354eb092a0fd0171e630d78c/modules/imgcodecs/src/exif.cpp
5252

5353
#if JPEG_FOUND
54-
5554
#include <jpeglib.h>
55+
#endif
56+
#if PNG_FOUND
57+
#include <png.h>
58+
#endif
59+
5660
#include <torch/types.h>
5761

5862
namespace vision {
@@ -125,8 +129,48 @@ inline uint32_t get_uint32(
125129
(exif_data[offset + 2] << 8) + exif_data[offset + 3];
126130
}
127131

128-
inline int fetch_exif_orientation(j_decompress_ptr cinfo) {
132+
inline int fetch_exif_orientation(unsigned char* exif_data_ptr, size_t size) {
129133
int exif_orientation = -1;
134+
135+
// Exif binary structure looks like this
136+
// First 6 bytes: [E, x, i, f, 0, 0]
137+
// Endianness, 2 bytes : [M, M] or [I, I]
138+
// Tag mark, 2 bytes: [0, 0x2a]
139+
// Offset, 4 bytes
140+
// Num entries, 2 bytes
141+
// Tag entries and data, tag has 2 bytes and its data has 10 bytes
142+
// For more details:
143+
// http://www.media.mit.edu/pia/Research/deepview/exif.html
144+
145+
ExifDataReader exif_data(exif_data_ptr, size);
146+
auto endianness = get_endianness(exif_data);
147+
148+
// Checking whether Tag Mark (0x002A) correspond to one contained in the
149+
// Jpeg file
150+
uint16_t tag_mark = get_uint16(exif_data, endianness, 2);
151+
if (tag_mark == REQ_EXIF_TAG_MARK) {
152+
auto offset = get_uint32(exif_data, endianness, 4);
153+
size_t num_entry = get_uint16(exif_data, endianness, offset);
154+
offset += 2; // go to start of tag fields
155+
constexpr size_t tiff_field_size = 12;
156+
for (size_t entry = 0; entry < num_entry; entry++) {
157+
// Here we just search for orientation tag and parse it
158+
auto tag_num = get_uint16(exif_data, endianness, offset);
159+
if (tag_num == INCORRECT_TAG) {
160+
break;
161+
}
162+
if (tag_num == ORIENTATION_EXIF_TAG) {
163+
exif_orientation = get_uint16(exif_data, endianness, offset + 8);
164+
break;
165+
}
166+
offset += tiff_field_size;
167+
}
168+
}
169+
return exif_orientation;
170+
}
171+
172+
#if JPEG_FOUND
173+
inline int fetch_jpeg_exif_orientation(j_decompress_ptr cinfo) {
130174
// Check for Exif marker APP1
131175
jpeg_saved_marker_ptr exif_marker = 0;
132176
jpeg_saved_marker_ptr cmarker = cinfo->marker_list;
@@ -137,51 +181,48 @@ inline int fetch_exif_orientation(j_decompress_ptr cinfo) {
137181
cmarker = cmarker->next;
138182
}
139183

140-
if (exif_marker) {
141-
// Exif binary structure looks like this
142-
// First 6 bytes: [E, x, i, f, 0, 0]
143-
// Endianness, 2 bytes : [M, M] or [I, I]
144-
// Tag mark, 2 bytes: [0, 0x2a]
145-
// Offset, 4 bytes
146-
// Num entries, 2 bytes
147-
// Tag entries and data, tag has 2 bytes and its data has 10 bytes
148-
// For more details:
149-
// http://www.media.mit.edu/pia/Research/deepview/exif.html
150-
151-
// Bytes from Exif size field to the first TIFF header
152-
constexpr size_t start_offset = 6;
153-
if (exif_marker->data_length > start_offset) {
154-
auto* exif_data_ptr = exif_marker->data + start_offset;
155-
auto size = exif_marker->data_length - start_offset;
156-
157-
ExifDataReader exif_data(exif_data_ptr, size);
158-
auto endianness = get_endianness(exif_data);
159-
160-
// Checking whether Tag Mark (0x002A) correspond to one contained in the
161-
// Jpeg file
162-
uint16_t tag_mark = get_uint16(exif_data, endianness, 2);
163-
if (tag_mark == REQ_EXIF_TAG_MARK) {
164-
auto offset = get_uint32(exif_data, endianness, 4);
165-
size_t num_entry = get_uint16(exif_data, endianness, offset);
166-
offset += 2; // go to start of tag fields
167-
constexpr size_t tiff_field_size = 12;
168-
for (size_t entry = 0; entry < num_entry; entry++) {
169-
// Here we just search for orientation tag and parse it
170-
auto tag_num = get_uint16(exif_data, endianness, offset);
171-
if (tag_num == INCORRECT_TAG) {
172-
break;
173-
}
174-
if (tag_num == ORIENTATION_EXIF_TAG) {
175-
exif_orientation = get_uint16(exif_data, endianness, offset + 8);
176-
break;
177-
}
178-
offset += tiff_field_size;
179-
}
180-
}
181-
}
184+
if (!exif_marker) {
185+
return -1;
182186
}
183-
return exif_orientation;
187+
188+
constexpr size_t start_offset = 6;
189+
if (exif_marker->data_length <= start_offset) {
190+
return -1;
191+
}
192+
193+
auto* exif_data_ptr = exif_marker->data + start_offset;
194+
auto size = exif_marker->data_length - start_offset;
195+
196+
return fetch_exif_orientation(exif_data_ptr, size);
197+
}
198+
#else
199+
inline int fetch_jpeg_exif_orientation(j_decompress_ptr cinfo) {
200+
return -1;
201+
}
202+
#endif // #if JPEG_FOUND
203+
204+
#if PNG_FOUND
205+
inline int fetch_png_exif_orientation(png_structp png_ptr, png_infop info_ptr) {
206+
#ifdef PNG_eXIf_SUPPORTED
207+
png_uint_32 num_exif = 0;
208+
png_bytep exif = 0;
209+
210+
// Exif info could be in info_ptr
211+
if (png_get_valid(png_ptr, info_ptr, PNG_INFO_eXIf)) {
212+
png_get_eXIf_1(png_ptr, info_ptr, &num_exif, &exif);
213+
}
214+
215+
if (exif && num_exif > 0) {
216+
return fetch_exif_orientation(exif, num_exif);
217+
}
218+
#endif // #ifdef PNG_eXIf_SUPPORTED
219+
return -1;
220+
}
221+
#else
222+
inline int fetch_png_exif_orientation(j_decompress_ptr cinfo) {
223+
return -1;
184224
}
225+
#endif // #if PNG_FOUND
185226

186227
constexpr uint16_t IMAGE_ORIENTATION_TL = 1; // normal orientation
187228
constexpr uint16_t IMAGE_ORIENTATION_TR = 2; // needs horizontal flip
@@ -222,5 +263,3 @@ inline torch::Tensor exif_orientation_transform(
222263
} // namespace exif_private
223264
} // namespace image
224265
} // namespace vision
225-
226-
#endif

torchvision/csrc/io/image/image.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ namespace image {
2121

2222
static auto registry =
2323
torch::RegisterOperators()
24-
.op("image::decode_png", &decode_png)
24+
.op("image::decode_png(Tensor data, int mode, bool allow_16_bits = False, bool apply_exif_orientation=False) -> Tensor",
25+
&decode_png)
2526
.op("image::encode_png", &encode_png)
2627
.op("image::decode_jpeg(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor",
2728
&decode_jpeg)

torchvision/io/image.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ def write_file(filename: str, data: torch.Tensor) -> None:
6767
torch.ops.image.write_file(filename, data)
6868

6969

70-
def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
70+
def decode_png(
71+
input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, apply_exif_orientation: bool = False
72+
) -> torch.Tensor:
7173
"""
7274
Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor.
7375
Optionally converts the image to the desired format.
@@ -80,13 +82,15 @@ def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGE
8082
converting the image. Default: ``ImageReadMode.UNCHANGED``.
8183
See `ImageReadMode` class for more information on various
8284
available modes.
85+
apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
86+
Default: False.
8387
8488
Returns:
8589
output (Tensor[image_channels, image_height, image_width])
8690
"""
8791
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
8892
_log_api_usage_once(decode_png)
89-
output = torch.ops.image.decode_png(input, mode.value, False)
93+
output = torch.ops.image.decode_png(input, mode.value, False, apply_exif_orientation)
9094
return output
9195

9296

0 commit comments

Comments
 (0)