代码
r_test_1.py
importtorchimporttimeimportargparseimportosfromdatetimeimportdatetimefromdiffusersimportFluxPipelinedefsetup_environment():"""设置环境变量优化V100性能"""os.environ["PYTORCH_CUDA_ALLOC_CONF"]="expandable_segments:True,max_split_size_mb:128"os.environ["TOKENIZERS_PARALLELISM"]="false"torch.backends.cudnn.benchmark=Trueiftorch.cuda.is_available():print(f"📊 GPU显存:{torch.cuda.get_device_properties(0).total_memory/1024**3:.2f}GB")print(f"🔧 PyTorch版本:{torch.__version__}")print(f"🔧 CUDA版本:{torch.version.cuda}")defsetup_xformers(pipeline):"""安全启用xformers(带fallback)"""try:# 尝试启用 xformerspipeline.enable_xformers_memory_efficient_attention()print("✅ 启用 xformers 内存高效注意力")except(ImportError,AttributeError)ase:print(f"⚠️ xformers不可用 ({e}), 回退到默认注意力实现")# 可选:启用 attention slicing 作为补充(你已启用)passdefload_flux_model_mini(precision="fp16",device="cuda"):"""超轻量级加载 - V100 OOM专用"""print(f"🚀 正在加载FLUX.1-schnell模型(超轻量模式)...")start_time=time.time()torch_dtype=torch.float16# V100 onlytry:pipeline=FluxPipeline.from_pretrained("/models/flux-schnell",torch_dtype=torch_dtype,use_safetensors=True,# 安全起见:禁用 text encoder offload(FLUX-schnell 无 text encoder))# 关键优化print("✅ 启用顺序CPU卸载...")pipeline.enable_sequential_cpu_offload()print("✅ 启用VAE切片与拼接...")pipeline.vae.enable_slicing()pipeline.vae.enable_tiling()print("✅ 启用Attention Slicing...")pipeline.enable_attention_slicing(slice_size="max")# ✅ 启用 xformers(或 fallback)setup_xformers(pipeline)exceptExceptionase:print(f"❌ 加载失败:{e}")importtraceback traceback.print_exc()raiseload_time=time.time()-start_timeprint(f"✅ 模型加载完成! 耗时:{load_time:.2f}秒")ifdevice=="cuda":print(f"💾 当前VRAM:{torch.cuda.memory_allocated()/1024**3:.2f}GB")print(f"💾 峰值VRAM:{torch.cuda.max_memory_allocated()/1024**3:.2f}GB")returnpipeline,load_timedefload_lora_weights(pipeline,lora_path,lora_weight=1.0):"""加载LoRA权重 - 显存优化版 + 兼容新API"""ifnotlora_path:returnpipeline,0print(f"\n🔌 正在加载LoRA权重:{os.path.basename(lora_path)}")print(f"⚖️ LoRA权重强度:{lora_weight}")start_time=time.time()try:torch.cuda.empty_cache()# 新版 Diffusers 推荐方式(set_adapter 已弃用/更名)pipeline.load_lora_weights(lora_path,adapter_name="default")# ✅ 正确设置 adapter(新版为 set_adapters 或直接 fuse)# 方法1:动态切换(推荐用于多LoRA)pipeline.set_adapters(["default"],adapter_weights=[lora_weight])# 方法2:临时融合(推理更快,但不可逆)→ 注释掉,保持灵活# pipeline.fuse_lora(lora_scale=lora_weight)# pipeline.unfuse_lora() # 需要时再调用load_time=time.time()-start_timeprint(f"✅ LoRA加载完成! 耗时:{load_time:.2f}秒")iftorch.cuda.is_available():print(f"💾 LoRA加载后VRAM:{torch.cuda.memory_allocated()/1024**3:.2f}GB")returnpipeline,load_timeexceptExceptionase:print(f"❌ 加载LoRA失败:{e}")importtraceback traceback.print_exc()returnpipeline,0defgenerate_image_optimized(pipeline,prompt,**kwargs):"""显存优化的生成函数"""torch.cuda.empty_cache()# 移除 FLUX 不支持的参数(避免警告)kwargs.pop("max_sequence_length",None)print(f"\n🎨 开始生成...")print(f"提示词:{prompt[:50]}{'...'iflen(prompt)>50else''}")start_time=time.time()try:result=pipeline(prompt=prompt,**kwargs,)inference_time=time.time()-start_timeprint(f"✅ 生成完成! 耗时:{inference_time:.2f}秒")returnresult.images[0],inference_timeexcepttorch.cuda.OutOfMemoryError:print("❌ 仍然OOM!建议: 1) 降低分辨率 2) 检查LoRA是否兼容FLUX")raisedefmain():setup_environment()parser=argparse.ArgumentParser(description='FLUX.1-schnell推理 + LoRA支持 - V100 OOM优化版')parser.add_argument('--prompt',type=str,required=True,help='生成提示词')parser.add_argument('--negative_prompt',type=str,default="",help='负面提示词')parser.add_argument('--lora_path',type=str,default=None,help='LoRA权重路径 (.safetensors)')parser.add_argument('--lora_weight',type=float,default=1.0,help='LoRA权重强度 (0.0-1.0)')parser.add_argument('--steps',type=int,default=4,help='推理步数 (schnell推荐4)')parser.add_argument('--guidance',type=float,default=0.0,help='引导比例 (schnell必须为0)')parser.add_argument('--height',type=int,default=512,help='图像高度')parser.add_argument('--width',type=int,default=512,help='图像宽度')parser.add_argument('--seed',type=int,default=None,help='随机种子')parser.add_argument('--output_dir',type=str,default='outputs',help='输出目录')args=parser.parse_args()# 验证参数ifargs.guidance!=0.0:print("⚠️ 警告: FLUX.1-schnell要求guidance=0.0,已自动修正")args.guidance=0.0ifargs.height>1024orargs.width>1024:print("⚠️ 警告: 尺寸>1024可能导致OOM,建议减小")print("\n"+"="*60)print("📋 运行配置")print("="*60)print(f"提示词:{args.prompt}")print(f"LoRA:{os.path.basename(args.lora_path)ifargs.lora_pathelse'无'}")print(f"尺寸:{args.width}x{args.height}")print(f"步数:{args.steps}")print(f"种子:{args.seedifargs.seedelse'随机'}")print("="*60)total_start=time.time()# 1. 加载模型pipeline,model_time=load_flux_model_mini("fp16","cuda")# 2. 加载LoRAlora_time=0ifargs.lora_path:pipeline,lora_time=load_lora_weights(pipeline,args.lora_path,args.lora_weight)else:print("\n⏭️ 跳过LoRA加载")# 3. 生成image,gen_time=generate_image_optimized(pipeline,prompt=args.prompt,negative_prompt=args.negative_prompt,height=args.height,width=args.width,num_inference_steps=args.steps,guidance_scale=args.guidance,generator=torch.Generator().manual_seed(args.seed)ifargs.seedelseNone,)# 4. 保存(✅ 文件名缩短)os.makedirs(args.output_dir,exist_ok=True)timestamp_short=datetime.now().strftime("%H%M%S")# 仅时分秒,如 143045lora_tag="lora"ifargs.lora_pathelse"none"output_path=f"{args.output_dir}/flux_{lora_tag}_{timestamp_short}.png"image.save(output_path)print(f"\n💾 图像保存至:{output_path}")# 5. 清理print("\n🧹 清理内存...")delpipeline torch.cuda.empty_cache()# 6. 统计total_time=time.time()-total_startprint("\n"+"="*60)print("📊 最终统计")print("="*60)print(f"模型加载:{model_time:.2f}秒")ifargs.lora_path:print(f"LoRA加载:{lora_time:.2f}秒")print(f"图像生成:{gen_time:.2f}秒")print(f"总耗时:{total_time:.2f}秒")iftorch.cuda.is_available():print(f"剩余显存:{torch.cuda.memory_allocated()/1024**3:.2f}GB")print("="*60)if__name__=="__main__":main()运行
python r_test_1.py --prompt"a girl, 8k"--lora_path /opt/ai-lora_v1/lora_v1.safetensors --seed42--height512--width512记录下后面会用到