Training Transformer
TA’s Slide
Slide
Description
In this assignment, we are tasked with utilizing a transformer decoder-only architecture for pretraining, with a focus on next-token prediction, applied to Pokémon images.
Please feel free to mail us if you have any questions.
ntu-ml-2025-spring-ta@googlegroups.com
Utilities
Download packages
!pip install datasets==3.3.2
Collecting datasets==3.3.2Using cached datasets-3.3.2-py3-none-any.whl.metadata (19 kB)
Requirement already satisfied: filelock in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (3.17.0)
Requirement already satisfied: numpy>=1.17 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (2.0.1)
Requirement already satisfied: pyarrow>=15.0.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (21.0.0)
Collecting dill<0.3.9,>=0.3.0 (from datasets==3.3.2)Using cached dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Requirement already satisfied: pandas in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (2.3.1)
Requirement already satisfied: requests>=2.32.2 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (2.32.5)
Requirement already satisfied: tqdm>=4.66.3 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (4.67.1)
Requirement already satisfied: xxhash in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (3.6.0)
Requirement already satisfied: multiprocess<0.70.17 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (0.70.16)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets==3.3.2)Using cached fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Requirement already satisfied: aiohttp in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (3.13.0)
Requirement already satisfied: huggingface-hub>=0.24.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (0.35.3)
Requirement already satisfied: packaging in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (25.0)
Requirement already satisfied: pyyaml>=5.1 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (6.0.2)
Requirement already satisfied: aiohappyeyeballs>=2.5.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (2.6.1)
Requirement already satisfied: aiosignal>=1.4.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (1.4.0)
Requirement already satisfied: async-timeout<6.0,>=4.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (5.0.1)
Requirement already satisfied: attrs>=17.3.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (25.3.0)
Requirement already satisfied: frozenlist>=1.1.1 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (1.8.0)
Requirement already satisfied: multidict<7.0,>=4.5 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (6.7.0)
Requirement already satisfied: propcache>=0.2.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (0.4.0)
Requirement already satisfied: yarl<2.0,>=1.17.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (1.22.0)
Requirement already satisfied: typing-extensions>=4.1.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from multidict<7.0,>=4.5->aiohttp->datasets==3.3.2) (4.15.0)
Requirement already satisfied: idna>=2.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from yarl<2.0,>=1.17.0->aiohttp->datasets==3.3.2) (3.7)
Requirement already satisfied: charset_normalizer<4,>=2 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from requests>=2.32.2->datasets==3.3.2) (3.3.2)
Requirement already satisfied: urllib3<3,>=1.21.1 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from requests>=2.32.2->datasets==3.3.2) (2.5.0)
Requirement already satisfied: certifi>=2017.4.17 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from requests>=2.32.2->datasets==3.3.2) (2025.10.5)
Requirement already satisfied: colorama in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from tqdm>=4.66.3->datasets==3.3.2) (0.4.6)
Requirement already satisfied: python-dateutil>=2.8.2 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from pandas->datasets==3.3.2) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from pandas->datasets==3.3.2) (2025.2)
Requirement already satisfied: tzdata>=2022.7 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from pandas->datasets==3.3.2) (2025.2)
Requirement already satisfied: six>=1.5 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from python-dateutil>=2.8.2->pandas->datasets==3.3.2) (1.17.0)
Using cached datasets-3.3.2-py3-none-any.whl (485 kB)
Using cached dill-0.3.8-py3-none-any.whl (116 kB)
Using cached fsspec-2024.12.0-py3-none-any.whl (183 kB)
Installing collected packages: fsspec, dill, datasetsAttempting uninstall: fsspecFound existing installation: fsspec 2025.9.0Uninstalling fsspec-2025.9.0:Successfully uninstalled fsspec-2025.9.0---------------------------------------- 0/3 [fsspec]---------------------------------------- 0/3 [fsspec]---------------------------------------- 0/3 [fsspec]---------------------------------------- 0/3 [fsspec]Attempting uninstall: dill---------------------------------------- 0/3 [fsspec]Found existing installation: dill 0.4.0---------------------------------------- 0/3 [fsspec]Uninstalling dill-0.4.0:---------------------------------------- 0/3 [fsspec]Successfully uninstalled dill-0.4.0---------------------------------------- 0/3 [fsspec]------------- -------------------------- 1/3 [dill]------------- -------------------------- 1/3 [dill]------------- -------------------------- 1/3 [dill]Attempting uninstall: datasets------------- -------------------------- 1/3 [dill]Found existing installation: datasets 4.1.1------------- -------------------------- 1/3 [dill]Uninstalling datasets-4.1.1:------------- -------------------------- 1/3 [dill]Successfully uninstalled datasets-4.1.1------------- -------------------------- 1/3 [dill]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]---------------------------------------- 3/3 [datasets]
Successfully installed datasets-3.3.2 dill-0.3.8 fsspec-2024.12.0
Import Packages
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.optim as optim
from PIL import Image
from torch import nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import AutoModelForCausalLM, GPT2Config, set_seed
from datasets import load_dataset
from typing import Dict, Any, Optional
Check Devices
!nvidia-smi
Wed Oct 8 18:50:06 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.97 Driver Version: 580.97 CUDA Version: 13.0 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Driver-Model | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 3090 Ti WDDM | 00000000:07:00.0 On | Off |
| 47% 42C P8 25W / 450W | 12684MiB / 24564MiB | 2% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 2292 C+G C:\Windows\System32\dwm.exe N/A |
| 0 N/A N/A 5552 C+G ...8bbwe\PhoneExperienceHost.exe N/A |
| 0 N/A N/A 9928 C+G C:\Windows\explorer.exe N/A |
| 0 N/A N/A 10036 C+G ..._cw5n1h2txyewy\SearchHost.exe N/A |
| 0 N/A N/A 10264 C+G ...y\StartMenuExperienceHost.exe N/A |
| 0 N/A N/A 10632 C+G ...ogram Files\ToDesk\ToDesk.exe N/A |
| 0 N/A N/A 14304 C+G ...xyewy\ShellExperienceHost.exe N/A |
| 0 N/A N/A 15600 C+G ...5n1h2txyewy\TextInputHost.exe N/A |
| 0 N/A N/A 15812 C+G ...ouryDevice\asus_framework.exe N/A |
| 0 N/A N/A 18660 C+G ...crosoft\OneDrive\OneDrive.exe N/A |
| 0 N/A N/A 18668 C+G ...Chrome\Application\chrome.exe N/A |
| 0 N/A N/A 21724 C+G ....0.3537.57\msedgewebview2.exe N/A |
| 0 N/A N/A 22748 C+G ...s\TencentDocs\TencentDocs.exe N/A |
| 0 N/A N/A 25412 C+G ...ram Files\Tencent\QQNT\QQ.exe N/A |
| 0 N/A N/A 25872 C+G ...Chrome\Application\chrome.exe N/A |
| 0 N/A N/A 26600 C+G ...ocal\Programs\Quark\quark.exe N/A |
| 0 N/A N/A 28688 C+G ...ntrolPanel\SystemSettings.exe N/A |
| 0 N/A N/A 30104 C+G ...de\Microsoft VS Code\Code.exe N/A |
| 0 N/A N/A 31500 C+G ....0.3537.57\msedgewebview2.exe N/A |
| 0 N/A N/A 39276 C+G ...t\Edge\Application\msedge.exe N/A |
| 0 N/A N/A 41696 C+G ...PotPlayer\PotPlayerMini64.exe N/A |
| 0 N/A N/A 44176 C+G ...ffice6\promecefpluginhost.exe N/A |
| 0 N/A N/A 72652 C ...2025-Spring-Hw1\python.exe.c~ N/A |
| 0 N/A N/A 115660 C+G ...ef.win7x64\steamwebhelper.exe N/A |
| 0 N/A N/A 124396 C+G ...yb3d8bbwe\WindowsTerminal.exe N/A |
+-----------------------------------------------------------------------------------------+
Set Random Seed
set_seed(0)
Prepare Data
Define Dataset
from typing import List, Tuple, Union
import torch
from torch.utils.data import Dataset
class PixelSequenceDataset(Dataset):
def __init__(self, data: List[List[int]], mode: str = "train"):
"""
A dataset class for handling pixel sequences.
Args:
data (List[List[int]]): A list of sequences, where each sequence is a list of integers.
mode (str): The mode of operation, either "train", "dev", or "test".
- "train": Returns (input_ids, labels) where input_ids are sequence[:-1] and labels are sequence[1:].
- "dev": Returns (input_ids, labels) where input_ids are sequence[:-160] and labels are sequence[-160:].
- "test": Returns only input_ids, as labels are not available.
"""
self.data = data
self.mode = mode
def __len__(self) -> int:
"""Returns the total number of sequences in the dataset."""
return len(self.data)
def __getitem__(self, idx: int) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
"""
Fetches a sequence from the dataset and processes it based on the mode.
Args:
idx (int): The index of the sequence.
Returns:
- If mode == "train": Tuple[torch.Tensor, torch.Tensor] -> (input_ids, labels)
- If mode == "dev": Tuple[torch.Tensor, torch.Tensor] -> (input_ids, labels)
- If mode == "test": torch.Tensor -> input_ids
"""
sequence = self.data[idx]
if self.mode == "train":
input_ids = torch.tensor(sequence[:-1], dtype=torch.long)
labels = torch.tensor(sequence[1:], dtype=torch.long)
return input_ids, labels
elif self.mode == "dev":
input_ids = torch.tensor(sequence[:-160], dtype=torch.long)
labels = torch.tensor(sequence[-160:], dtype=torch.long)
return input_ids, labels
elif self.mode == "test":
input_ids = torch.tensor(sequence, dtype=torch.long)
return input_ids
raise ValueError(f"Invalid mode: {
self.mode}. Choose from 'train', 'dev', or 'test'.")
Download Dataset & Prepare Dataloader
# Load the pokemon dataset from Hugging Face Hub
pokemon_dataset = load_dataset("lca0503/ml2025-hw4-pokemon")
# Load the colormap from Hugging Face Hub
colormap = list(load_dataset("lca0503/ml2025-hw4-colormap")["train"]["color"])
# Define number of classes
num_classes = len(colormap)
# Define batch size
batch_size = 16
# === Prepare Dataset and DataLoader for Training ===
train_dataset: PixelSequenceDataset = PixelSequenceDataset(
pokemon_dataset["train"]["pixel_color"], mode="train"
)
train_dataloader: DataLoader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True
)
# === Prepare Dataset and DataLoader for Validation ===
dev_dataset: PixelSequenceDataset = PixelSequenceDataset(
pokemon_dataset["dev"]["pixel_color"], mode="dev"
)
dev_dataloader: DataLoader = DataLoader(
dev_dataset, batch_size=batch_size, shuffle=False
)
# === Prepare Dataset and DataLoader for Testing ===
test_dataset: PixelSequenceDataset = PixelSequenceDataset(
pokemon_dataset["test"][