가장 기본적인 3d resnet의 구조를 살펴본다
import torch
import torch.nn as nn
import torchvision.models as models
class ResNet3D(nn.Module):
def __init__(self, num_classes, pretrained=True):
super(ResNet3D, self).__init__()
# 미리 정의된 3D ResNet-18 모델 로드 (사전 학습된 가중치 포함)
self.model = models.video.r3d_18(pretrained=pretrained)
# 마지막 Fully Connected 레이어 변경 (num_classes에 맞게 출력 차원 변경)
self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
def forward(self, x):
return self.model(x)
# 모델 인스턴스 생성 (클래스 수를 설정)
num_classes = 10 # 예: 10개의 행동 클래스
model = ResNet3D(num_classes=num_classes, pretrained=True)
# 손실 함수 및 옵티마이저 정의
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# DataLoader 설정 (데이터셋 예시 사용)
data_loader = ... # Your DataLoader here
# 학습 루프
num_epochs = 25
for epoch in range(num_epochs):
model.train()
for inputs, labels in data_loader:
# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 검증 루프 작성 (필요에 따라)
# model.eval()
# with torch.no_grad():
# # Validation logic here
# 모델 저장
torch.save(model.state_dict(), "resnet3d_model.pth")
# 모델 로드
model = ResNet3D(num_classes=num_classes, pretrained=False)
model.load_state_dict(torch.load("resnet3d_model.pth"))
# 예측 수행
model.eval()
with torch.no_grad():
for inputs, _ in data_loader:
outputs = model(inputs)
_, predicted = torch.max(outputs, 1)
print(predicted)
'프로젝트' 카테고리의 다른 글
[java,안드로이드 스튜디오] 날짜와 시간 예약하는 캘린더 (1) | 2023.12.08 |
---|---|
[kotlin, 앱 만들기] 타이머 (0) | 2023.11.23 |
[TensorFlow] 개 고양이 분류기 만들기 (1) | 2023.11.14 |
(TensorFlow) 의류 이미지 분류 (0) | 2023.10.31 |
주택 노후도와 범죄율간 상관관계 분석 및 사업제안 (0) | 2023.09.18 |