深入解析:李宏毅2025春季机器学习作业ML2025_Spring_HW4在kaggle上的实操笔记

news/2025/11/9 15:04:21/文章来源:https://www.cnblogs.com/yxysuanfa/p/19204296

Training Transformer

TA’s Slide

Slide

Description

In this assignment, we are tasked with utilizing a transformer decoder-only architecture for pretraining, with a focus on next-token prediction, applied to Pokémon images.

Please feel free to mail us if you have any questions.

ntu-ml-2025-spring-ta@googlegroups.com

Utilities

Download packages

!pip install datasets==3.3.2
Collecting datasets==3.3.2Using cached datasets-3.3.2-py3-none-any.whl.metadata (19 kB)
Requirement already satisfied: filelock in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (3.17.0)
Requirement already satisfied: numpy>=1.17 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (2.0.1)
Requirement already satisfied: pyarrow>=15.0.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (21.0.0)
Collecting dill<0.3.9,>=0.3.0 (from datasets==3.3.2)Using cached dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Requirement already satisfied: pandas in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (2.3.1)
Requirement already satisfied: requests>=2.32.2 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (2.32.5)
Requirement already satisfied: tqdm>=4.66.3 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (4.67.1)
Requirement already satisfied: xxhash in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (3.6.0)
Requirement already satisfied: multiprocess<0.70.17 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (0.70.16)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets==3.3.2)Using cached fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Requirement already satisfied: aiohttp in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (3.13.0)
Requirement already satisfied: huggingface-hub>=0.24.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (0.35.3)
Requirement already satisfied: packaging in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (25.0)
Requirement already satisfied: pyyaml>=5.1 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from datasets==3.3.2) (6.0.2)
Requirement already satisfied: aiohappyeyeballs>=2.5.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (2.6.1)
Requirement already satisfied: aiosignal>=1.4.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (1.4.0)
Requirement already satisfied: async-timeout<6.0,>=4.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (5.0.1)
Requirement already satisfied: attrs>=17.3.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (25.3.0)
Requirement already satisfied: frozenlist>=1.1.1 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (1.8.0)
Requirement already satisfied: multidict<7.0,>=4.5 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (6.7.0)
Requirement already satisfied: propcache>=0.2.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (0.4.0)
Requirement already satisfied: yarl<2.0,>=1.17.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from aiohttp->datasets==3.3.2) (1.22.0)
Requirement already satisfied: typing-extensions>=4.1.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from multidict<7.0,>=4.5->aiohttp->datasets==3.3.2) (4.15.0)
Requirement already satisfied: idna>=2.0 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from yarl<2.0,>=1.17.0->aiohttp->datasets==3.3.2) (3.7)
Requirement already satisfied: charset_normalizer<4,>=2 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from requests>=2.32.2->datasets==3.3.2) (3.3.2)
Requirement already satisfied: urllib3<3,>=1.21.1 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from requests>=2.32.2->datasets==3.3.2) (2.5.0)
Requirement already satisfied: certifi>=2017.4.17 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from requests>=2.32.2->datasets==3.3.2) (2025.10.5)
Requirement already satisfied: colorama in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from tqdm>=4.66.3->datasets==3.3.2) (0.4.6)
Requirement already satisfied: python-dateutil>=2.8.2 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from pandas->datasets==3.3.2) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from pandas->datasets==3.3.2) (2025.2)
Requirement already satisfied: tzdata>=2022.7 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from pandas->datasets==3.3.2) (2025.2)
Requirement already satisfied: six>=1.5 in d:\anaconda3\envs\ml2025_hw4\lib\site-packages (from python-dateutil>=2.8.2->pandas->datasets==3.3.2) (1.17.0)
Using cached datasets-3.3.2-py3-none-any.whl (485 kB)
Using cached dill-0.3.8-py3-none-any.whl (116 kB)
Using cached fsspec-2024.12.0-py3-none-any.whl (183 kB)
Installing collected packages: fsspec, dill, datasetsAttempting uninstall: fsspecFound existing installation: fsspec 2025.9.0Uninstalling fsspec-2025.9.0:Successfully uninstalled fsspec-2025.9.0---------------------------------------- 0/3 [fsspec]---------------------------------------- 0/3 [fsspec]---------------------------------------- 0/3 [fsspec]---------------------------------------- 0/3 [fsspec]Attempting uninstall: dill---------------------------------------- 0/3 [fsspec]Found existing installation: dill 0.4.0---------------------------------------- 0/3 [fsspec]Uninstalling dill-0.4.0:---------------------------------------- 0/3 [fsspec]Successfully uninstalled dill-0.4.0---------------------------------------- 0/3 [fsspec]------------- -------------------------- 1/3 [dill]------------- -------------------------- 1/3 [dill]------------- -------------------------- 1/3 [dill]Attempting uninstall: datasets------------- -------------------------- 1/3 [dill]Found existing installation: datasets 4.1.1------------- -------------------------- 1/3 [dill]Uninstalling datasets-4.1.1:------------- -------------------------- 1/3 [dill]Successfully uninstalled datasets-4.1.1------------- -------------------------- 1/3 [dill]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]-------------------------- ------------- 2/3 [datasets]---------------------------------------- 3/3 [datasets]
Successfully installed datasets-3.3.2 dill-0.3.8 fsspec-2024.12.0

Import Packages

import os
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.optim as optim
from PIL import Image
from torch import nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import AutoModelForCausalLM, GPT2Config, set_seed
from datasets import load_dataset
from typing import Dict, Any, Optional

Check Devices

!nvidia-smi
Wed Oct  8 18:50:06 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.97                 Driver Version: 580.97         CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 3090 Ti   WDDM  |   00000000:07:00.0  On |                  Off |
| 47%   42C    P8             25W /  450W |   12684MiB /  24564MiB |      2%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A            2292    C+G   C:\Windows\System32\dwm.exe           N/A      |
|    0   N/A  N/A            5552    C+G   ...8bbwe\PhoneExperienceHost.exe      N/A      |
|    0   N/A  N/A            9928    C+G   C:\Windows\explorer.exe               N/A      |
|    0   N/A  N/A           10036    C+G   ..._cw5n1h2txyewy\SearchHost.exe      N/A      |
|    0   N/A  N/A           10264    C+G   ...y\StartMenuExperienceHost.exe      N/A      |
|    0   N/A  N/A           10632    C+G   ...ogram Files\ToDesk\ToDesk.exe      N/A      |
|    0   N/A  N/A           14304    C+G   ...xyewy\ShellExperienceHost.exe      N/A      |
|    0   N/A  N/A           15600    C+G   ...5n1h2txyewy\TextInputHost.exe      N/A      |
|    0   N/A  N/A           15812    C+G   ...ouryDevice\asus_framework.exe      N/A      |
|    0   N/A  N/A           18660    C+G   ...crosoft\OneDrive\OneDrive.exe      N/A      |
|    0   N/A  N/A           18668    C+G   ...Chrome\Application\chrome.exe      N/A      |
|    0   N/A  N/A           21724    C+G   ....0.3537.57\msedgewebview2.exe      N/A      |
|    0   N/A  N/A           22748    C+G   ...s\TencentDocs\TencentDocs.exe      N/A      |
|    0   N/A  N/A           25412    C+G   ...ram Files\Tencent\QQNT\QQ.exe      N/A      |
|    0   N/A  N/A           25872    C+G   ...Chrome\Application\chrome.exe      N/A      |
|    0   N/A  N/A           26600    C+G   ...ocal\Programs\Quark\quark.exe      N/A      |
|    0   N/A  N/A           28688    C+G   ...ntrolPanel\SystemSettings.exe      N/A      |
|    0   N/A  N/A           30104    C+G   ...de\Microsoft VS Code\Code.exe      N/A      |
|    0   N/A  N/A           31500    C+G   ....0.3537.57\msedgewebview2.exe      N/A      |
|    0   N/A  N/A           39276    C+G   ...t\Edge\Application\msedge.exe      N/A      |
|    0   N/A  N/A           41696    C+G   ...PotPlayer\PotPlayerMini64.exe      N/A      |
|    0   N/A  N/A           44176    C+G   ...ffice6\promecefpluginhost.exe      N/A      |
|    0   N/A  N/A           72652      C   ...2025-Spring-Hw1\python.exe.c~      N/A      |
|    0   N/A  N/A          115660    C+G   ...ef.win7x64\steamwebhelper.exe      N/A      |
|    0   N/A  N/A          124396    C+G   ...yb3d8bbwe\WindowsTerminal.exe      N/A      |
+-----------------------------------------------------------------------------------------+

Set Random Seed

set_seed(0)

Prepare Data

Define Dataset

from typing import List, Tuple, Union
import torch
from torch.utils.data import Dataset
class PixelSequenceDataset(Dataset):
def __init__(self, data: List[List[int]], mode: str = "train"):
"""
A dataset class for handling pixel sequences.
Args:
data (List[List[int]]): A list of sequences, where each sequence is a list of integers.
mode (str): The mode of operation, either "train", "dev", or "test".
- "train": Returns (input_ids, labels) where input_ids are sequence[:-1] and labels are sequence[1:].
- "dev": Returns (input_ids, labels) where input_ids are sequence[:-160] and labels are sequence[-160:].
- "test": Returns only input_ids, as labels are not available.
"""
self.data = data
self.mode = mode
def __len__(self) -> int:
"""Returns the total number of sequences in the dataset."""
return len(self.data)
def __getitem__(self, idx: int) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
"""
Fetches a sequence from the dataset and processes it based on the mode.
Args:
idx (int): The index of the sequence.
Returns:
- If mode == "train": Tuple[torch.Tensor, torch.Tensor] -> (input_ids, labels)
- If mode == "dev": Tuple[torch.Tensor, torch.Tensor] -> (input_ids, labels)
- If mode == "test": torch.Tensor -> input_ids
"""
sequence = self.data[idx]
if self.mode == "train":
input_ids = torch.tensor(sequence[:-1], dtype=torch.long)
labels = torch.tensor(sequence[1:], dtype=torch.long)
return input_ids, labels
elif self.mode == "dev":
input_ids = torch.tensor(sequence[:-160], dtype=torch.long)
labels = torch.tensor(sequence[-160:], dtype=torch.long)
return input_ids, labels
elif self.mode == "test":
input_ids = torch.tensor(sequence, dtype=torch.long)
return input_ids
raise ValueError(f"Invalid mode: {
self.mode}. Choose from 'train', 'dev', or 'test'.")

Download Dataset & Prepare Dataloader

# Load the pokemon dataset from Hugging Face Hub
pokemon_dataset = load_dataset("lca0503/ml2025-hw4-pokemon")
# Load the colormap from Hugging Face Hub
colormap = list(load_dataset("lca0503/ml2025-hw4-colormap")["train"]["color"])
# Define number of classes
num_classes = len(colormap)
# Define batch size
batch_size = 16
# === Prepare Dataset and DataLoader for Training ===
train_dataset: PixelSequenceDataset = PixelSequenceDataset(
pokemon_dataset["train"]["pixel_color"], mode="train"
)
train_dataloader: DataLoader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True
)
# === Prepare Dataset and DataLoader for Validation ===
dev_dataset: PixelSequenceDataset = PixelSequenceDataset(
pokemon_dataset["dev"]["pixel_color"], mode="dev"
)
dev_dataloader: DataLoader = DataLoader(
dev_dataset, batch_size=batch_size, shuffle=False
)
# === Prepare Dataset and DataLoader for Testing ===
test_dataset: PixelSequenceDataset = PixelSequenceDataset(
pokemon_dataset["test"][

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/960516.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

完整教程:PostgreSQL + Redis + Elasticsearch 实时同步方案实践:从触发器到高性能搜索

完整教程:PostgreSQL + Redis + Elasticsearch 实时同步方案实践:从触发器到高性能搜索pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !importan…

基于最小二乘法的五颗可见卫星伪距定位

一、数学模型构建 1.1 伪距观测方程 对于每颗可见卫星i,观测方程可表示为:(\(x,y,z\)):接收机三维坐标(待解算) (\(x_i,y_i,z_i)\):卫星\(i\)的ECEF坐标(由星历计算) \(Δt\):接收机钟差(待解算) \(ϵi\):…

new day

今日进行二叉树练习,比较不熟练,需多多练习。继续进行java语法复习。未遇到问题。

2025 年 11 月冰水机厂家推荐排行榜,工业冰水机,冷却冰水机,制冷冰水机,低温冰水机公司精选

2025年11月冰水机厂家推荐排行榜:工业温控设备专业选购指南 在工业制造领域,温控设备作为生产过程中不可或缺的关键环节,其性能优劣直接影响产品质量和生产效率。冰水机作为工业温控系统的核心设备,在塑料成型、食…

2025 年 11 月工业冰水机厂家权威推荐榜:专业制冷与高效节能口碑之选,工业冰水机,工业冷水机,工业冷冻机公司推荐

2025 年 11 月工业冰水机厂家权威推荐榜:专业制冷与高效节能口碑之选 在当今工业生产领域,制冷设备已成为保障生产效率和产品质量的关键基础设施。工业冰水机作为工业生产中温度控制的核心设备,其性能优劣直接影响生…

完整教程:用 Java 指挥 3500 只机器人跳舞——Ocado 高密度仓储集群的架构实践

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

词根学习笔记 | Alter系列 - 详解

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

图片加字,用我最爽

添加后的文字可用鼠标拖拽——没这个就成了最不爽了。 HTML+JavaScript:<html><head><meta charset="UTF-8"> <title>图片加字,用我最爽</title> <style> body {displ…

new day

今日背四级单词,写各个学科作业,明日把欠缺学科补补,没啥大问题。

How to do PhD work

人是需要有足够的内驱力的偶然间因为代课老师,网络资源后翻阅了范老师的履历,一边感慨一边深思,上午简单看了下论文后实在是不想继续对着看不懂的公式抓耳挠腮了。 老师的回答:作者:范睿 Ranger链接:https://www…

关于计算机语言的学习

关于计算机语言的学习关于计算机语言的学习 在这个系列的笔记中,我将记录下自己在研究计算机语言的过程中所积累的一些心得体会,笔记的内容将会包括我对编程语言、标记语言的了解,以及我学习各种语言的具体过程。希…

VSCODE脚本禁止:因为在此系统上禁止运行脚本。有关详细信息,请参阅。。。

在Terminal执行CMD命令时无法成功运行报错: npm : 无法加载文件 D:\Program Files\nodejs\npm.ps1,因为在此系统上禁止运行脚本。有关详 细信息,请参阅 https:/go.microsoft.com/fwlink/?LinkID=135170 中的 about…

VisionPro学习笔记-CogColorExtractorTool和CogColorSegmenterTool

CogColorExtractorTool CogColorExtractorTool CogColorExtractorTool 是康耐视(Cognex)VisionPro视觉软件中专门用于颜色提取的工具。其核心功能是从RGB彩色图像中提取符合特定颜色定义的像素,并生成相应的灰度图像…

计算机视觉(opencv)——基于MediaPipe与机器学习的手势识别高效的系统

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

2025年合肥品牌设计团队专业排行

摘要 2025年,品牌设计行业正迎来数字化和个性化浪潮,企业愈发重视品牌形象以提升市场竞争力。本文基于行业数据和用户口碑,为您推荐Top5品牌设计团队,并提供详细排名和选择指南,供企业参考。表单数据来源于权威市…

2025年国内品牌设计公司top5推荐:专业团队口碑榜单

摘要 随着2025年品牌设计行业的快速发展,企业对于专业设计团队的需求日益增长,行业整体趋向数字化、个性化和跨文化融合。本文基于市场调研和用户口碑数据,精选top5品牌设计公司,并提供详细比较和表单说明供参考,…

英语_中考作文_An Act of Kindness_待读

An Act of Kindness Last semester, my classmate Li Hua sat next to me. He had trouble with learning English and often felt discouraged. Seeing this, I decided to give him a helping hand. I shared my not…

[题解]【MX-S10】梦熊 NOIP 2025 模拟赛 2 FeOI Round 4 T1~T2

T1. P14460 寻雾启示 考虑 DP。令 \(f_i\) 为到达位置 \(i\) 的最短时间。 转移时,考虑枚举最后一个折返点 \(j\)。即:先从 \(0\) 经过一系列步骤到 \(j\)。 从 \(j\) 折返到 \(0\),一直等待到铁锭足够。 先跑步到 …

小聊一下 带圈的数字,以及罕用字的显示、字体文件的分割

以前我在博客《文本文件中一些特殊的字符用法》中提到过,unicode中有一类字符用于在前一字符的右上角的显示字母或数字,如 也有一类在前一字符的右下角显示字母或数字。 ⁴ ⁵ ⁶ ⁷ ⁸ ⁹ ⁰ ⁺ ⁻ ⁽ ⁾ ⁿ ⁼ ₐ…

CSP挂分记

本文同步发布至洛谷文章。 上午普及组,前两道题差不多 \(30\) 分钟就切掉了,但是第三题就卡住了,刚开始想弄枚举或贪心,发现写不出来,于是考虑动态规划。差不多半个小时左右,就想出正解了。而且赛后看题解,似乎…