Skip to content

Commit d0ebeb5

Browse files
authored
Allow decode_image to support paths (#8624)
1 parent c36025a commit d0ebeb5

9 files changed

+66
-59
lines changed

docs/source/io.rst

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ For encoding, JPEG (cpu and CUDA) and PNG are supported.
1919
:toctree: generated/
2020
:template: function.rst
2121

22-
read_image
2322
decode_image
2423
encode_jpeg
2524
decode_jpeg
@@ -38,6 +37,13 @@ For encoding, JPEG (cpu and CUDA) and PNG are supported.
3837

3938
ImageReadMode
4039

40+
Obsolete decoding function:
41+
42+
.. autosummary::
43+
:toctree: generated/
44+
:template: class.rst
45+
46+
read_image
4147

4248

4349
Video

docs/source/models.rst

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -226,10 +226,10 @@ Here is an example of how to use the pre-trained image classification models:
226226

227227
.. code:: python
228228
229-
from torchvision.io import read_image
229+
from torchvision.io import decode_image
230230
from torchvision.models import resnet50, ResNet50_Weights
231231
232-
img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
232+
img = decode_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
233233
234234
# Step 1: Initialize model with the best available weights
235235
weights = ResNet50_Weights.DEFAULT
@@ -283,10 +283,10 @@ Here is an example of how to use the pre-trained quantized image classification
283283

284284
.. code:: python
285285
286-
from torchvision.io import read_image
286+
from torchvision.io import decode_image
287287
from torchvision.models.quantization import resnet50, ResNet50_QuantizedWeights
288288
289-
img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
289+
img = decode_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
290290
291291
# Step 1: Initialize model with the best available weights
292292
weights = ResNet50_QuantizedWeights.DEFAULT
@@ -339,11 +339,11 @@ Here is an example of how to use the pre-trained semantic segmentation models:
339339

340340
.. code:: python
341341
342-
from torchvision.io.image import read_image
342+
from torchvision.io.image import decode_image
343343
from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights
344344
from torchvision.transforms.functional import to_pil_image
345345
346-
img = read_image("gallery/assets/dog1.jpg")
346+
img = decode_image("gallery/assets/dog1.jpg")
347347
348348
# Step 1: Initialize model with the best available weights
349349
weights = FCN_ResNet50_Weights.DEFAULT
@@ -411,12 +411,12 @@ Here is an example of how to use the pre-trained object detection models:
411411
.. code:: python
412412
413413
414-
from torchvision.io.image import read_image
414+
from torchvision.io.image import decode_image
415415
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
416416
from torchvision.utils import draw_bounding_boxes
417417
from torchvision.transforms.functional import to_pil_image
418418
419-
img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
419+
img = decode_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg")
420420
421421
# Step 1: Initialize model with the best available weights
422422
weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT

gallery/others/plot_repurposing_annotations.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,12 @@ def show(imgs):
6666
# We will take images and masks from the `PenFudan Dataset <https://www.cis.upenn.edu/~jshi/ped_html/>`_.
6767

6868

69-
from torchvision.io import read_image
69+
from torchvision.io import decode_image
7070

7171
img_path = os.path.join(ASSETS_DIRECTORY, "FudanPed00054.png")
7272
mask_path = os.path.join(ASSETS_DIRECTORY, "FudanPed00054_mask.png")
73-
img = read_image(img_path)
74-
mask = read_image(mask_path)
73+
img = decode_image(img_path)
74+
mask = decode_image(mask_path)
7575

7676

7777
# %%
@@ -181,8 +181,8 @@ def __getitem__(self, idx):
181181
img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
182182
mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
183183

184-
img = read_image(img_path)
185-
mask = read_image(mask_path)
184+
img = decode_image(img_path)
185+
mask = decode_image(mask_path)
186186

187187
img = F.convert_image_dtype(img, dtype=torch.float)
188188
mask = F.convert_image_dtype(mask, dtype=torch.float)

gallery/others/plot_scripted_tensor_transforms.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import torch.nn as nn
2222

2323
import torchvision.transforms as v1
24-
from torchvision.io import read_image
24+
from torchvision.io import decode_image
2525

2626
plt.rcParams["savefig.bbox"] = 'tight'
2727
torch.manual_seed(1)
@@ -39,8 +39,8 @@
3939
# :class:`torch.nn.Sequential` instead of
4040
# :class:`~torchvision.transforms.v2.Compose`:
4141

42-
dog1 = read_image(str(ASSETS_PATH / 'dog1.jpg'))
43-
dog2 = read_image(str(ASSETS_PATH / 'dog2.jpg'))
42+
dog1 = decode_image(str(ASSETS_PATH / 'dog1.jpg'))
43+
dog2 = decode_image(str(ASSETS_PATH / 'dog2.jpg'))
4444

4545
transforms = torch.nn.Sequential(
4646
v1.RandomCrop(224),

gallery/others/plot_visualization_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ def show(imgs):
4242
# image of dtype ``uint8`` as input.
4343

4444
from torchvision.utils import make_grid
45-
from torchvision.io import read_image
45+
from torchvision.io import decode_image
4646
from pathlib import Path
4747

48-
dog1_int = read_image(str(Path('../assets') / 'dog1.jpg'))
49-
dog2_int = read_image(str(Path('../assets') / 'dog2.jpg'))
48+
dog1_int = decode_image(str(Path('../assets') / 'dog1.jpg'))
49+
dog2_int = decode_image(str(Path('../assets') / 'dog2.jpg'))
5050
dog_list = [dog1_int, dog2_int]
5151

5252
grid = make_grid(dog_list)
@@ -362,9 +362,9 @@ def show(imgs):
362362
#
363363

364364
from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
365-
from torchvision.io import read_image
365+
from torchvision.io import decode_image
366366

367-
person_int = read_image(str(Path("../assets") / "person1.jpg"))
367+
person_int = decode_image(str(Path("../assets") / "person1.jpg"))
368368

369369
weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
370370
transforms = weights.transforms()

gallery/transforms/plot_transforms_getting_started.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121
plt.rcParams["savefig.bbox"] = 'tight'
2222

2323
from torchvision.transforms import v2
24-
from torchvision.io import read_image
24+
from torchvision.io import decode_image
2525

2626
torch.manual_seed(1)
2727

2828
# If you're trying to run that on Colab, you can download the assets and the
2929
# helpers from https://github.com/pytorch/vision/tree/main/gallery/
3030
from helpers import plot
31-
img = read_image(str(Path('../assets') / 'astronaut.jpg'))
31+
img = decode_image(str(Path('../assets') / 'astronaut.jpg'))
3232
print(f"{type(img) = }, {img.dtype = }, {img.shape = }")
3333

3434
# %%

test/smoke_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import torch
88
import torchvision
9-
from torchvision.io import decode_jpeg, decode_webp, read_file, read_image
9+
from torchvision.io import decode_image, decode_jpeg, decode_webp, read_file
1010
from torchvision.models import resnet50, ResNet50_Weights
1111

1212

@@ -21,13 +21,13 @@ def smoke_test_torchvision() -> None:
2121

2222

2323
def smoke_test_torchvision_read_decode() -> None:
24-
img_jpg = read_image(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg"))
24+
img_jpg = decode_image(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg"))
2525
if img_jpg.shape != (3, 606, 517):
2626
raise RuntimeError(f"Unexpected shape of img_jpg: {img_jpg.shape}")
27-
img_png = read_image(str(SCRIPT_DIR / "assets" / "interlaced_png" / "wizard_low.png"))
27+
img_png = decode_image(str(SCRIPT_DIR / "assets" / "interlaced_png" / "wizard_low.png"))
2828
if img_png.shape != (4, 471, 354):
2929
raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")
30-
img_webp = read_image(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch.webp"))
30+
img_webp = decode_image(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch.webp"))
3131
if img_webp.shape != (3, 100, 100):
3232
raise RuntimeError(f"Unexpected shape of img_webp: {img_webp.shape}")
3333

@@ -54,7 +54,7 @@ def smoke_test_compile() -> None:
5454

5555

5656
def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None:
57-
img = read_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")).to(device)
57+
img = decode_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")).to(device)
5858

5959
# Step 1: Initialize model with the best available weights
6060
weights = ResNet50_Weights.DEFAULT

test/test_image.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,5 +1044,26 @@ def test_decode_heic(decode_fun, scripted):
10441044
img += 123 # make sure image buffer wasn't freed by underlying decoding lib
10451045

10461046

1047+
@pytest.mark.parametrize("input_type", ("Path", "str", "tensor"))
1048+
@pytest.mark.parametrize("scripted", (False, True))
1049+
def test_decode_image_path(input_type, scripted):
1050+
# Check that decode_image can support not just tensors as input
1051+
path = next(get_images(IMAGE_ROOT, ".jpg"))
1052+
if input_type == "Path":
1053+
input = Path(path)
1054+
elif input_type == "str":
1055+
input = path
1056+
elif input_type == "tensor":
1057+
input = read_file(path)
1058+
else:
1059+
raise ValueError("Oops")
1060+
1061+
if scripted and input_type == "Path":
1062+
pytest.xfail(reason="Can't pass a Path when scripting")
1063+
1064+
decode_fun = torch.jit.script(decode_image) if scripted else decode_image
1065+
decode_fun(input)
1066+
1067+
10471068
if __name__ == "__main__":
10481069
pytest.main([__file__])

torchvision/io/image.py

Lines changed: 10 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -277,13 +277,13 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
277277

278278

279279
def decode_image(
280-
input: torch.Tensor,
280+
input: Union[torch.Tensor, str],
281281
mode: ImageReadMode = ImageReadMode.UNCHANGED,
282282
apply_exif_orientation: bool = False,
283283
) -> torch.Tensor:
284-
"""
285-
Detect whether an image is a JPEG, PNG, WEBP, or GIF and performs the
286-
appropriate operation to decode the image into a Tensor.
284+
"""Decode an image into a tensor.
285+
286+
Currently supported image formats are jpeg, png, gif and webp.
287287
288288
The values of the output tensor are in uint8 in [0, 255] for most cases.
289289
@@ -295,8 +295,9 @@ def decode_image(
295295
tensor.
296296
297297
Args:
298-
input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the
299-
image.
298+
input (Tensor or str or ``pathlib.Path``): The image to decode. If a
299+
tensor is passed, it must be one dimensional uint8 tensor containing
300+
the raw bytes of the image. Otherwise, this must be a path to the image file.
300301
mode (ImageReadMode): the read mode used for optionally converting the image.
301302
Default: ``ImageReadMode.UNCHANGED``.
302303
See ``ImageReadMode`` class for more information on various
@@ -309,6 +310,8 @@ def decode_image(
309310
"""
310311
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
311312
_log_api_usage_once(decode_image)
313+
if not isinstance(input, torch.Tensor):
314+
input = read_file(str(input))
312315
output = torch.ops.image.decode_image(input, mode.value, apply_exif_orientation)
313316
return output
314317

@@ -318,30 +321,7 @@ def read_image(
318321
mode: ImageReadMode = ImageReadMode.UNCHANGED,
319322
apply_exif_orientation: bool = False,
320323
) -> torch.Tensor:
321-
"""
322-
Reads a JPEG, PNG, WEBP, or GIF image into a Tensor.
323-
324-
The values of the output tensor are in uint8 in [0, 255] for most cases.
325-
326-
If the image is a 16-bit png, then the output tensor is uint16 in [0, 65535]
327-
(supported from torchvision ``0.21``. Since uint16 support is limited in
328-
pytorch, we recommend calling
329-
:func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True``
330-
after this function to convert the decoded image into a uint8 or float
331-
tensor.
332-
333-
Args:
334-
path (str or ``pathlib.Path``): path of the image.
335-
mode (ImageReadMode): the read mode used for optionally converting the image.
336-
Default: ``ImageReadMode.UNCHANGED``.
337-
See ``ImageReadMode`` class for more information on various
338-
available modes. Only applies to JPEG and PNG images.
339-
apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
340-
Only applies to JPEG and PNG images. Default: False.
341-
342-
Returns:
343-
output (Tensor[image_channels, image_height, image_width])
344-
"""
324+
"""[OBSOLETE] Use :func:`~torchvision.io.decode_image` instead."""
345325
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
346326
_log_api_usage_once(read_image)
347327
data = read_file(path)

0 commit comments

Comments
 (0)