PyTorch 기본 함수 정리 (View, Squeeze, Unsqueeze)

2022. 6. 18. 12:10PyTorch

View (Numpy의 Reshape)

t = np.array([[[0, 1, 2],
              [3, 4, 5]],
             [[6, 7, 8],
             [9, 10, 11]]])
t
 
array([[[ 0,  1,  2],
        [ 3,  4,  5]],

       [[ 6,  7,  8],
        [ 9, 10, 11]]])
 
 

 
t.shape
 
(2, 2, 3)
 

 
ft = torch.FloatTensor(t)
ft
 
tensor([[[ 0.,  1.,  2.],
         [ 3.,  4.,  5.]],

        [[ 6.,  7.,  8.],
         [ 9., 10., 11.]]])
 

 
ft.shape
 
torch.Size([2, 2, 3])
 

 
print(ft.view(-1,3).shape)
ft.view(-1,3)
 
torch.Size([4, 3])
tensor([[ 0.,  1.,  2.],
        [ 3.,  4.,  5.],
        [ 6.,  7.,  8.],
        [ 9., 10., 11.]])
 
 

 
 
 
tensor([[[ 0.,  1.,  2.]],

        [[ 3.,  4.,  5.]],

        [[ 6.,  7.,  8.]],

        [[ 9., 10., 11.]]])
torch.Size([4, 1, 3])

Squeeze

dimension의 element의 개수가 1인 경우 해당 dimension을 없애줌

ft = torch.FloatTensor([[0],[1],[2]])
print(ft)
print(ft.shape)

ft.squeeze()

ft.squeeze().shape

# dimension이 1이 아니므로 효과가 없음
ft.squeeze(dim = 0)

ft.squeeze(dim = 1)

Unsqueeze

dimension을 반드시 명시 내가 원하는 dimension에 1을 넣어줌

ft = torch.FloatTensor([0, 1, 2])
print(ft)
print(ft.shape)
# tensor([0., 1., 2.])
# torch.Size([3])

print(ft.unsqueeze(0))
print(ft.unsqueeze(0).shape)
# tensor([[0., 1., 2.]])
# torch.Size([1, 3])

print(ft.view(1,-1))
print(ft.view(1,-1).shape)
# tensor([[0., 1., 2.]])
# torch.Size([1, 3])

print(ft.unsqueeze(1))
print(ft.unsqueeze(1).shape)
# tensor([[0.],
#         [1.],
#         [2.]])
# torch.Size([3, 1])

print(ft.unsqueeze(-1))
print(ft.unsqueeze(-1).shape)
# tensor([[0.],
#         [1.],
#         [2.]])
# torch.Size([3, 1])