PyTorch 기본 함수 정리 (torch.optim, torch.nn)

2022. 6. 18. 23:22PyTorch

import torch.optim as optim # optimizer import
optimizer = optim.SGD(model.parameters(), lr = 1e-5) # optim.옵티마이저(파라미터, 러닝레이트)
import torch.nn as nn

class 딥러닝모델(nn.Module): # nn.Module을 상속받아 딥러닝 모델 생성 시
    def __init__(self):
        super().__init__()
        self.모델오브젝트 = nn.기정의된모델사용가능(dim0, dim1)
    
    def forward(self, x):
    	return self.모델오브젝트(x)
        
model = 딥러닝모델()
optimizer = optim.SGD(model.parameters(), lr=1e-5)
  
######################################################################
  
class MultivariateLinearRegressionModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(3,1)

    def forward(self, x):
        return self.linear(x)

model = MultivariateLinearRegressionModel()
optimizer = optim.SGD(model.parameters(), lr=1e-5)