前言
在AI驱动开发的时代,提示词(Prompt)是连接人类需求与AI能力的核心桥梁。尤其对于爬虫+神经网络这类技术密集型项目,优秀的提示词能让AI精准输出可用代码、高效解决调试问题,甚至缩短50%以上的开发周期。本文将先拆解优秀提示词的核心条件,再以“图片爬虫+ResNet18分类模型”项目为案例,复盘从初始需求到最终落地的全流程提示词设计与迭代,带你掌握技术类项目的提示词工程技巧。
一、优秀提示词的核心条件(技术开发场景专属)
结合提示词工程理论与实操经验,优秀的技术类提示词需满足「五大核心条件」,既兼顾通用逻辑,又适配代码开发、调试优化的特殊需求,具体如下:
1. 目标明确性:精准锁定“做什么+要什么”
拒绝模糊表述,需明确任务边界、技术目标与输出要求,让AI无需猜测。核心是“不遗漏关键约束”,比如技术栈、功能范围、输出格式。
- 反例:“帮我写个爬虫爬图片,再做个分类模型”(无技术栈、无目标对象、无输出标准);
- 正例:“用Python+Selenium(Edge浏览器)爬取百度图片中‘少年带土’前20张图片,按人物名称创建文件夹存储,代码需包含异常处理和详细注释”。
实操要点:技术场景中需额外明确“输出物形态”(如可运行代码、模块拆分、注释要求),避免AI输出伪代码或残缺逻辑。
2. 场景具象化:绑定“角色+环境+约束”
给AI设定清晰角色与开发环境,同时明确技术约束(必用/禁用方案),让输出更贴合实际开发场景。技术项目中,“环境前提”和“技术限制”是具象化的核心。
- 示例:“作为资深PyTorch工程师,基于我已整理的火影人物图片数据集(6类,每类20张,按文件夹分类),用ResNet18迁移学习实现分类模型,要求不使用TensorFlow,适配Python3.7版本,显存限制8G”。
实操要点:角色设定需匹配任务难度(如“爬虫工程师”“神经网络专家”),环境信息需包含版本、硬件限制、数据集情况,技术约束需明确“必选工具”和“排除项”。
3. 逻辑结构性:拆解“步骤+优先级”
复杂项目需按开发流程拆分需求,明确步骤顺序与优先级,帮助AI建立清晰的执行框架。尤其适合“爬虫→数据处理→模型开发→优化”这类多环节任务。
- 示例:“分两步实现项目:第一步,开发百度图片爬虫,需支持多人物批量爬取、图片去重;第二步,基于PyTorch搭建分类模型,优先保证代码可运行,再优化准确率。每步代码需独立成模块,附带调用说明”。
实操要点:技术项目可按“流程阶段”或“功能模块”拆分,同时标注优先级(如“先保证运行,再优化性能”),避免AI本末倒置。
4. 细节完整性:补充“参数+反馈机制”
技术开发的卡壳多源于细节缺失,优秀提示词需提供关键参数(如路径、接口信息),并预留动态反馈通道,方便迭代优化。
- 示例:“爬虫代码中,EdgeDriver路径为r"C:\Users\86139\AppData\Local\Programs\Python\Python37\Scripts\msedgedriver.exe",爬取关键词包括‘少年带土’‘鹰佐’等6个角色;若运行报错,先分析错误原因,再提供修改后的完整代码”。
实操要点:关键参数(路径、版本、硬件参数)需精准提供,同时明确“错误反馈方式”,让AI能基于报错信息定向修正。
5. 动态可调性:预留“迭代优化空间”
首次提示词无需追求完美,需设计“追加指令”的接口,根据AI输出结果微调需求,尤其适合调试阶段。
- 示例:“先按上述需求开发爬虫代码,运行后若出现‘提取不到图片链接’的问题,再优化定位逻辑,优先使用XPath定位,不依赖类名”。
实操要点:技术项目中可预设常见问题场景,明确优化方向,减少反复沟通成本。
二、项目实战:从提示词到完整落地(爬虫+ResNet18分类)
以下以“火影人物图片爬虫+PyTorch分类模型”项目为核心,按开发流程复盘提示词设计、AI响应、迭代优化的全过程,所有提示词均为实战原版,附设计思路与效果分析。
阶段1:需求初始化——从模糊到明确的首次提示词
用户初始提示词
“帮我爬取百度图片中6个火影人物的图片,每个人物前20张,分别存入对应名称的文件夹;之后用卷积神经网络实现分类模型,能预测单张图片的人物类别。用Edge浏览器,Python语言。”
提示词分析
- 优点:明确了核心任务(爬虫+分类)、工具(Edge+Python)、基础需求(6类人物、前20张);
- 不足:缺少关键细节(驱动路径、图片定位方式、神经网络框架、数据集结构),易导致AI输出通用代码,需二次调试。
豆包响应与迭代
豆包首次输出基础爬虫代码(Selenium+Edge)和简单CNN模型,但未包含具体路径配置、图片定位逻辑模糊。
用户优化提示词(补充细节+结构):“基于之前的需求,补充以下信息:1. EdgeDriver路径:r"C:\Users\86139\AppData\Local\Programs\Python\Python37\Scripts\msedgedriver.exe";2. 爬取关键词:少年带土、白面具、破面带土、青年水门、鹰佐、九喇嘛鸣人;3. 神经网络用PyTorch框架,优先考虑迁移学习(如ResNet18);4. 分两步输出代码:先爬虫,再分类模型,每步代码需可直接运行,包含异常处理和注释。”
阶段2:爬虫开发与调试——用精准提示词解决报错
问题场景1:爬虫运行报错“未找到图片容器元素”
用户提示词
“运行你给的爬虫代码后报错:Exception: ❌ 未找到任何图片容器元素!错误位置在第74行。我用Edge浏览器检查了百度图片页面,img标签的class是img_7rSL,src属性可直接打开完整图片。请基于这个信息修改代码,优先用src属性提取图片链接,定位逻辑改用XPath,避免依赖易变的类名,确保能爬取到图片。”
提示词设计思路
- 满足“细节完整性”:提供报错信息、页面元素检查结果(class、src属性);
- 满足“动态可调性”:明确优化方向(XPath定位、用src属性),直指问题核心。
豆包响应与效果
豆包修改代码,用XPath定位“src属性有效”的img标签,补充页面滚动逻辑触发懒加载,代码运行成功,顺利爬取6类人物各20张图片。
用户二次优化提示词(交互逻辑定制)
“代码已成功运行,帮我修改最后的单张图片预测部分,实现循环输入图片路径预测,输入‘q’可退出,若路径不存在提示错误信息,保留置信度输出。”
阶段3:神经网络模型开发——技术选型与代码落地
用户提示词
“我已手动筛选完图片,数据集按类别分文件夹存储(路径:D:\无畏时刻\机器学习\vibe coding火影人物分类\人物分类)。作为擅长神经网络的专家,用PyTorch+ResNet18迁移学习实现分类模型,要求:1. 自动划分训练集/验证集(比例2:8);2. 包含数据增强、早停逻辑,缓解过拟合;3. 训练完成后可视化损失和准确率曲线;4. 保留之前的循环预测逻辑,输入q退出。不使用混淆矩阵分析,代码可直接运行。”
提示词设计思路
- 满足“角色具象化”:设定“神经网络专家”角色,明确技术栈(PyTorch+ResNet18);
- 满足“逻辑结构性”:拆分数据处理、模型构建、训练可视化、预测交互四大需求;
- 满足“技术约束”:排除混淆矩阵,明确数据集路径和拆分比例。
豆包响应与效果
输出完整模块化代码,包含数据增强、ResNet18微调(解冻layer3/layer4)、早停逻辑、训练曲线可视化,代码一次运行成功,初始验证准确率30%。
阶段4:模型优化——针对性提示词提升准确率
用户提示词
“模型运行成功,但验证准确率只有30%,请分析原因并给出优化方案,基于原代码修改:1. 强化数据增强(增加裁剪、高斯模糊等);2. 优化ResNet18微调策略,调整学习率(分层设置);3. 保持循环预测逻辑和早停机制,不使用混淆矩阵;4. 每个类别我会补充到30-50张图片,代码需适配更大数据集。”
提示词设计思路
- 满足“目标明确性”:核心目标是提升准确率,明确3个优化方向;
- 满足“动态可调性”:结合数据补充计划,让代码适配后续数据集扩容。
豆包响应与效果
修改代码:强化数据增强(随机裁剪、高斯模糊、颜色抖动)、用AdamW分层设置学习率(卷积层1e-5,分类头1e-4)、提升Dropout比例至0.6,优化后准确率提升至78%。
阶段5:最终迭代——完善交互与稳定性
用户提示词
“代码已优化,帮我做最后调整:1. 修复循环预测中‘q’退出无效的问题;2. 补充模型保存路径的注释,明确最优模型保存逻辑;3. 训练日志中增加早停计数器显示,方便监控训练进度;4. 确保所有路径适配Windows系统,代码可直接运行无需修改。”
最终效果
代码完全适配需求,支持批量爬取、模型训练、循环预测全流程,验证准确率稳定在78%-82%,交互逻辑流畅,可直接用于火影人物图片分类任务。
三、提示词工程复盘与可复用模板
1. 项目提示词核心经验
- 技术项目中,“报错信息+页面/数据细节”是调试类提示词的黄金组合,能让AI跳过猜测,直接定位问题;
- 迁移学习、爬虫这类专业任务,需在提示词中“锁定技术细节”(如ResNet18微调策略、XPath定位规则),避免AI输出通用方案;
- 迭代式提示词更高效:首次搭框架,后续按“报错→补充细节→优化需求”的节奏微调,比一次性写长篇提示词更精准。
2. 技术类项目提示词可复用模板
【环境前提】已安装:Python3.7、PyTorch2.x、Edge浏览器;已配置:EdgeDriver路径XXX;数据集情况:XXX(类别数、数量、存储结构);硬件限制:XXX(显存、CPU)
【核心需求】XXX(如“爬虫+ResNet18分类”“批量爬取+单张预测”)
【技术约束】必须用:XXX(工具/框架/策略,如Edge+Selenium、ResNet18迁移学习);禁止用:XXX(如TensorFlow、混淆矩阵)
【步骤与优先级】1. 第一步XXX(如爬虫开发),优先保证XXX(如图片提取成功);2. 第二步XXX(如模型训练),优先保证XXX(如可运行性)
【关键细节】XXX(如路径、页面元素属性、学习率设置)
【输出要求】代码需模块化、带注释、可直接运行;包含XXX(如异常处理、可视化、交互逻辑);报错后提供修改后的完整代码
四,代码与运行结果展示
1、文件夹结构:
2、代码:
爬虫:
import os import time import requests import traceback from selenium import webdriver from selenium.webdriver.edge.service import Service from selenium.webdriver.common.by import By from selenium.webdriver.common.keys import Keys from selenium.webdriver.support.wait import WebDriverWait from selenium.webdriver.support import expected_conditions as EC from PIL import Image # 需安装:pip install pillow import io # -------------------------- 核心配置(路径隔离关键) -------------------------- IMG_COUNT = 50 EDGE_DRIVER_PATH = r"C:\Users\86139\AppData\Local\Programs\Python\Python37\Scripts\msedgedriver.exe" TIMEOUT = 25 # 图片根目录:代码在「代码」文件夹,图片保存到上级的「人物分类」文件夹 IMG_ROOT_DIR = "../人物分类" # 关键:所有图片都存在这里,和代码隔离 HEADERS = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.0.0", "Referer": "https://image.baidu.com/", "Accept": "image/webp,image/png,image/jpeg,*/*;q=0.8" } # -------------------------- 核心函数 -------------------------- def init_edge_browser(): """初始化Edge浏览器,规避反爬""" if not os.path.exists(EDGE_DRIVER_PATH): raise FileNotFoundError(f"EdgeDriver文件不存在!路径:{EDGE_DRIVER_PATH}") options = webdriver.EdgeOptions() options.add_experimental_option("excludeSwitches", ["enable-automation", "enable-logging"]) options.add_experimental_option("useAutomationExtension", False) options.add_argument("--disable-blink-features=AutomationControlled") options.add_argument("--disable-web-security") options.add_argument("--ignore-certificate-errors") options.add_argument(f"--user-agent={HEADERS['User-Agent']}") options.add_experimental_option("prefs", {"profile.managed_default_content_settings.images": 2}) service = Service(executable_path=EDGE_DRIVER_PATH) browser = webdriver.Edge(service=service, options=options) browser.execute_script("Object.defineProperty(navigator, 'webdriver', {get: () => undefined})") browser.maximize_window() browser.set_page_load_timeout(TIMEOUT) return browser def create_folder(keyword): """创建人物命名的文件夹(保存在「人物分类」下)""" # 拼接路径:../人物分类/关键词(比如 ../人物分类/少年带土) folder_path = os.path.join(IMG_ROOT_DIR, keyword) if not os.path.exists(folder_path): os.makedirs(folder_path) print(f"📁 已在「人物分类」创建文件夹:{folder_path}") return folder_path def get_img_urls(browser, keyword, count): """提取图片链接(跳过第一张)""" img_urls = [] try: browser.get("https://image.baidu.com/search") time.sleep(3) search_box = WebDriverWait(browser, TIMEOUT).until( EC.element_to_be_clickable((By.ID, "kw")) ) search_box.clear() search_box.send_keys(keyword) search_box.send_keys(Keys.ENTER) print(f"🔍 搜索「{keyword}」完成,等待结果加载...") time.sleep(5) for _ in range(3): browser.execute_script("window.scrollBy(0, 800);") time.sleep(2) img_tags = WebDriverWait(browser, TIMEOUT).until( EC.presence_of_all_elements_located( (By.XPATH, "//img[starts-with(@src, 'http') and not(contains(@src, 'ss0.bdstatic.com'))]") ) ) if not img_tags: print(f"❌ 「{keyword}」未找到任何有效图片标签") return img_urls # 跳过第一张,取指定数量 target_img_tags = img_tags[1:count + 1] for img in target_img_tags: try: src_url = img.get_attribute("src") if src_url: img_urls.append(src_url) except Exception as e: print(f"⚠️ 提取单张图片链接失败:{e}") continue print(f"✅ 「{keyword}」提取到 {len(img_urls)} 张有效图片链接(已跳过第一张)") except Exception as e: print(f"❌ 提取「{keyword}」链接失败:{str(e)}") traceback.print_exc() return img_urls def download_img(img_url, save_path): """下载图片,强制保存为标准RGB格式JPG""" try: response = requests.get( img_url, headers=HEADERS, timeout=15, stream=True, allow_redirects=True ) response.raise_for_status() # 验证1:是否为图片类型 content_type = response.headers.get("Content-Type", "") if "image" not in content_type: print(f"❌ 非图片链接,跳过:{img_url[:50]}...") return # 验证2:文件大小(过滤小于1KB的无效文件) img_data = response.content if len(img_data) < 1024: print(f"❌ 文件过小,跳过:{img_url[:50]}...(大小:{len(img_data)}字节)") return # 核心优化:用PIL解析并重新编码为标准JPG try: # 打开图片 img = Image.open(io.BytesIO(img_data)) # 转换为RGB格式(JPG不支持透明通道,避免编码异常) img_rgb = img.convert("RGB") # 保存为标准JPG(覆盖原后缀的异常文件) img_rgb.save(save_path, "JPEG", quality=95, optimize=True) print(f"✅ 下载并标准化成功:{save_path}") except Exception as e: print(f"❌ 图片编码异常,跳过:{img_url[:50]}...(错误:{e})") return except Exception as e: print(f"❌ 下载失败 {img_url[:50]}...:{str(e)}") def crawl_baidu_images(): """主爬虫函数""" browser = None try: print("🚀 百度图片爬虫启动(路径隔离版)...") # 先确保「人物分类」根文件夹存在 if not os.path.exists(IMG_ROOT_DIR): os.makedirs(IMG_ROOT_DIR) print(f"📁 已创建图片根目录:{IMG_ROOT_DIR}") input_str = input("请输入需要爬取的目标关键词(多个关键词用空格分隔):") PERSONS = [keyword.strip() for keyword in input_str.split() if keyword.strip()] if not PERSONS: print("❌ 未输入有效关键词,爬虫退出!") return browser = init_edge_browser() for person in PERSONS: print(f"\n========== 开始爬取:{person} ==========") folder = create_folder(person) # 文件夹在「人物分类」下 img_urls = get_img_urls(browser, person, IMG_COUNT) if not img_urls: print(f"⚠️ 「{person}」无有效图片链接,跳过下载") continue for idx, img_url in enumerate(img_urls, 1): # 保存路径:../人物分类/关键词/关键词_1.jpg save_path = os.path.join(folder, f"{person}_{idx}.jpg") download_img(img_url, save_path) time.sleep(1) print("\n🎉 所有关键词图片爬取完成!图片均保存在「人物分类」文件夹下") except FileNotFoundError as e: print(f"❌ 爬虫启动失败:{e}") except Exception as e: print(f"\n❌ 爬虫执行异常:{str(e)}") traceback.print_exc() finally: if browser: browser.quit() print("\n🔚 浏览器已关闭,爬虫程序结束!") # -------------------------- 执行入口 -------------------------- if __name__ == "__main__": crawl_baidu_images()训练,预测模型:
# '''调用爬虫''' # # 分类代码顶部的导入区新增 # import pachong # 导入爬虫脚本(pachong.py) # import sys # # """调用爬虫脚本的核心函数爬取图片""" # print("\n🐜 开始执行爬虫爬取图片...") # try: # # 直接调用pachong.py中的核心爬虫函数 # pachong.crawl_baidu_images() # print("\n✅ 爬虫爬取完成,图片已保存至「人物分类」文件夹") # except Exception as e: # print(f"\n❌ 爬虫执行失败:{str(e)}") # # 爬虫失败时可选是否继续执行分类 # confirm = input("是否跳过爬虫,继续执行分类训练?(y/n):") # if confirm.lower() != "y": # sys.exit(1) # 退出程序 import os import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, Dataset from torchvision import datasets, transforms, models from torchvision.models import ResNet18_Weights import matplotlib.pyplot as plt import numpy as np from tqdm import tqdm import warnings from PIL import Image from imblearn.over_sampling import RandomOverSampler import random warnings.filterwarnings("ignore") # ======================== 1. 核心配置(路径隔离关键:图片路径指向「人物分类」,代码在「代码」文件夹) ======================== # 图片目录:代码在「代码」文件夹,上一级是「vibe coding火影人物分类」,再进入「人物分类」(相对路径) DATA_DIR = "../人物分类" # 关键:只操作这个文件夹里的图片,和代码文件隔离 IMAGE_SIZE = (224, 224) BATCH_SIZE = 8 EPOCHS = 30 LEARNING_RATE = 1e-4 VAL_SPLIT = 0.2 SAVE_MODEL = True MODEL_SAVE_PATH = "./naruto_classifier_resnet18_data_opt.pth" # 模型保存在「代码」文件夹内 CUTMIX_PROB = 0.5 CUTMIX_BETA = 1.0 # ======================== 2. 设备配置(不变) ======================== def get_device(): if torch.cuda.is_available(): device = torch.device("cuda:0") print(f"✅ 使用GPU训练:{torch.cuda.get_device_name(0)}") elif torch.backends.mps.is_available(): device = torch.device("mps") print("✅ 使用MPS训练(Mac M系列芯片)") else: device = torch.device("cpu") print("⚠️ 未检测到GPU,使用CPU训练(速度较慢)") return device device = get_device() # ======================== 3. 数据清洗(仅操作「人物分类」下的图片,不会碰代码文件) ======================== def clean_invalid_images(data_dir): """仅清洗「人物分类」文件夹内的图片,彻底隔离代码文件""" print("\n🧹 开始清洗「人物分类」下的无效图片...") invalid_count = 0 # 只遍历「人物分类」下的子文件夹(图片类别) for root, _, files in os.walk(data_dir): # 跳过非图片文件(只处理.jpg/.png) for file in files: if not file.lower().endswith((".jpg", ".png", ".jpeg")): continue # 不是图片格式,直接跳过(避免误删其他文件) file_path = os.path.join(root, file) try: img = Image.open(file_path).convert("RGB") img.verify() if img.size[0] < 64 or img.size[1] < 64: os.remove(file_path) invalid_count += 1 print(f"删除过小图片:{file_path}") except: os.remove(file_path) invalid_count += 1 print(f"删除损坏图片:{file_path}") print(f"✅ 数据清洗完成,共删除{invalid_count}张无效图片") # ======================== 4. 其他数据函数(路径均限定在DATA_DIR) ======================== def calculate_dataset_stats(data_dir): print("\n📊 计算「人物分类」数据集的RGB均值和标准差...") transform = transforms.Compose([ transforms.Resize(IMAGE_SIZE), transforms.ToTensor() ]) dataset = datasets.ImageFolder(root=data_dir, transform=transform) loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=0) mean = torch.zeros(3) std = torch.zeros(3) total_images = 0 for images, _ in tqdm(loader, desc="计算统计量"): batch_size = images.size(0) images = images.view(batch_size, 3, -1) mean += images.mean(2).sum(0) std += images.std(2).sum(0) total_images += batch_size mean /= total_images std /= total_images print(f"✅ 自定义归一化参数:") print(f"均值:{mean.numpy().round(4)} | 标准差:{std.numpy().round(4)}") return mean.numpy(), std.numpy() def cutmix(data, targets, alpha=1.0): if alpha > 0: lam = np.random.beta(alpha, alpha) else: lam = 1 batch_size = data.size(0) index = torch.randperm(batch_size).to(device) bbx1, bby1, bbx2, bby2 = rand_bbox(data.size(), lam) data[:, :, bbx1:bbx2, bby1:bby2] = data[index, :, bbx1:bbx2, bby1:bby2] lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (data.size()[-1] * data.size()[-2])) targets_a, targets_b = targets, targets[index] return data, targets_a, targets_b, lam def rand_bbox(size, lam): W = size[2] H = size[3] cut_rat = np.sqrt(1. - lam) cut_w = int(W * cut_rat) cut_h = int(H * cut_rat) cx = np.random.randint(W) cy = np.random.randint(H) bbx1 = np.clip(cx - cut_w // 2, 0, W) bby1 = np.clip(cy - cut_h // 2, 0, H) bbx2 = np.clip(cx + cut_w // 2, 0, W) bby2 = np.clip(cy + cut_h // 2, 0, H) return bbx1, bby1, bbx2, bby2 class BalancedDataset(Dataset): def __init__(self, dataset): self.dataset = dataset self.classes = dataset.dataset.classes self.class_to_idx = dataset.dataset.class_to_idx self.samples = [] self.labels = [] for idx in range(len(dataset)): img_path, label = dataset.dataset.samples[dataset.indices[idx]] self.samples.append((img_path, label)) self.labels.append(label) ros = RandomOverSampler(random_state=42) X_resampled, y_resampled = ros.fit_resample(np.arange(len(self.samples)).reshape(-1, 1), self.labels) self.resampled_indices = X_resampled.flatten() print(f"\n⚖️ 类别均衡处理:") print(f"原样本数:{len(self.samples)} | 均衡后样本数:{len(self.resampled_indices)}") for cls in self.classes: cls_idx = self.class_to_idx[cls] count = np.sum(np.array(y_resampled) == cls_idx) print(f"{cls}:{count}张") def __len__(self): return len(self.resampled_indices) def __getitem__(self, idx): orig_idx = self.resampled_indices[idx] img_path, label = self.samples[orig_idx] img = Image.open(img_path).convert("RGB") img = self.dataset.dataset.transform(img) return img, label # ======================== 5. 数据加载(路径隔离) ======================== def create_data_loaders(): clean_invalid_images(DATA_DIR) # 只清洗「人物分类」的图片 mean, std = calculate_dataset_stats(DATA_DIR) # 只统计「人物分类」的图片 train_transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.RandomCrop(IMAGE_SIZE), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomRotation(20), transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1), transforms.RandomGrayscale(p=0.1), transforms.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0)), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std) ]) val_transform = transforms.Compose([ transforms.Resize(IMAGE_SIZE), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std) ]) full_dataset = datasets.ImageFolder(root=DATA_DIR, transform=train_transform) dataset_size = len(full_dataset) val_size = int(VAL_SPLIT * dataset_size) train_size = dataset_size - val_size train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size]) val_dataset.dataset.transform = val_transform train_dataset_balanced = BalancedDataset(train_dataset) train_loader = DataLoader( train_dataset_balanced, batch_size=BATCH_SIZE, shuffle=True, num_workers=0 ) val_loader = DataLoader( val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0 ) print(f"\n📊 最终数据集信息:") print(f"原始总样本数:{dataset_size} | 训练集(均衡后):{len(train_dataset_balanced)} | 验证集:{len(val_dataset)}") print(f"类别数:{len(full_dataset.classes)} | 类别映射:{full_dataset.class_to_idx}") return train_loader, val_loader, full_dataset.classes, full_dataset.class_to_idx, mean, std # ======================== 6. 模型/训练/可视化/预测(路径均隔离) ======================== def build_model(num_classes): model = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1) for name, param in model.named_parameters(): if "layer1" in name or "layer2" in name or "conv1" in name or "bn1" in name: param.requires_grad = False else: param.requires_grad = True in_features = model.fc.in_features model.fc = nn.Sequential(nn.Dropout(0.6), nn.Linear(in_features, num_classes)) model = model.to(device) return model def train_model(model, train_loader, val_loader, criterion, optimizer, epochs, mean, std): train_losses = [] val_losses = [] train_accs = [] val_accs = [] best_val_acc = 0.0 best_model_weights = None early_stop_count = 0 early_stop_patience = 5 print("\n🚀 开始训练(路径隔离版):") for epoch in range(epochs): model.train() train_loss = 0.0 train_correct = 0 train_total = 0 pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs} [训练]") for images, labels in pbar: images, labels = images.to(device), labels.to(device) if random.random() < CUTMIX_PROB: images, labels_a, labels_b, lam = cutmix(images, labels, CUTMIX_BETA) optimizer.zero_grad() outputs = model(images) loss = lam * criterion(outputs, labels_a) + (1 - lam) * criterion(outputs, labels_b) loss.backward() optimizer.step() _, predicted = torch.max(outputs.data, 1) train_total += labels.size(0) train_correct += (lam * predicted.eq(labels_a.data).cpu().sum() + (1 - lam) * predicted.eq( labels_b.data).cpu().sum()).item() else: optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() train_loss += loss.item() * images.size(0) _, predicted = torch.max(outputs.data, 1) train_total += labels.size(0) train_correct += (predicted == labels).sum().item() pbar.set_postfix({"loss": loss.item(), "acc": train_correct / train_total}) epoch_train_loss = train_loss / len(train_loader.dataset) if train_loss > 0 else 0.0 epoch_train_acc = train_correct / train_total train_losses.append(epoch_train_loss) train_accs.append(epoch_train_acc) model.eval() val_loss = 0.0 val_correct = 0 val_total = 0 with torch.no_grad(): for images, labels in val_loader: images, labels = images.to(device), labels.to(device) outputs = model(images) loss = criterion(outputs, labels) val_loss += loss.item() * images.size(0) _, predicted = torch.max(outputs.data, 1) val_total += labels.size(0) val_correct += (predicted == labels).sum().item() epoch_val_loss = val_loss / len(val_loader.dataset) epoch_val_acc = val_correct / val_total val_losses.append(epoch_val_loss) val_accs.append(epoch_val_acc) if epoch_val_acc > best_val_acc: best_val_acc = epoch_val_acc best_model_weights = model.state_dict() early_stop_count = 0 else: early_stop_count += 1 print(f"\nEpoch {epoch + 1} 总结:") print(f"训练损失:{epoch_train_loss:.4f} | 训练准确率:{epoch_train_acc:.4f}") print(f"验证损失:{epoch_val_loss:.4f} | 验证准确率:{epoch_val_acc:.4f}") print(f"当前最优验证准确率:{best_val_acc:.4f}") if early_stop_count > 0: print(f"早停计数器:{early_stop_count}/{early_stop_patience}") print("-" * 50) if early_stop_count >= early_stop_patience: print(f"\n⚠️ 验证集准确率连续{early_stop_patience}轮未提升,触发早停") break model.load_state_dict(best_model_weights) if SAVE_MODEL: torch.save({ 'model_state_dict': model.state_dict(), 'mean': mean, 'std': std, 'class_to_idx': train_loader.dataset.class_to_idx }, MODEL_SAVE_PATH) print(f"✅ 最优模型已保存至「代码」文件夹:{MODEL_SAVE_PATH}") return model, train_losses, val_losses, train_accs, val_accs def plot_training_curves(train_losses, val_losses, train_accs, val_accs): plt.figure(figsize=(12, 4)) plt.subplot(1, 2, 1) plt.plot(train_losses, label="训练损失", color="blue") plt.plot(val_losses, label="验证损失", color="red") plt.title("训练/验证损失曲线(路径隔离版)") plt.xlabel("Epoch") plt.ylabel("Loss") plt.legend() plt.grid(True) plt.subplot(1, 2, 2) plt.plot(train_accs, label="训练准确率", color="blue") plt.plot(val_accs, label="验证准确率", color="red") plt.title("训练/验证准确率曲线(路径隔离版)") plt.xlabel("Epoch") plt.ylabel("Accuracy") plt.legend() plt.grid(True) plt.tight_layout() plt.savefig("./training_curves_data_opt.png") # 曲线保存在「代码」文件夹 plt.show() print("📈 训练曲线已保存至「代码」文件夹:training_curves_data_opt.png") def predict_single_image(model, image_path, class_names, class_to_idx, mean, std): idx_to_class = {v: k for k, v in class_to_idx.items()} transform = transforms.Compose([ transforms.Resize(IMAGE_SIZE), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std) ]) try: image = Image.open(image_path).convert("RGB") image_tensor = transform(image).unsqueeze(0) image_tensor = image_tensor.to(device) model.eval() with torch.no_grad(): outputs = model(image_tensor) probabilities = torch.nn.functional.softmax(outputs, dim=1) _, predicted_idx = torch.max(probabilities, 1) predicted_class = idx_to_class[predicted_idx.item()] confidence = probabilities[0][predicted_idx.item()].item() print(f"\n🔍 预测结果:") print(f"图片路径:{image_path}") print(f"预测类别:{predicted_class} | 置信度:{confidence:.4f}") return predicted_class, confidence except Exception as e: print(f"\n❌ 预测失败:{str(e)}") return None, None # ======================== 7. 主函数 ======================== if __name__ == "__main__": train_loader, val_loader, class_names, class_to_idx, mean, std = create_data_loaders() num_classes = len(class_names) model = build_model(num_classes) criterion = nn.CrossEntropyLoss() optimizer = optim.AdamW( [ {"params": model.layer3.parameters(), "lr": 1e-5}, {"params": model.layer4.parameters(), "lr": 1e-5}, {"params": model.fc.parameters(), "lr": LEARNING_RATE} ], weight_decay=1e-4 ) model, train_losses, val_losses, train_accs, val_accs = train_model( model, train_loader, val_loader, criterion, optimizer, EPOCHS, mean, std ) plot_training_curves(train_losses, val_losses, train_accs, val_accs) print("\n========== 开始预测 ==========") while True: test_image_path = input("\n请输入测试图片路径(输入q退出):") if test_image_path.lower() == "q": print("👋 退出预测") break if os.path.exists(test_image_path): predict_single_image(model, test_image_path, class_names, class_to_idx, mean, std) else: print(f"\n⚠️ 测试图片不存在:{test_image_path},请重新输入")3、运行结果
结语
提示词工程不是“堆砌需求”,而是“精准传递意图”的艺术。对于爬虫+神经网络这类技术项目,优秀的提示词需兼顾“理论框架”与“实操细节”,既符合目标明确、逻辑清晰的通用原则,又能适配技术开发的特殊性(如报错调试、参数配置、版本兼容)。
通过本文的案例复盘可见,从模糊需求到完整项目,提示词的迭代过程也是需求逐步清晰、问题逐个解决的过程。掌握本文的提示词条件与模板,能让你在AI驱动开发中更高效地落地技术项目,将更多精力放在核心逻辑设计上,而非重复调试与沟通。