PyTorch
PyTorch 기본 함수 정리 (Concatenate, Stack)
수터디
2022. 6. 18. 12:26
Concatenate
x = torch.FloatTensor([[1,2], [3,4]])
y = torch.FloatTensor([[5,6], [7,8]])
print(torch.cat([x,y], dim=0))
print(torch.cat([x,y], dim=1))
# tensor([[1., 2.],
# [3., 4.],
# [5., 6.],
# [7., 8.]])
# tensor([[1., 2., 5., 6.],
# [3., 4., 7., 8.]])
Stacking
x = torch.FloatTensor([1, 4])
y = torch.FloatTensor([2, 5])
z = torch.FloatTensor([3, 6])
print(x.shape)
# torch.Size([2])
# 3개가 쌓임
# 3개가 쌓이므로 dim 3이 생김
# 3이 생기는 방향을 dim option을 통해 결정
print(torch.stack([x,y,z]))
print(torch.stack([x,y,z], dim=1))
# tensor([[1., 4.],
# [2., 5.],
# [3., 6.]])
# tensor([[1., 2., 3.],
# [4., 5., 6.]])
print(torch.cat([x.unsqueeze(0), y.unsqueeze(0), z.unsqueeze(0)], dim=0))
# tensor([[1., 4.],
# [2., 5.],
# [3., 6.]])