8

I am trying to display an image stored as a pytorch tensor.

trainset = datasets.ImageFolder('data/Cat_Dog_data/train/', transform=transforms)
trainload = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)

images, labels = iter(trainload).next()
image = images[0]
image.shape 

>>> torch.Size([3, 224, 224]) # pyplot doesn't like this, so reshape

image = image.reshape(224,224,3)
plt.imshow(image.numpy())

This method is displaying a 3 by 3 grid of the same image, always in greyscale. For example:

enter image description here

How do I fix this so that the single color image is displayed correctly?

2
  • I have the exact same problem. The solution works (thanks Nicolas!) but I would to understand why I get 9 images when I use reshape or view Commented Apr 6, 2022 at 15:06
  • @Jan Pisl I actually have that question open this very second for my own investigation: stackoverflow.com/questions/51143206/… Commented Apr 6, 2022 at 17:19

1 Answer 1

16

That's very odd. Try putting the channels last by permuting rather than reshaping:

image.permute(1, 2, 0)
Sign up to request clarification or add additional context in comments.

1 Comment

This method can be used also for list of frames (video). Just insert 4 arguments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.