프로젝트

3d resnet 도전기

기계학습점쟁이 2024. 7. 18. 10:09

가장 기본적인 3d resnet의 구조를 살펴본다

import torch
import torch.nn as nn
import torchvision.models as models


class ResNet3D(nn.Module):
    def __init__(self, num_classes, pretrained=True):
        super(ResNet3D, self).__init__()
        
        # 미리 정의된 3D ResNet-18 모델 로드 (사전 학습된 가중치 포함)
        self.model = models.video.r3d_18(pretrained=pretrained)
        
        # 마지막 Fully Connected 레이어 변경 (num_classes에 맞게 출력 차원 변경)
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)

    def forward(self, x):
        return self.model(x)

# 모델 인스턴스 생성 (클래스 수를 설정)
num_classes = 10  # 예: 10개의 행동 클래스
model = ResNet3D(num_classes=num_classes, pretrained=True)

# 손실 함수 및 옵티마이저 정의
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# DataLoader 설정 (데이터셋 예시 사용)
data_loader = ...  # Your DataLoader here

# 학습 루프
num_epochs = 25
for epoch in range(num_epochs):
    model.train()
    for inputs, labels in data_loader:
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # 검증 루프 작성 (필요에 따라)
    # model.eval()
    # with torch.no_grad():
    #     # Validation logic here

# 모델 저장
torch.save(model.state_dict(), "resnet3d_model.pth")

# 모델 로드
model = ResNet3D(num_classes=num_classes, pretrained=False)
model.load_state_dict(torch.load("resnet3d_model.pth"))

# 예측 수행
model.eval()
with torch.no_grad():
    for inputs, _ in data_loader:
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        print(predicted)