【笔记 Pytorch 08】深度学习模板 (未完)

文章目录

  • 一、声明
  • 二、工程结构
  • 三、文件内容
      • main.py
      • model.py
      • dataset.py
      • utils.py
  • 四、问题汇总

一、声明

非常感谢这些资料的作者:
【参考1】、【PyTorch速成教程 (by Sung Kim)】

二、工程结构

├── main.py:实现训练 (train) 、验证(validation)和测试(test)
│ ├── model.py:实现的模型
│ ├── dataset.py:加载的数据
│ ├── utils.py:常用功能

三、文件内容

main.py

from torch.utils.data import Dataset, DataLoader
from torch import from_numpy, tensor
from torch.autograd import Variable
import numpy as np
import model
import utils# load data
dataset = MyDataset()
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=2)# model
model=Model()# define loss and optimizer
criterion=torch.nn.BCELoss(size_average=True)
optimizer=torch.optim.SGD(model.parameters(),lr=0.1)# train
for epoch in range(2):for i, data in enumerate(train_loader, 0):# get the inputsinputs, labels = data# wrap them in Variableinputs, labels = Variable(inputs), Variable(labels)# Forward passy_pred=model(inputs)# Compute and print lossloss=criterion(y_pred,labels)accuracy= ultis.accuracy(y_pred,labels)print("[{:05d}/{:05d}] train_loss:{:.4f} accuracy: {:.4f}]".format(i,epoch,loss.data[0],accuracy))# updateoptimizer.zero_grad()	# zero gradientsloss.backward()			# perform a backward passoptimizer.step() 		# update weight or parameters

model.py

import torch
class Model(torch.nn.Module):def __init__(self):super(Model,self).__init__()self.l1=torch.nn.Linear(8,6)self.l2=torch.nn.Linear(6,4)self.l3=torch.nn.Linear(4,1)self.sigmoid=torch.nn.Sigmoid()# 数据流def forward(self,x):out1=self.sigmoid(self.l1(x))out2=self.sigmoid(self.l2(out1))y_pred=self.sigmoid(self.l3(out2))return y_pred

dataset.py

要点:
(1)必须重载 __getitem____len__
(2)

import torch
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):def __init__(self):  # Initialize your data, download, etc.xy = np.loadtxt('./data/diabetes.csv.gz', delimiter=',', dtype=np.float32)self.len = xy.shape[0]self.x_data = torch.from_numpy(xy[:, 0:-1])self.y_data = torch.from_numpy(xy[:, [-1]])def __getitem__(self, index):return self.x_data[index], self.y_data[index]def __len__(self):return self.len

utils.py

import numpy as np
import scipy.sparse as sp
import torch
import osdef encode_onehot(labels):classes = set(labels)classes_dict = {c: np.identity(len(classes))[i, :] for i, c inenumerate(classes)}labels_onehot = np.array(list(map(classes_dict.get, labels)),dtype=np.int32)return labels_onehotdef accuracy(output, labels):preds = output.max(1)[1].type_as(labels)correct = preds.eq(labels).double()correct = correct.sum()return correct / len(labels)def list_all_files(rootdir):_files = []#列出文件夹下所有的目录与文件list_file = os.listdir(rootdir)for i in range(0,len(list_file)):# 构造路径path = os.path.join(rootdir,list_file[i])# 判断路径是否是一个文件目录或者文件# 如果是文件目录,继续递归        if os.path.isdir(path):_files.extend(list_all_files(path))if os.path.isfile(path):_files.append(path)return _filesdef mkdir(path):# 去除首位空格path=path.strip()# 去除尾部 \ 符号path=path.rstrip("\\")# 判断路径是否存在# 存在     True# 不存在   FalseisExists=os.path.exists(path)# 判断结果if not isExists:# 如果不存在则创建目录# 创建目录操作函数os.makedirs(path) print(path+' create sucess')return Trueelse:# 如果目录存在则不创建,并提示目录已存在print(path+' path exist !')return False

四、问题汇总

:dataset.py中__getitem__返回的是一个元素,还是一个batch数据?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/172202.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python pip安装第三方包时报错 error: Microsoft Visual C++ 14.0 or greater is required.

文章目录 1.问题2.原因3.解决办法 1.问题 pip install 的时候报错一大堆,其中有这么一段话 👇 error: Microsoft Visual C 14.0 or greater is required. Get it with "Microsoft C Build Tools": https://visualstudio.microsoft.com/visua…

rdf-file:读和写

<dependency><groupId>com.alipay.rdf.file</groupId><artifactId>rdf-file-core</artifactId><version>2.2.10</version> </dependency>一&#xff1a;读 一&#xff1a;写 写文件之正常写 协议布局模板 使用内置的布局文…

二分 模板

好久没更新博客了&#xff0c;之前一直在准备比赛&#xff0c;忙着学算法和写题&#xff0c;今天写了一道二分答案的题&#xff0c;发现之前那种二分写法有一丢丢的问题&#xff0c;导致有道题只能过97%的点。 emmm,还是把最经典的二分的板子写在这记录下&#xff08;这里参考…

python每日一题——8无重复字符的最长子串

题目 给定一个字符串 s &#xff0c;请你找出其中不含有重复字符的 最长子串 的长度。 示例 1: 输入: s “abcabcbb” 输出: 3 解释: 因为无重复字符的最长子串是 “abc”&#xff0c;所以其长度为 3。 示例 2: 输入: s “bbbbb” 输出: 1 解释: 因为无重复字符的最长子串…

精进Beautiful Soup 小技巧(二)---处理多种页面结构

前言: 为了处理多种不同结构的页面&#xff0c;一个灵活的代码基础是至关重要的。一些针对性的技巧和方法&#xff0c;让你能够优雅地解决遇到的页面结构多元化的问题。 使用条件语句适配不同布局 当面对页面布局差异时&#xff0c;选择合适的条件语句至关重要。 认识布局类型…

正则表达式例题-PTA

PTA-7-55 判断指定字符串是否合法-CSDN博客 7-54 StringBuffer-拼接字符串 题目&#xff1a; 输入3个整数n、begin、end。 将从0到n-1的数字拼接为字符串str。如&#xff0c;n12&#xff0c;则拼接出来的字符串为&#xff1a;01234567891011 最后截取字符串str从begin到end(包…

【2023 年终盘点】今年用的最多的 10 款浏览器插件

分享顺哥今年用的最多的 10 款浏览器插件。 排名不分先后,涉及各个方面的应用。 大家有好用的插件也欢迎在评论区留言分享! 视频 YouTube:https://youtu.be/ZpTydUSBwCA 顺哥博客 浏览器扩展篇 注意: 1、以下介绍的均为在 Google Chrome 浏览器适用的小插件,部分插件…

2018年11月8日 Go生态洞察:参与2018年Go用户调查

&#x1f337;&#x1f341; 博主猫头虎&#xff08;&#x1f405;&#x1f43e;&#xff09;带您 Go to New World✨&#x1f341; &#x1f984; 博客首页——&#x1f405;&#x1f43e;猫头虎的博客&#x1f390; &#x1f433; 《面试题大全专栏》 &#x1f995; 文章图文…

一个通用的分页实体对象的思考

背景 用得非常多的一个分页实体对象 说明 只是一种抽象的思路, 可能不一定能够直接使用, 慎用. 只是一种抽象的思路, 可能不一定能够直接使用, 慎用. 只是一种抽象的思路, 可能不一定能够直接使用, 慎用. 分页实体 Data public class PageEntity<T> {/*** 分页后的结…

文件解析工具

前言 对Excel & CSV 文件解析 package com.wind.bird.Utils;import com.opencsv.CSVReader; import com.opencsv.CSVReaderBuilder; import org.apache.commons.validator.Var; import org.apache.poi.hssf.usermodel.HSSFCell; import org.apache.poi.hssf.usermodel.HS…

基于springboot学籍管理系统

一、设计目的 1. 复习、巩固Java语言的基础知识&#xff0c;进一步加深对Java语言的理解和掌握&#xff1b; 2. 课程设计为学生提供了一个既动手又动脑&#xff0c;独立实践的机会&#xff0c;将课本上的理论知识和实际有机的结合起来&#xff0c;锻炼学生的分析解决实际问题…

2016年五一杯数学建模B题能源总量控制下的城市工业企业协调发展问题解题全过程文档及程序

2016年五一杯数学建模 B题 能源总量控制下的城市工业企业协调发展问题 原题再现 能源是国民经济的重要物质基础,是工业企业发展的动力&#xff0c;但是过度的能源消耗&#xff0c;会破坏资源和环境&#xff0c;不利于经济的可持续发展。目前我国正处于经济转型的关键时期&…

关于 raw 图像的理解

1、问题背景 在图像调试过程&#xff0c;当发现一个问题时&#xff0c;很多时候都要通过 dump raw图像来分析&#xff0c;如果raw图像上有&#xff0c;那就排除了是 ISP的处理导致。 下一步就是排查 sensor 或者镜头&#xff0c;这样可以有效的帮我们定位问题所在。 但遇到过…

IDEA出现cannot download sources解决方案

IDEA出现cannot download sources解决方案 问题描述 当我想看第三方库的源码的注释时需要下载源码。 点击Dodnload Sources后可能会出现cannot download sources的问题。 解决方案 这时我们只需在根目录下打开Terminal后执行下面一行代码 mvn dependency:resolve -Dclassi…

王者荣耀Java

代码 package com.sxt;import javax.swing.*; import java.awt.*;public class Background extends GameObject {public Background(GameFrame gameFrame) {super(gameFrame);// TODO Auto-generated constructor stub}Image bg Toolkit.getDefaultToolkit().getImage("…

notion 3.0.0 版本最新桌面端汉化教程,支持MAC和WIN版本

notion客户端汉化&#xff08;目前版本3.0.0&#xff09; 最近notion桌面端更新了3.0.0版本后会导致老版本汉化失效&#xff0c;本项目实现了最新版Notion桌面端的汉化。 文件下载地址&#xff1a;汉化文件下载地址 项目说明 本项目针对新的客户端做了汉化文化&#xff0c;依…

超实用!Spring Boot 常用注解详解与应用场景

目录 一、Web MVC 开发时&#xff0c;对于三层的类注解 1.1 Controller 1.2 Service 1.3 Repository 1.4 Component 二、依赖注入的注解 2.1 Autowired 2.2 Resource 2.3 Resource 与 Autowired 的区别 2.3.1 实例讲解 2.4 Value 2.5 Data 三、Web 常用的注解 3.1…

【Linux】Linux 系统 grep 命令超详细讲解

文章目录 grep补充说明选项规则表达式grep命令常见用法 grep grep 命令是一项非常有用的工具。grep&#xff08;全称&#xff1a;Global Regular Expression Print&#xff09;命令用于根据给定的正则表达式搜索文本&#xff0c;并将匹配的行打印出来。 补充说明 grep &…

可以在Playgrounds或Xcode Command Line Tool开始学习Swift

一、用Playgrounds 1. App Store搜索并安装Swift Playgrounds 2. 打开Playgrounds&#xff0c;点击 文件-新建图书。然后就可以编程了&#xff0c;如下&#xff1a; 二、用Xcode 1. 安装Xcode 2. 打开Xcode&#xff0c;选择Creat New Project 3. 选择macOS 4. 选择Comman…

【面经八股】搜广推方向:常见面试题(三)

【面经&八股】搜广推方向:常见面试题(三) 文章目录 【面经&八股】搜广推方向:常见面试题(三)1. 如何解决数据不平衡2. 假设检验的两类错误3. 为什么快排比堆排快4. RMSE、MSE、MAE5. 双塔模型的应用6. XGBoost如果损失函数没有二阶导,该怎么办7. AUC是如何实现的…