PyTorch 기본 함수 정리 (Type Casting)

2022. 6. 18. 12:18PyTorch

lt = torch.LongTensor([1, 2, 3, 4])
print(lt)
# tensor([1, 2, 3, 4])

print(lt.float())
# tensor([1., 2., 3., 4.])

# ByteTensor: Boolean을 저장
bt = torch.ByteTensor([True,False,False,True])
print(bt)
print(bt.long())
print(bt.float())
# tensor([1, 0, 0, 1], dtype=torch.uint8)
# tensor([1, 0, 0, 1])
# tensor([1., 0., 0., 1.])

lt == 3
# tensor([False, False,  True, False])