PyTorch 기본 함수 정리 (View, Squeeze, Unsqueeze)
2022. 6. 18. 12:10ㆍPyTorch
View (Numpy의 Reshape)
t = np.array([[[0, 1, 2],
[3, 4, 5]],
[[6, 7, 8],
[9, 10, 11]]])
t
array([[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 6, 7, 8],
[ 9, 10, 11]]])
t.shape
(2, 2, 3)
ft = torch.FloatTensor(t)
ft
tensor([[[ 0., 1., 2.],
[ 3., 4., 5.]],
[[ 6., 7., 8.],
[ 9., 10., 11.]]])
ft.shape
torch.Size([2, 2, 3])
print(ft.view(-1,3).shape)
ft.view(-1,3)
torch.Size([4, 3])
tensor([[ 0., 1., 2.],
[ 3., 4., 5.],
[ 6., 7., 8.],
[ 9., 10., 11.]])
tensor([[[ 0., 1., 2.]],
[[ 3., 4., 5.]],
[[ 6., 7., 8.]],
[[ 9., 10., 11.]]])
torch.Size([4, 1, 3])
Squeeze
dimension의 element의 개수가 1인 경우 해당 dimension을 없애줌
ft = torch.FloatTensor([[0],[1],[2]])
print(ft)
print(ft.shape)
ft.squeeze()
ft.squeeze().shape
# dimension이 1이 아니므로 효과가 없음
ft.squeeze(dim = 0)
ft.squeeze(dim = 1)
Unsqueeze
dimension을 반드시 명시 내가 원하는 dimension에 1을 넣어줌
ft = torch.FloatTensor([0, 1, 2])
print(ft)
print(ft.shape)
# tensor([0., 1., 2.])
# torch.Size([3])
print(ft.unsqueeze(0))
print(ft.unsqueeze(0).shape)
# tensor([[0., 1., 2.]])
# torch.Size([1, 3])
print(ft.view(1,-1))
print(ft.view(1,-1).shape)
# tensor([[0., 1., 2.]])
# torch.Size([1, 3])
print(ft.unsqueeze(1))
print(ft.unsqueeze(1).shape)
# tensor([[0.],
# [1.],
# [2.]])
# torch.Size([3, 1])
print(ft.unsqueeze(-1))
print(ft.unsqueeze(-1).shape)
# tensor([[0.],
# [1.],
# [2.]])
# torch.Size([3, 1])
'PyTorch' 카테고리의 다른 글
PyTorch 기본 함수 정리(In-place Operation) (0) | 2022.06.18 |
---|---|
PyTorch 기본 함수 정리 (Ones, Zeros) (0) | 2022.06.18 |
PyTorch 기본 함수 정리 (Concatenate, Stack) (0) | 2022.06.18 |
PyTorch 기본 함수 정리 (Type Casting) (0) | 2022.06.18 |
PyTorch 기본 함수 정리 (mean, max) (0) | 2022.06.18 |