PyTorch 기본 함수 정리

2022. 6. 21. 21:52PyTorch

torch.zeros_iike(텐서)

: 모든 entry가 0이면서 입력된 텐서와 동일한 shape를 갖는 텐서 생성

 

scatter(dimension, 지정한 위치, scatter할 value
: 지정한 dimension에 대해 지정한 위치에 특정 value를 scatter 하는 함수
# class 개, sample 3개
z = torch.rand(3, 5, requires_grad=True) # uniform random
hypothesis = F.softmax(z, dim=1) # prediction y_hat
y = torch.randint(5, (3,)).long()
print(y)
print(y.shape)
print(y.unsqueeze(1).shape)
# tensor([0, 2, 2])
# torch.Size([3])
# torch.Size([3, 1])


y_one_hot = torch.zeros_like(hypothesis)
print(y_one_hot)
print(y.unsqueeze(1))
print(y_one_hot.scatter_(1, y.unsqueeze(1), 1)) # inplace. dimension에 대해 지정한 위치에 1을 scatter 하는 함수

# tensor([[0., 0., 0., 0., 0.],
#         [0., 0., 0., 0., 0.],
#         [0., 0., 0., 0., 0.]])
# tensor([[0],
#         [2],
#         [2]])
# tensor([[1., 0., 0., 0., 0.],
#         [0., 0., 1., 0., 0.],
#         [0., 0., 1., 0., 0.]])

 

Tensor의 값 확인하기

# size(n): n번째 위치의 값을 return


out.size()
# 결과
# torch.Size([1, 32, 28, 28])


out.size(0)
# 결과
# 1