基于Stable Diffusion XL模型进行文本生成图像的训练

基于Stable Diffusion XL模型进行文本生成图像的训练

flyfish

export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export VAE_NAME="madebyollin/sdxl-vae-fp16-fix"
export DATASET_NAME="lambdalabs/naruto-blip-captions"accelerate launch train_text_to_image_sdxl.py \--pretrained_model_name_or_path=$MODEL_NAME \--pretrained_vae_model_name_or_path=$VAE_NAME \--dataset_name=$DATASET_NAME \--enable_xformers_memory_efficient_attention \--resolution=512 --center_crop --random_flip \--proportion_empty_prompts=0.2 \--train_batch_size=1 \--gradient_accumulation_steps=4 --gradient_checkpointing \--max_train_steps=10000 \--use_8bit_adam \--learning_rate=1e-06 --lr_scheduler="constant" --lr_warmup_steps=0 \--mixed_precision="fp16" \--report_to="wandb" \--validation_prompt="a cute Sundar Pichai creature" --validation_epochs 5 \--checkpointing_steps=5000 \--output_dir="sdxl-naruto-model" \--push_to_hub

环境变量部分

export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export VAE_NAME="madebyollin/sdxl-vae-fp16-fix"
export DATASET_NAME="lambdalabs/naruto-blip-captions"
  • MODEL_NAME:指定预训练模型的名称或路径。这里使用的是 stabilityai/stable-diffusion-xl-base-1.0,也就是Stable Diffusion XL的基础版本1.0。
  • VAE_NAME:指定变分自编码器(VAE)的名称或路径。madebyollin/sdxl-vae-fp16-fix 是针对Stable Diffusion XL的一个经过修复的VAE模型,适用于半精度(FP16)计算。
  • DATASET_NAME:指定训练所使用的数据集名称或路径。这里使用的是 lambdalabs/naruto-blip-captions,是一个包含火影忍者相关图像及其描述的数据集。

accelerate launch 命令参数部分

accelerate launch train_text_to_image_sdxl.py \

这行代码使用 accelerate 工具来启动 train_text_to_image_sdxl.py 脚本,accelerate 可以帮助我们在多GPU、TPU等环境下进行分布式训练。

脚本参数部分

  • --pretrained_model_name_or_path=$MODEL_NAME:指定预训练模型的名称或路径,这里使用前面定义的 MODEL_NAME 环境变量。
  • --pretrained_vae_model_name_or_path=$VAE_NAME:指定预训练VAE模型的名称或路径,使用前面定义的 VAE_NAME 环境变量。
  • --dataset_name=$DATASET_NAME:指定训练数据集的名称或路径,使用前面定义的 DATASET_NAME 环境变量。
  • --enable_xformers_memory_efficient_attention:启用 xformers 库的内存高效注意力机制,能减少训练过程中的内存占用。
  • --resolution=512 --center_crop --random_flip
    • --resolution=512:将输入图像的分辨率统一调整为512x512像素。
    • --center_crop:对图像进行中心裁剪,使其达到指定的分辨率。
    • --random_flip:在训练过程中随机对图像进行水平翻转,以增加数据的多样性。
  • --proportion_empty_prompts=0.2:设置空提示(没有文本描述)的样本在训练数据中的比例为20%。
  • --train_batch_size=1:每个训练批次包含的样本数量为1。
  • --gradient_accumulation_steps=4 --gradient_checkpointing
    • --gradient_accumulation_steps=4:梯度累积步数为4,即每4个批次的梯度进行一次更新,这样可以在有限的内存下模拟更大的批次大小。
    • --gradient_checkpointing:启用梯度检查点机制,通过减少内存使用来支持更大的模型和批次大小。
  • --max_train_steps=10000:最大训练步数为10000步。
  • --use_8bit_adam:使用8位Adam优化器,能减少内存占用。
  • --learning_rate=1e-06 --lr_scheduler="constant" --lr_warmup_steps=0
    • --learning_rate=1e-06:学习率设置为1e-6。
    • --lr_scheduler="constant":学习率调度器设置为常数,即训练过程中学习率保持不变。
    • --lr_warmup_steps=0:学习率预热步数为0,即不进行学习率预热。
  • --mixed_precision="fp16":使用半精度(FP16)混合精度训练,能减少内存使用并加快训练速度。
  • --report_to="wandb":将训练过程中的指标报告到Weights & Biases(WandB)平台,方便进行可视化和监控。
  • --validation_prompt="a cute Sundar Pichai creature" --validation_epochs 5
    • --validation_prompt="a cute Sundar Pichai creature":指定验证时使用的文本提示,这里是“一个可爱的桑达尔·皮查伊形象”。
    • --validation_epochs 5:每5个训练轮次进行一次验证。
  • --checkpointing_steps=5000:每5000步保存一次模型的检查点。
  • --output_dir="sdxl-naruto-model":指定训练好的模型的输出目录为 sdxl-naruto-model
  • --push_to_hub:将训练好的模型推送到Hugging Face模型库。

离线环境运行

# 假设已经把模型、VAE和数据集下载到本地了
# 这里假设模型在当前目录下的 sdxl-base-1.0 文件夹
# VAE 在 sdxl-vae-fp16-fix 文件夹
# 数据集在 naruto-blip-captions 文件夹# 定义本地路径
MODEL_NAME="./sdxl-base-1.0"
VAE_NAME="./sdxl-vae-fp16-fix"
DATASET_NAME="./naruto-blip-captions"# 移除需要外网连接的参数
accelerate launch train_text_to_image_sdxl.py \--pretrained_model_name_or_path=$MODEL_NAME \--pretrained_vae_model_name_or_path=$VAE_NAME \--dataset_name=$DATASET_NAME \--enable_xformers_memory_efficient_attention \--resolution=512 --center_crop --random_flip \--proportion_empty_prompts=0.2 \--train_batch_size=1 \--gradient_accumulation_steps=4 --gradient_checkpointing \--max_train_steps=10000 \--use_8bit_adam \--learning_rate=1e-06 --lr_scheduler="constant" --lr_warmup_steps=0 \--mixed_precision="fp16" \--validation_prompt="a cute Sundar Pichai creature" --validation_epochs 5 \--checkpointing_steps=5000 \--output_dir="sdxl-naruto-model"

移除需要外网连接的参数:去掉 --report_to="wandb"--push_to_hub 参数,因为 wandb 需要外网连接来上传训练指标,--push_to_hub 则需要外网连接把模型推送到Hugging Face模型库。

推理

from diffusers import DiffusionPipeline
import torchmodel_path = "you-model-id-goes-here" # <-- 替换为你的模型路径
pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
pipe.to("cuda")prompt = "A naruto with green eyes and red legs."
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
image.save("naruto.png")

训练后的文件夹结构

.
├── checkpoint-10000
│   ├── optimizer.bin
│   ├── random_states_0.pkl
│   ├── scaler.pt
│   ├── scheduler.bin
│   └── unet
│       ├── config.json
│       ├── diffusion_pytorch_model-00001-of-00002.safetensors
│       ├── diffusion_pytorch_model-00002-of-00002.safetensors
│       └── diffusion_pytorch_model.safetensors.index.json
├── checkpoint-5000
│   ├── optimizer.bin
│   ├── random_states_0.pkl
│   ├── scaler.pt
│   ├── scheduler.bin
│   └── unet
│       ├── config.json
│       ├── diffusion_pytorch_model-00001-of-00002.safetensors
│       ├── diffusion_pytorch_model-00002-of-00002.safetensors
│       └── diffusion_pytorch_model.safetensors.index.json
├── model_index.json
├── scheduler
│   └── scheduler_config.json
├── text_encoder
│   ├── config.json
│   └── model.safetensors
├── text_encoder_2
│   ├── config.json
│   └── model.safetensors
├── tokenizer
│   ├── merges.txt
│   ├── special_tokens_map.json
│   ├── tokenizer_config.json
│   └── vocab.json
├── tokenizer_2
│   ├── merges.txt
│   ├── special_tokens_map.json
│   ├── tokenizer_config.json
│   └── vocab.json
├── unet
│   ├── config.json
│   ├── diffusion_pytorch_model-00001-of-00002.safetensors
│   ├── diffusion_pytorch_model-00002-of-00002.safetensors
│   └── diffusion_pytorch_model.safetensors.index.json
└── vae├── config.json└── diffusion_pytorch_model.safetensors

LoRA训练

accelerate launch train_text_to_image_lora_sdxl.py \--pretrained_model_name_or_path=$MODEL_NAME \--pretrained_vae_model_name_or_path=$VAE_NAME \--dataset_name=$DATASET_NAME --caption_column="text" \--resolution=1024 --random_flip \--train_batch_size=1 \--num_train_epochs=2 --checkpointing_steps=500 \--learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \--mixed_precision="fp16" \--seed=42 \--output_dir="sd-naruto-model-lora-sdxl" \--validation_prompt="cute dragon creature"

推理

from diffusers import DiffusionPipeline
import torchsdxl_model_path="/media/models/AI-ModelScope/stable-diffusion-xl-base-1___0/"
lora_model_path = "/media/text_to_image/sd-naruto-model-lora-sdxl/"pipe = DiffusionPipeline.from_pretrained(sdxl_model_path, torch_dtype=torch.float16)
pipe.to("cuda")
pipe.load_lora_weights(lora_model_path)prompt = "A naruto with green eyes and red legs."
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
image.save("naruto.png")

LoRA训练后的文件夹结构

├── checkpoint-1000
│   ├── optimizer.bin
│   ├── pytorch_lora_weights.safetensors
│   ├── random_states_0.pkl
│   ├── scaler.pt
│   └── scheduler.bin
├── checkpoint-1500
│   ├── optimizer.bin
│   ├── pytorch_lora_weights.safetensors
│   ├── random_states_0.pkl
│   ├── scaler.pt
│   └── scheduler.bin
├── checkpoint-2000
│   ├── optimizer.bin
│   ├── pytorch_lora_weights.safetensors
│   ├── random_states_0.pkl
│   ├── scaler.pt
│   └── scheduler.bin
├── checkpoint-500
│   ├── optimizer.bin
│   ├── pytorch_lora_weights.safetensors
│   ├── random_states_0.pkl
│   ├── scaler.pt
│   └── scheduler.bin
└── pytorch_lora_weights.safetensors

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/80763.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于React的高德地图api教程001:初始化地图

文章目录 1、初始化地图1.1 创建react项目1.2 可视化地图1.3 设置卫星地图1.4 添加开关开启3D地图1.5 代码下载1、初始化地图 1.1 创建react项目 创建geodeapi项目: npx create-react-app gaodeapi安装高德地图包: npm install @amap/amap-jsapi-loader1.2 可视化地图 在…

uniapp使用npm下载

uniapp的项目在使用HBuilder X创建时是不会有node_modules文件夹的&#xff0c;如下图所示&#xff1a; 但是uni-app不管基于哪个框架&#xff0c;它内部一定是有node.js的&#xff0c;否则没有办法去实现框架层面的一些东西&#xff0c;只是说它略微有点差异。具体差异表现在…

轻量在线工具箱系统源码 附教程

源码介绍 轻量在线工具箱系统源码,直接扔服务器 修改config/config.php文件里面的数据库 后台账号admin 密码admin123 本工具是AI写的 所以工具均是第三方接口直接写的。 需要加工具直接自己找接口写好扔到goju目录 后台自动读取 效果预览 源码获取 轻量在线工具箱系统源…

图解gpt之Seq2Seq架构与序列到序列模型

今天深入探讨如何构建更强大的序列到序列模型&#xff0c;特别是Seq2Seq架构。序列到序列模型&#xff0c;顾名思义&#xff0c;它的核心任务就是将一个序列映射到另一个序列。这个序列可以是文本&#xff0c;也可以是其他符号序列。最早&#xff0c;人们尝试用一个单一的RNN来…

mac M2能安装的虚拟机和linux系统系统

能适配MAC M2芯片的虚拟机下Linux系统的搭建全是深坑&#xff0c;目前网上的资料能搜到的都是错误的&#xff0c;自己整理并分享给坑友们~ 网上搜索到的推荐安装的改造过的centos7也无法进行yum操作&#xff0c;我这边建议安装centos8 VMware Fusion下载地址&#xff1a; htt…

「国产嵌入式仿真平台:高精度虚实融合如何终结Proteus时代?」——从教学实验到低空经济,揭秘新一代AI赋能的产业级教学工具

引言&#xff1a;从Proteus到国产平台的范式革新 在高校嵌入式实验教学中&#xff0c;仿真工具的选择直接影响学生的工程能力培养与创新思维发展。长期以来&#xff0c;Proteus作为经典工具占据主导地位&#xff0c;但其设计理念已难以满足现代复杂系统教学与国产化技术需求。…

【Linux】在Arm服务器源码编译onnxruntime-gpu的whl

服务器信息&#xff1a; aarch64架构 ubuntu20.04 nvidia T4卡 编译onnxruntime-gpu前置条件&#xff1a; 已经安装合适的cuda已经安装合适的cudnn已经安装合适的cmake 源码编译onnxruntime-gpu的步骤 1. 下载源码 git clone --recursive https://github.com/microsoft/o…

前端上传el-upload、原生input本地文件pdf格式(纯前端预览本地文件不走后端接口)

前端实现本地文件上传与预览&#xff08;PDF格式展示&#xff09;不走后端接口 实现步骤 第一步&#xff1a;文件选择 使用前端原生input上传本地文件&#xff0c;或者是el-upload组件实现文件选择功能&#xff0c;核心在于文件渲染处理。&#xff08;input只不过可以自定义样…

Python 数据分析与可视化:开启数据洞察之旅(5/10)

一、Python 数据分析与可视化简介 在当今数字化时代&#xff0c;数据就像一座蕴藏无限价值的宝藏&#xff0c;等待着我们去挖掘和探索。而 Python&#xff0c;作为数据科学领域的明星语言&#xff0c;凭借其丰富的库和强大的功能&#xff0c;成为了开启这座宝藏的关键钥匙&…

C语言学习记录——深入理解指针(4)

OK&#xff0c;这一篇主要是讲我学习的3种指针类型。 正文开始&#xff1a; 一.字符指针 所谓字符指针&#xff0c;顾名思义就是指向字符的指针。一般写作 " char* " 直接来说说它的使用方法吧&#xff1a; &#xff08;1&#xff09;一般使用情况&#xff1a; i…

springboot3+vue3融合项目实战-大事件文章管理系统获取用户详细信息-ThreadLocal优化

一句话本质 为每个线程创建独立的变量副本&#xff0c;实现多线程环境下数据的安全隔离&#xff08;线程操作自己的副本&#xff0c;互不影响&#xff09;。 关键解读&#xff1a; 核心机制 • 同一个 ThreadLocal 对象&#xff08;如示意图中的红色区域 tl&#xff09;被多个线…

Nacos源码—8.Nacos升级gRPC分析六

大纲 7.服务端对服务实例进行健康检查 8.服务下线如何注销注册表和客户端等信息 9.事件驱动架构源码分析 一.处理ClientChangedEvent事件 也就是同步数据到集群节点&#xff1a; public class DistroClientDataProcessor extends SmartSubscriber implements DistroDataSt…

设计杂谈-工厂模式

“工厂”模式在各种框架中非常常见&#xff0c;包括 MyBatis&#xff0c;它是一种创建对象的设计模式。使用工厂模式有很多好处&#xff0c;尤其是在复杂的框架中&#xff0c;它可以带来更好的灵活性、可维护性和可配置性。 让我们以 MyBatis 为例&#xff0c;来理解工厂模式及…

AI与IoT携手,精准农业未来已来

AIoT&#xff1a;农业领域的变革先锋 在科技飞速发展的当下&#xff0c;人工智能&#xff08;AI&#xff09;与物联网&#xff08;IoT&#xff09;的融合 ——AIoT&#xff0c;正逐渐成为推动各行业变革的关键力量&#xff0c;农业领域也不例外。AIoT 技术通过将 AI 的智能分析…

排错-harbor-db容器异常重启

排错-harbor-db容器异常重启 环境&#xff1a; docker 19.03 , harbor-db(postgresql) goharbor/harbor-db:v2.5.6 现象&#xff1a; harbor-db 容器一直restart&#xff0c;查看日志发现 报错 initdb: error: directory "/var/lib/postgresql/data/pg13" exists…

Docker容器启动失败?无法启动?

Docker容器无法启动的疑难杂症解析与解决方案 一、问题现象 Docker容器无法启动是开发者在容器化部署中最常见的故障之一。尽管Docker提供了丰富的调试工具&#xff0c;但问题的根源往往隐藏在复杂的配置、环境依赖或资源限制中。本文将从环境变量配置错误这一细节问题入手&am…

查看购物车

一.查看购物车 查看购物车使用get请求。我们要查看当前用户的购物车&#xff0c;就要获取当前用户的userId字段进行条件查询。因为在用户登录时就已经将userId封装在token中了&#xff0c;因此我们只需要解析token获取userId即可&#xff0c;不需要前端再传入参数了。 Control…

C/C++ 内存管理深度解析:从内存分布到实践应用(malloc和new,free和delete的对比与使用,定位 new )

一、引言&#xff1a;理解内存管理的核心价值 在系统级编程领域&#xff0c;内存管理是决定程序性能、稳定性和安全性的关键因素。C/C 作为底层开发的主流语言&#xff0c;赋予开发者直接操作内存的能力&#xff0c;却也要求开发者深入理解内存布局与生命周期管理。本文将从内…

使用Stable Diffusion(SD)中CFG参数指的是什么?该怎么用!

1.定义&#xff1a; CFG参数控制模型在生成图像时&#xff0c;对提示词&#xff08;Prompt&#xff09;的“服从程度”。 它衡量模型在“完全根据提示词生成图像”和“自由生成图像”&#xff08;不参考提示词&#xff09;之间的权衡程度。 数值范围&#xff1a;常见范围是 1 …

【GESP】C++三级练习 luogu-B2156 最长单词 2

GESP三级练习&#xff0c;字符串练习&#xff08;C三级大纲中6号知识点&#xff0c;字符串&#xff09;&#xff0c;难度★★☆☆☆。 题目题解详见&#xff1a;https://www.coderli.com/gesp-3-luogu-b2156/ 【GESP】C三级练习 luogu-B2156 最长单词 2 | OneCoderGESP三级练…