用deepseek学大模型04-模型可视化与数据可视化

deepseek.com: pytorch可视化工具 生成神经网络图

在 PyTorch 中,可视化神经网络结构的常用工具和方法有以下几种,以下将详细介绍它们的用法:


1. TensorBoard (PyTorch 官方集成)

PyTorch 通过 torch.utils.tensorboard 支持 TensorBoard,可以直接可视化模型结构。

步骤:
  1. 安装 TensorBoard

    pip install tensorboard
    
  2. 代码示例

    import torch
    from torch.utils.tensorboard import SummaryWriter
    from torchvision.models import resnet18# 定义模型
    model = resnet18()
    # 创建 SummaryWriter 对象
    writer = SummaryWriter("runs/experiment1")
    # 生成随机输入数据(需与模型输入形状匹配)
    dummy_input = torch.randn(1, 3, 224, 224)
    # 将模型结构写入 TensorBoard
    writer.add_graph(model, dummy_input)
    writer.close()
    
  3. 启动 TensorBoard

    tensorboard --logdir=runs
    

    在浏览器中打开 http://localhost:6006,在 GRAPHS 标签页查看模型结构。


2. torchviz (生成计算图)

torchviz 基于 Graphviz,生成模型的计算图(前向传播路径)。

步骤:
  1. 安装依赖

    pip install torchviz graphviz
    
  2. 代码示例

    import torch
    from torchviz import make_dot
    from torchvision.models import resnet18model = resnet18()
    dummy_input = torch.randn(1, 3, 224, 224)
    output = model(dummy_input)
    # 生成计算图
    dot = make_dot(output, params=dict(model.named_parameters()))
    # 保存为图片(格式可为 PDF 或 PNG)
    dot.render("resnet18", format="png")
    

3. HiddenLayer (轻量级可视化)

hiddenlayer 是一个轻量级工具,适合快速生成模型结构图。

步骤:
  1. 安装

    pip install hiddenlayer
    
  2. 代码示例

    import torch
    import hiddenlayer as hl
    from torchvision.models import resnet18model = resnet18()
    dummy_input = torch.randn(1, 3, 224, 224)
    # 构建可视化图表
    hl.build_graph(model, dummy_input).save("resnet18.png")
    

4. Netron (模型结构查看器)

Netron 是一个独立的可视化工具,支持 PyTorch 模型文件(.pt.pth)或 ONNX 格式。

步骤:
  1. 导出模型为 ONNX

    import torch
    from torchvision.models import resnet18model = resnet18()
    dummy_input = torch.randn(1, 3, 224, 224)
    torch.onnx.export(model, dummy_input, "resnet18.onnx")
    
  2. 使用 Netron 打开
    访问 Netron Web 版 或下载桌面端,上传 .onnx 文件即可查看交互式模型结构。


5. PyTorch 内置打印

通过 print(model) 快速查看模型层级(适合简单结构):

import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super().__init__()self.layers = nn.Sequential(nn.Linear(10, 20),nn.ReLU(),nn.Linear(20, 1))def forward(self, x):return self.layers(x)model = MyModel()
print(model)

工具对比

工具优点缺点
TensorBoard集成训练监控,交互式可视化需要启动服务,略复杂
torchviz显示计算细节,适合调试依赖 Graphviz,图可能较复杂
HiddenLayer轻量级,快速生成功能相对简单
Netron无需代码,支持多格式依赖导出模型文件

根据需求选择工具:调试计算流程用 torchviz,训练集成用 TensorBoard,快速分享用 Netron

在 PyTorch 中,可视化训练集、测试集、预测结果和训练过程可以通过多种工具实现。以下是常用的工具和方法,按场景分类:


1. 可视化数据集(训练集/测试集)

工具:Matplotlib/Seaborn

直接使用 Python 绘图库可视化输入数据或特征分布。

import matplotlib.pyplot as plt
import torchvision# 示例:可视化 CIFAR10 训练集
dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# 显示前 9 张图片
fig, axes = plt.subplots(3, 3, figsize=(10, 10))
for i, ax in enumerate(axes.flat):img, label = dataset[i]ax.imshow(img)ax.set_title(f"Label: {classes[label]}")ax.axis('off')
plt.show()

2. 可视化训练过程

工具 1:TensorBoard(PyTorch 集成)

监控训练损失、准确率等指标,支持动态更新。

from torch.utils.tensorboard import SummaryWriterwriter = SummaryWriter("runs/experiment1")for epoch in range(num_epochs):# 训练代码...train_loss = ...val_accuracy = ...# 记录标量数据writer.add_scalar('Loss/train', train_loss, epoch)writer.add_scalar('Accuracy/val', val_accuracy, epoch)# 记录模型权重分布for name, param in model.named_parameters():writer.add_histogram(name, param, epoch)# 启动 TensorBoard
# tensorboard --logdir=runs
工具 2:Weights & Biases(第三方协作工具)

云端记录实验,支持超参数跟踪和团队协作。

import wandb# 初始化
wandb.init(project="my-project")# 记录指标
wandb.log({"train_loss": train_loss, "val_acc": val_accuracy})# 记录预测结果(图像示例)
wandb.log({"predictions": [wandb.Image(img, caption=f"Pred:{pred}, True:{true}")]})

3. 可视化预测结果

方法 1:Matplotlib 直接绘制
# 示例:分类结果可视化
import numpy as npmodel.eval()
with torch.no_grad():inputs, labels = next(iter(test_loader))outputs = model(inputs)preds = torch.argmax(outputs, dim=1)# 显示预测结果
fig, axes = plt.subplots(4, 4, figsize=(12, 12))
for i, ax in enumerate(axes.flat):ax.imshow(inputs[i].permute(1, 2, 0))  # 调整通道顺序ax.set_title(f"Pred: {classes[preds[i]]}\nTrue: {classes[labels[i]]}")ax.axis('off')
plt.tight_layout()
plt.show()
方法 2:混淆矩阵(分类任务)
from sklearn.metrics import confusion_matrix
import seaborn as sns# 计算混淆矩阵
cm = confusion_matrix(true_labels, pred_labels)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=classes, yticklabels=classes)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()

4. 高级可视化工具

工具 1:Plotly(交互式可视化)

绘制动态训练曲线:

import plotly.express as px# 假设 logs 是包含训练历史的字典
fig = px.line(logs, x='epoch', y=['train_loss', 'val_loss'], title="Training and Validation Loss")
fig.show()
工具 2:Gradio(快速构建交互式 Demo)

部署模型预测交互界面:

import gradio as grdef classify_image(img):img = preprocess(img)  # 预处理pred = model(img)      # 推理return classes[torch.argmax(pred)]gr.Interface(fn=classify_image, inputs="image", outputs="label").launch()

5. 训练过程可视化(高级)

工具:PyTorch Lightning

自动集成 TensorBoard 和 WandB,简化日志记录:

import pytorch_lightning as plclass MyModel(pl.LightningModule):def training_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)self.log("train_loss", loss)  # 自动记录到日志return loss# 训练时指定 logger
trainer = pl.Trainer(logger=pl.loggers.TensorBoardLogger("logs/"),# 或使用 WandB# logger=pl.loggers.WandbLogger(project="my-project")
)
trainer.fit(model)

工具对比

工具/方法适用场景优点缺点
TensorBoard训练指标跟踪、模型结构可视化官方集成,功能全面需本地启动服务
WandB团队协作、云端实验管理实时同步、超参数跟踪需要注册账号
Matplotlib静态数据可视化灵活、无需额外依赖交互性弱
Plotly交互式动态图表支持网页嵌入、动态更新学习曲线稍陡峭
Gradio快速部署预测 Demo零代码交互界面功能相对简单

关键场景总结

  1. 训练过程监控:优先选择 TensorBoard 或 WandB。
  2. 数据集预览:Matplotlib/Seaborn 快速绘制。
  3. 预测结果分析:混淆矩阵(分类)、BBox 标注(检测)、Matplotlib 对比图(回归)。
  4. 协作与报告:WandB 或 TensorBoard.dev(云端共享)。

可根据需求组合使用工具,例如:TensorBoard + Matplotlib(本地开发)或 WandB + Gradio(团队协作 + 演示)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/70049.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JavaScript设计模式 -- 外观模式

在实际开发中,往往会遇到多个子系统协同工作时,直接操作各个子系统不仅接口繁琐,还容易导致客户端与内部实现紧密耦合。**外观模式(Facade Pattern)**通过为多个子系统提供一个统一的高层接口,将复杂性隐藏…

【性能测试】如何理解“10个线程且10次循环“的请求和“100线程且1次循环“的请求

在性能测试中,我们常常会见到不同的并发配置:比如“10个线程且10次循环”与“100线程且1次循环”。乍一看,这两个设置的总请求数都是100次,但它们对系统的压力和测试场景却截然不同。了解其中的区别,能帮助你更精准地模…

Spring Boot 实战:轻松实现文件上传与下载功能

目录 一、引言 二、Spring Boot 文件上传基础 (一)依赖引入 (二)配置文件设置 (三)文件上传接口编写 (一)文件类型限制 (二)文件大小验证 &#xff0…

【Golang】GC探秘/写屏障是什么?

之前写了 一篇【Golang】内存管理 ,有了很多的阅读量,那么我就接着分享一下Golang的GC相关的学习。 由于Golang的GC机制一直在持续迭代,本文叙述的主要是Go1.9版本及以后的GC机制,该版本中Golang引入了 混合写屏障大幅度地优化了S…

DeepSeek教unity------MessagePack-03

数据契约兼容性 你可以使用 [DataContract] 注解代替 [MessagePackObject]。如果类型用 DataContract 进行注解,可以使用 [DataMember] 注解代替 [Key],并使用 [IgnoreDataMember] 代替 [IgnoreMember]。 然后,[DataMember(Order int)] 的…

【对比】Pandas 和 Polars 的区别

Pandas vs Polars 对比表 特性PandasPolars开发语言Python(Cython 实现核心部分)Rust(高性能系统编程语言)性能较慢,尤其在大数据集上(内存占用高,计算效率低)极快,利用…

百度千帆平台对接DeepSeek官方文档

目录 第一步:注册账号,开通千帆服务 第二步:创建应用,获取调用秘钥 第三步:调用模型,开启AI对话 方式一:通过API直接调用 方式二:使用SDK快速调用 方式三:在千帆大模…

49. c++计时器

为了测试某段特定代码的执行时间&#xff0c;体现代码的性能&#xff0c;可以使用计时器对代码段计时。下面使用std::chrono中的api编写简单案例&#xff1a; // // main.cpp // HelloWorld // // Created by on 2024/11/28. //#include <iostream> #include <vec…

Natural Language Processing NLP

NLP 清晰版本查看 Sentence segmentation (split)Tokenisation (split)Named entity recognition (combine) 概念主要內容典型方法Distributional Semantics&#xff08;分佈式語義&#xff09;&#xff08;分銷語義&#xff08;分佈式語義&#xff09;單詞的語義來自於它的…

Linux中线程创建,线程退出,线程接合

线程的简单了解 之前我们了解过 task_struct 是用于描述进程的核心数据结构。它包含了一个进程的所有重要信息&#xff0c;并且在进程的生命周期内保持更新。我们想要获取进程相关信息往往从这里得到。 在Linux中&#xff0c;线程的实现方式与进程类似&#xff0c;每个线程都…

HarmonyOS:使用List实现分组列表(包含粘性标题)

一、支持分组列表 在列表中支持数据的分组展示&#xff0c;可以使列表显示结构清晰&#xff0c;查找方便&#xff0c;从而提高使用效率。分组列表在实际应用中十分常见&#xff0c;如下图所示联系人列表。 联系人分组列表 在List组件中使用ListItemGroup对项目进行分组&#…

django上传文件

1、settings.py配置 # 静态文件配置 STATIC_URL /static/ STATICFILES_DIRS [BASE_DIR /static, ]上传文件 # 定义一个视图函数&#xff0c;该函数接收一个 request 参数 from django.shortcuts import render # 必备引入 import json from django.views.decorators.http i…

【前端知识】浏览器兼容方案polyfill

浏览器兼容方案polyfill 什么是 Polyfill&#xff1f;Polyfill 的作用Polyfill 的工作原理1. **特性检测**2. **加载 Polyfill**3. **模拟实现** Polyfill 的常见场景Polyfill 的使用方式Polyfill 的优缺点优点缺点 常见的 Polyfill 库总结 什么是 Polyfill&#xff1f; Polyf…

C#学习之DateTime 类

目录 一、DateTime 类的常用方法和属性的汇总表格 二、常用方法程序示例 1. 获取当前本地时间 2. 获取当前 UTC 时间 3. 格式化日期和时间 4. 获取特定部分的时间 5. 获取时间戳 6. 获取时区信息 三、总结 一、DateTime 类的常用方法和属性的汇总表格 在 C# 中&#x…

dedecms 开放重定向漏洞(附脚本)(CVE-2024-57241)

免责申明: 本文所描述的漏洞及其复现步骤仅供网络安全研究与教育目的使用。任何人不得将本文提供的信息用于非法目的或未经授权的系统测试。作者不对任何由于使用本文信息而导致的直接或间接损害承担责任。如涉及侵权,请及时与我们联系,我们将尽快处理并删除相关内容。 0x0…

如何选择合适的超参数来训练Bert和TextCNN模型?

选择合适的超参数来训练Bert和TextCNN模型是一个复杂但关键的过程&#xff0c;它会显著影响模型的性能。以下是一些常见的超参数以及选择它们的方法&#xff1a; 1. 与数据处理相关的超参数 最大序列长度&#xff08;max_length&#xff09; 含义&#xff1a;指输入到Bert模…

AWS 前端自动化部署流程指南

本文详细介绍从前端代码开发到 AWS 自动化部署的完整流程。 一、流程概览 1.1 部署流程图 #mermaid-svg-nYg7k6L5IKVBjDtr {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-nYg7k6L5IKVBjDtr .error-icon{fill:#552…

Office word打开加载比较慢处理方法

1.添加safe参数 ,找到word启动项,右击word,选择属性 , 添加/safe , 应用并确定 2.取消加载项,点击文件,点击选项 ,点击加载项,点击转到,取消所有勾选,确定。

大数据SQL调优专题——Spark执行原理

引入 在深入MapReduce中有提到&#xff0c;MapReduce虽然通过“分而治之”的思想&#xff0c;解决了海量数据的计算处理问题&#xff0c;但性能还是不太理想&#xff0c;这体现在两个方面&#xff1a; 每个任务都有比较大的overhead&#xff0c;都需要预先把程序复制到各个 w…

MYSQL下载安装及使用

MYSQL官网下载地址&#xff1a;https://downloads.mysql.com/archives/community/ 也可以直接在服务器执行指令下载&#xff0c;但是下载速度比较慢。还是自己下载好拷贝过来比较快。 wget https://dev.mysql.com/get/Downloads/mysql-5.7.38-linux-glibc2.12-x86_64.tar.gz 1…