libtorch的c++,加载*.pth

一、转换模型为TorchScript

前提:python只保存了参数,没存结构

要在C++中使用libtorch(PyTorch的C++接口),读取和加载通过torch.save保存的模型(    torch.save(pdn.state_dict()这种方式,只保存了参数,没存结构),需要转换模型为TorchScript。在python下实现。

def get_pdn_small(out_channels=384, padding=False):pad_mult = 1 if padding else 0return nn.Sequential(nn.Conv2d(in_channels=3, out_channels=128, kernel_size=4,padding=3 * pad_mult),nn.ReLU(inplace=True),nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult),nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4,padding=3 * pad_mult),nn.ReLU(inplace=True),nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult),nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3,padding=1 * pad_mult),nn.ReLU(inplace=True),nn.Conv2d(in_channels=256, out_channels=out_channels, kernel_size=4))def get_pdn_medium(out_channels=384, padding=False):pad_mult = 1 if padding else 0return nn.Sequential(nn.Conv2d(in_channels=3, out_channels=256, kernel_size=4,padding=3 * pad_mult),nn.ReLU(inplace=True),nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult),nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4,padding=3 * pad_mult),nn.ReLU(inplace=True),nn.AvgPool2d(kernel_size=2, stride=2, padding=1 * pad_mult),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1),nn.ReLU(inplace=True),nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3,padding=1 * pad_mult),nn.ReLU(inplace=True),nn.Conv2d(in_channels=512, out_channels=out_channels, kernel_size=4),nn.ReLU(inplace=True),nn.Conv2d(in_channels=out_channels, out_channels=out_channels,kernel_size=1))
import torch# 假设你有一个已训练的模型
model = get_pdn_small()# 加载模型的state_dict
model.load_state_dict(torch.load('teacher_small.pth'))
model.eval()  # 设置模型为评估模式# 将模型转化为TorchScript
scripted_model = torch.jit.script(model)
scripted_model.save('teacher_small.pt')

二、在C++中加载TorchScript模型

在C++中,你可以使用torch::jit::load来加载.pt文件,如下所示:

#include <torch/script.h>  // One-stop header for loading TorchScript models
#include <iostream>
#include <memory>int main() {// 加载TorchScript模型try {// 加载模型std::shared_ptr<torch::jit::Module> model = std::make_shared<torch::jit::Module>(torch::jit::load("teacher_small.pt"));std::cout << "Model loaded successfully!" << std::endl;// 你可以在这里使用模型进行推理,比如输入一个张量// 例如,如果输入是一个3x224x224的图像,你需要创建一个相应的Tensortorch::Tensor input = torch::randn({1, 3, 224, 224});  // 示例输入std::vector<torch::jit::IValue> inputs;inputs.push_back(input);// 执行模型推理at::Tensor output = model->forward(inputs).toTensor();std::cout << "Output tensor: " << output << std::endl;}catch (const c10::Error& e) {std::cerr << "Error loading the model: " << e.what() << std::endl;return -1;}
}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/69405.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

postgresql 游标(cursor)的使用

概述 PostgreSQL游标可以封装查询并对其中每一行记录进行单独处理。当我们想对大量结果集进行分批处理时可以使用游标&#xff0c;因为一次性处理可能造成内存溢出。 另外我们可以定义函数返回游标类型变量&#xff0c;这是函数返回大数据集的有效方式&#xff0c;函数调用者…

Linux 快速对比两个文件的差异值

Linux 快速对比两个文件的差异值&#xff08;无需排序、直接输出&#xff09; 在日常开发或数据处理中&#xff0c;若需快速对比两个文本文件中的差异值&#xff08;仅保留第一个文件中的独有内容&#xff09;&#xff0c;Linux 系统提供了两种高效方法。以下是具体操作及适用…

【Pytorch实战教程】PyTorch中的Dataset用法详解

PyTorch中的Dataset用法详解 在深度学习中,数据是模型训练的基石。PyTorch作为一个强大的深度学习框架,提供了丰富的工具来处理和加载数据。其中,Dataset类是PyTorch中用于处理数据的重要工具之一。本文将详细介绍Dataset的用法,帮助你更好地理解和使用它。 1. 什么是Dat…

python:面向对象之魔法方法

概念&#xff1a;主要是提供一些特殊的功能。 1.__init__方法&#xff1a; 一.不带参数&#xff1a; python中类似__xx__() __init__():初始化对象class Car():def __init__(self):self.color blueself.type suvdef info(self):print(f车的颜色是&#xff1a;{self.color})p…

python两段多线程的例子

记录瞬间 其一 # coding:UTF-8 import os import threading from time import ctimedef loop(loops, list): # list存放着每个线程需要处理的文本文件名print(线程 %d 处理的文件列表 %s \n % (loops 1, list))list_len len(list)for i in range(list_len):f open(list[i…

基于蒙特卡洛思想生成电动汽车充电负荷曲线

本程序基于蒙特卡洛思想生成电动汽车充电负荷曲线&#xff0c;利用第十一届电工杯所提供的数据&#xff08;充电开始时间&#xff0c;充电电量&#xff0c;充电功率&#xff09;得到一万台电动汽车充电负荷曲线。蒙特卡洛只是解决问题的一种思想&#xff0c;本程序可为其他利用…

语言月赛 202308【小粉兔做麻辣兔头】题解(AC)

》》》点我查看「视频」详解》》》 [语言月赛 202308] 小粉兔做麻辣兔头 题目描述 粉兔喜欢吃麻辣兔头&#xff0c;麻辣兔头的辣度分为若干级&#xff0c;用数字表示&#xff0c;数字越大&#xff0c;兔头越辣。为了庆祝粉兔专题赛 #1 的顺利举行&#xff0c;粉兔要做一些麻…

C++20导出模块及使用

1.模块声明 .ixx文件为导入模块文件 math_operations.ixx export module math_operations;//模块导出 //导出命名空间 export namespace math_ {//导出命名空间中函数int add(int a, int b);int sub(int a, int b);int mul(int a, int b);int div(int a, int b); } .cppm文件…

使用redis实现 令牌桶算法 漏桶算法

流量控制算法&#xff0c;用于限制请求的速率。 可以应对缓存雪崩 令牌桶算法 核心思想是&#xff1a; 有一个固定容量的桶&#xff0c;里面存放着令牌&#xff08;token&#xff09;。每过一定时间&#xff08;如 1 秒&#xff09;&#xff0c;桶中会自动增加一定数量的令牌…

活动预告 | 解锁 Excel 新境界 —— AI 技术赋能下的数据分析超级引擎!

课程介绍 在 AI 技术的浪潮中&#xff0c;Microsoft Excel 已经焕然一新&#xff0c;它不再仅仅是海量复杂数据的处理中心&#xff0c;更是未来趋势的预测大师。智能 Copilot 副驾驶的加入&#xff0c;让 Excel 如虎添翼&#xff0c;成为每一位数据探索者梦寐以求的超级引擎。在…

在阿里云ECS上一键部署DeepSeek-R1

DeepSeek-R1 是一款开源模型&#xff0c;也提供了 API(接口)调用方式。据 DeepSeek介绍&#xff0c;DeepSeek-R1 后训练阶段大规模使用了强化学习技术&#xff0c;在只有极少标注数据的情况下提升了模型推理能力&#xff0c;该模型性能对标 OpenAl o1 正式版。DeepSeek-R1 推出…

Python分享20个Excel自动化脚本

在数据处理和分析的过程中&#xff0c;Excel文件是我们日常工作中常见的格式。通过Python&#xff0c;我们可以实现对Excel文件的各种自动化操作&#xff0c;提高工作效率。 本文将分享20个实用的Excel自动化脚本&#xff0c;以帮助新手小白更轻松地掌握这些技能。 1. Excel单…

使用requestAnimationFrame减少浏览器重绘

文章目录 介绍使用使用rAF前使用rAF后 介绍 在屏幕中&#xff0c;浏览器通常都以60FPS&#xff08;1/60 s&#xff09;每帧更新屏幕&#xff0c;但是当前端绑定了一些高频事件&#xff0c;如鼠标移动&#xff0c;屏幕滚动、触摸滑动等时&#xff0c;在一帧的周期内&#xff0c;…

Android的MQTT客户端实现

在 Android 平台上实现 MQTT 客户端的完整技术方案&#xff0c;涵盖基础实现、安全连接、性能优化和最佳实践&#xff1a; 一、技术选型与依赖配置 推荐库 Eclipse Paho Android Service&#xff08;官方维护&#xff0c;支持后台运行&#xff09; gradle 复制 // build.gradl…

SQL LEFT JOIN 详解

SQL LEFT JOIN 详解 引言 在SQL数据库查询中,LEFT JOIN 是一种强大的联接操作符,它允许我们从两个或多个表中检索数据。本文将详细介绍 LEFT JOIN 的概念、用法以及在实际应用中的注意事项。 一、什么是 LEFT JOIN? LEFT JOIN 是一种 SQL 联接操作符,用于返回左表(Lef…

理解UML中的四种关系:依赖、关联、泛化和实现

在软件工程中&#xff0c;统一建模语言&#xff08;UML&#xff09;是一种广泛使用的工具&#xff0c;用于可视化、设计、构造和文档化软件系统。UML提供了多种图表类型&#xff0c;如类图、用例图、序列图等&#xff0c;帮助开发者和设计师更好地理解系统的结构和行为。在UML中…

es match 可查 而 term 查不到 问题分析

版本信息 elasticsearch-8.13.0 es 匹配逻辑 根本&#xff1a;es 的匹配是基于token 的。检索的query和目标字段在token 层级上有交集才能检索成功。对同样的文本&#xff0c;使用不同的分词器&#xff0c;所得token 不同。es 默认的analyzer(分词器)是standard模式&#xf…

如何通过Deepseek的API进行开发和使用(适合开发者和小白的学习使用教程)

目录 一,API创建与获取 二,直接进行API的调用 2.1 安装第三方库 2.2 官方支持的接口调用方式 2.3 编写的小tips 2.4 AI助手工具代码 三, 配置方面的说明 3.1 token价格和字符用量 3.2 响应错误码 最近在休息的时候也是一直会刷到关于deepseek,简单使用了一下,发现这…

C#+halcon机器视觉九点标定算法

在机器视觉中&#xff0c;九点标定&#xff08;也称为九点标定法&#xff09;是一种常用的方法&#xff0c;用于将图像坐标系与物理坐标系进行映射。通过标定&#xff0c;可以将图像中的像素坐标转换为实际物理坐标&#xff0c;或者反之。下面是一个使用C#和Halcon进行九点标定…

Stream API 进阶:筛选、映射、查找、归约

文章目录 1. 引言 (Introduction)2. 筛选和切片 (Filtering and Slicing)2.1 使用谓词筛选 filter2.2 筛选各异的元素 distinct2.3 截短流 limit2.4 跳过元素 skip 3. 映射 (Mapping)3.1 对流中每一个元素应用函数 map3.2 流的扁平化 flatMap 4. 查找和匹配 (Finding and Match…