Notice
Recent Posts
Recent Comments
Link
| 일 | 월 | 화 | 수 | 목 | 금 | 토 |
|---|---|---|---|---|---|---|
| 1 | ||||||
| 2 | 3 | 4 | 5 | 6 | 7 | 8 |
| 9 | 10 | 11 | 12 | 13 | 14 | 15 |
| 16 | 17 | 18 | 19 | 20 | 21 | 22 |
| 23 | 24 | 25 | 26 | 27 | 28 | 29 |
| 30 |
Tags
- vllmmcp
- langchain
- AI
- langgraph
- 타입스크립트상태관리
- langchain react agent
- 크로마DB
- react
- 상태관리
- VectorDB
- expo 51 버전
- expo go 오류
- langchain tools
- jotai
- expo 아이폰 오류
- 랭체인 툴
- 리액트 네이티브 오류
- 리액트 네이티브
- 이미지처리
- expo 아이폰
- expo 안드로이드
- 네스트시큐리티
- 자바공부
- nestjs시큐리티
- langgraph mcp
- comfyui
- expo 버전 오류
- expo 51 오류
- rnn gnsfus
- 스프링 공부
Archives
- Today
- Total
영리의 테크블로그
마리오카트 RNN 학습 본문
서론
학교 과제인 Super Mario Kart 에이전트를 학습하기 위해 구현한 순환 신경망(RNN) 기반 모델을 구현했다.
https://github.com/YoonJae00/gym-SuperMarioKart-Snes
GitHub - YoonJae00/gym-SuperMarioKart-Snes: Super Mario Kart Snes integration with OpenAI Retro Gym
Super Mario Kart Snes integration with OpenAI Retro Gym - GitHub - YoonJae00/gym-SuperMarioKart-Snes: Super Mario Kart Snes integration with OpenAI Retro Gym
github.com
데이터 수집 및 전처리
- Gym-Retro 환경에 SuperMarioKart-Snes 폴더를 등록했다.
- Pyglet 키 이벤트를 바인딩해 Z 키와 방향키 입력을 raw 액션 벡터로 기록했다.
- 한 에피소드가 끝나면 별도
.npz파일로 저장했다. - 데이터 로딩 속도 병목을 해소하기 위해 모든 에피소드 파일을 메모리에 미리 로드하는
MarioFastDataset을 구현했다.
class MarioFastDataset(Dataset):
def __init__(self, npz_dir, seq_len=10):
files = sorted(glob.glob(f"{npz_dir}/*.npz"))
self.seq_len = seq_len
self.frames_list = []
self.actions_list = []
for f in files:
arr = np.load(f)
frames = arr['frames'].astype(np.float32) / 255.0 # 정규화
actions = arr['actions']
self.frames_list.append(frames)
self.actions_list.append(actions)
self.lengths = [len(a) - seq_len for a in self.actions_list]
self.cum_lengths = np.cumsum(self.lengths)
def __len__(self):
return int(self.cum_lengths[-1])
def __getitem__(self, idx):
ep = np.searchsorted(self.cum_lengths, idx, side='right')
start = idx - (self.cum_lengths[ep-1] if ep>0 else 0)
frames = self.frames_list[ep]
actions = self.actions_list[ep]
x = frames[start:start+self.seq_len]
y = actions[start+self.seq_len]
x = np.expand_dims(x, 1) # (seq_len,1,84,84)
return torch.from_numpy(x), torch.tensor(y, dtype=torch.long)
모델 구조
CNN 특징 추출기와 LSTM 순환층을 결합한 MarioRNN 모델을 사용한다.
각 프레임에서 특징 벡터를 추출하고 시퀀스 전체를 LSTM으로 처리해 마지막 시점 출력을 분류한다.
import torch.nn as nn
class MarioRNN(nn.Module):
def __init__(self, hidden_size=128, n_layers=1, n_actions=3):
super().__init__()
self.cnn = nn.Sequential(
nn.Conv2d(1,16,3,stride=2,padding=1),
nn.ReLU(),
nn.Conv2d(16,32,3,stride=2,padding=1),
nn.ReLU(),
nn.Flatten() # (32*21*21)
)
feat_size = 32 * 21 * 21
self.lstm = nn.LSTM(
input_size=feat_size,
hidden_size=hidden_size,
num_layers=n_layers,
batch_first=True
)
self.fc = nn.Linear(hidden_size, n_actions)
def forward(self, x):
B,S,C,H,W = x.shape
x = x.view(B*S, C, H, W)
feat = self.cnn(x) # (B*S, feat_size)
feat = feat.view(B, S, -1) # (B,S,feat_size)
out, _ = self.lstm(feat) # (B,S,hidden)
return self.fc(out[:, -1, :]) # (B,n_actions)
코드 설명
nn.Conv2d두 층으로 입력 프레임 특징 벡터 생성nn.LSTM으로 시퀀스를 순차 처리out[:, -1, :]로 마지막 시점 은닉 상태 사용nn.Linear로 행동(직진·좌회전+가속·우회전+가속) 확률 예측
학습 스크립트
Colab T4 환경에서 빠른 학습을 위해 메모리 캐시 데이터셋과 num_workers를 활용한다.
각 epoch마다 loss와 accuracy를 출력해 학습 경과를 모니터링한다.
from tqdm.notebook import tqdm
import time
def train(data_dir, epochs=20, lr=1e-3, batch_size=32, seq_len=15):
loader = get_fast_loader(data_dir, batch_size, seq_len, num_workers=4)
print("📦 총 배치 수:", len(loader))
print("🔌 Using device:", DEVICE)
xb,yb = next(iter(loader))
print("샘플 배치 shapes → x:", xb.shape, "y:", yb.shape)
model = MarioRNN().to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
for epoch in range(1, epochs+1):
print(f"\n▶ Epoch {epoch}/{epochs}")
total_loss, correct, total = 0,0,0
for x,y in tqdm(loader, leave=False):
x,y = x.to(DEVICE), y.to(DEVICE)
logits = model(x)
loss = criterion(logits,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
preds = logits.argmax(dim=1)
correct += (preds==y).sum().item()
total += y.size(0)
avg_loss = total_loss/len(loader)
acc = correct/total*100
print(f"✅ loss={avg_loss:.4f} acc={acc:.2f}%")
torch.save(model.state_dict(), MODEL_PATH)
테스트 코드
자동화된 단위 테스트로 데이터셋과 모델의 입출력 형태를 검증한다.
pip install pytest
pytest -q
test_dataset.py는MarioFastDataset와get_fast_loader의 배치 형태를 확인test_model.py는MarioRNNforward 출력 차원 검증
'dev > AI' 카테고리의 다른 글
| [Agent 연습] Claude MCP 알아보기 (0) | 2025.04.13 |
|---|---|
| 소설 캐릭터 정보 추출 & 대화 토이프로젝트 (Book Buddy) (0) | 2024.11.24 |
| Agent 만들기 [8] - 분신 성격 업데이트: LLM과 상호작용 분석을 활용한 페르소나 발전 (1) | 2024.11.21 |