PyTorch 기본 함수 정리 (torch.optim, torch.nn)
2022. 6. 18. 23:22ㆍPyTorch
import torch.optim as optim # optimizer import
optimizer = optim.SGD(model.parameters(), lr = 1e-5) # optim.옵티마이저(파라미터, 러닝레이트)
import torch.nn as nn
class 딥러닝모델(nn.Module): # nn.Module을 상속받아 딥러닝 모델 생성 시
def __init__(self):
super().__init__()
self.모델오브젝트 = nn.기정의된모델사용가능(dim0, dim1)
def forward(self, x):
return self.모델오브젝트(x)
model = 딥러닝모델()
optimizer = optim.SGD(model.parameters(), lr=1e-5)
######################################################################
class MultivariateLinearRegressionModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3,1)
def forward(self, x):
return self.linear(x)
model = MultivariateLinearRegressionModel()
optimizer = optim.SGD(model.parameters(), lr=1e-5)
'PyTorch' 카테고리의 다른 글
GoogLeNet - Inception Module (0) | 2022.07.28 |
---|---|
PyTorch 기본 함수 정리 (0) | 2022.06.21 |
PyTorch 기본 함수 정리(In-place Operation) (0) | 2022.06.18 |
PyTorch 기본 함수 정리 (Ones, Zeros) (0) | 2022.06.18 |
PyTorch 기본 함수 정리 (Concatenate, Stack) (0) | 2022.06.18 |