영리의 테크블로그

마리오카트 RNN 학습 본문

dev/AI

마리오카트 RNN 학습

영리0 2025. 5. 18. 21:45

서론

학교 과제인 Super Mario Kart 에이전트를 학습하기 위해 구현한 순환 신경망(RNN) 기반 모델을 구현했다.

 

https://github.com/YoonJae00/gym-SuperMarioKart-Snes

 

GitHub - YoonJae00/gym-SuperMarioKart-Snes: Super Mario Kart Snes integration with OpenAI Retro Gym

Super Mario Kart Snes integration with OpenAI Retro Gym - GitHub - YoonJae00/gym-SuperMarioKart-Snes: Super Mario Kart Snes integration with OpenAI Retro Gym

github.com

 

 

 

 

4번 학습데이터 20에포크 훈련

 

데이터 수집 및 전처리

  1. Gym-Retro 환경에 SuperMarioKart-Snes 폴더를 등록했다.
  2. Pyglet 키 이벤트를 바인딩해 Z 키와 방향키 입력을 raw 액션 벡터로 기록했다.
  3. 한 에피소드가 끝나면 별도 .npz 파일로 저장했다.
  4. 데이터 로딩 속도 병목을 해소하기 위해 모든 에피소드 파일을 메모리에 미리 로드하는 MarioFastDataset을 구현했다.
class MarioFastDataset(Dataset):
    def __init__(self, npz_dir, seq_len=10):
        files = sorted(glob.glob(f"{npz_dir}/*.npz"))
        self.seq_len = seq_len
        self.frames_list = []
        self.actions_list = []
        for f in files:
            arr = np.load(f)
            frames  = arr['frames'].astype(np.float32) / 255.0  # 정규화
            actions = arr['actions']
            self.frames_list.append(frames)
            self.actions_list.append(actions)
        self.lengths = [len(a) - seq_len for a in self.actions_list]
        self.cum_lengths = np.cumsum(self.lengths)
    def __len__(self):
        return int(self.cum_lengths[-1])
    def __getitem__(self, idx):
        ep = np.searchsorted(self.cum_lengths, idx, side='right')
        start = idx - (self.cum_lengths[ep-1] if ep>0 else 0)
        frames  = self.frames_list[ep]
        actions = self.actions_list[ep]
        x = frames[start:start+self.seq_len]
        y = actions[start+self.seq_len]
        x = np.expand_dims(x, 1)  # (seq_len,1,84,84)
        return torch.from_numpy(x), torch.tensor(y, dtype=torch.long)

모델 구조

CNN 특징 추출기와 LSTM 순환층을 결합한 MarioRNN 모델을 사용한다.
각 프레임에서 특징 벡터를 추출하고 시퀀스 전체를 LSTM으로 처리해 마지막 시점 출력을 분류한다.

import torch.nn as nn

class MarioRNN(nn.Module):
    def __init__(self, hidden_size=128, n_layers=1, n_actions=3):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1,16,3,stride=2,padding=1),
            nn.ReLU(),
            nn.Conv2d(16,32,3,stride=2,padding=1),
            nn.ReLU(),
            nn.Flatten()  # (32*21*21)
        )
        feat_size = 32 * 21 * 21
        self.lstm = nn.LSTM(
            input_size=feat_size,
            hidden_size=hidden_size,
            num_layers=n_layers,
            batch_first=True
        )
        self.fc = nn.Linear(hidden_size, n_actions)
    def forward(self, x):
        B,S,C,H,W = x.shape
        x = x.view(B*S, C, H, W)
        feat = self.cnn(x)            # (B*S, feat_size)
        feat = feat.view(B, S, -1)    # (B,S,feat_size)
        out, _ = self.lstm(feat)      # (B,S,hidden)
        return self.fc(out[:, -1, :]) # (B,n_actions)

코드 설명

  • nn.Conv2d 두 층으로 입력 프레임 특징 벡터 생성
  • nn.LSTM으로 시퀀스를 순차 처리
  • out[:, -1, :]로 마지막 시점 은닉 상태 사용
  • nn.Linear로 행동(직진·좌회전+가속·우회전+가속) 확률 예측

학습 스크립트

Colab T4 환경에서 빠른 학습을 위해 메모리 캐시 데이터셋과 num_workers를 활용한다.
각 epoch마다 loss와 accuracy를 출력해 학습 경과를 모니터링한다.

from tqdm.notebook import tqdm
import time

def train(data_dir, epochs=20, lr=1e-3, batch_size=32, seq_len=15):
    loader = get_fast_loader(data_dir, batch_size, seq_len, num_workers=4)
    print("📦 총 배치 수:", len(loader))
    print("🔌 Using device:", DEVICE)
    xb,yb = next(iter(loader))
    print("샘플 배치 shapes → x:", xb.shape, "y:", yb.shape)

    model = MarioRNN().to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(1, epochs+1):
        print(f"\n▶ Epoch {epoch}/{epochs}")
        total_loss, correct, total = 0,0,0
        for x,y in tqdm(loader, leave=False):
            x,y = x.to(DEVICE), y.to(DEVICE)
            logits = model(x)
            loss = criterion(logits,y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            preds = logits.argmax(dim=1)
            correct += (preds==y).sum().item()
            total += y.size(0)
        avg_loss = total_loss/len(loader)
        acc = correct/total*100
        print(f"✅ loss={avg_loss:.4f} acc={acc:.2f}%")
    torch.save(model.state_dict(), MODEL_PATH)

테스트 코드

자동화된 단위 테스트로 데이터셋과 모델의 입출력 형태를 검증한다.

pip install pytest
pytest -q
  • test_dataset.pyMarioFastDatasetget_fast_loader의 배치 형태를 확인
  • test_model.pyMarioRNN forward 출력 차원 검증