Pytorch高阶API示范——DNN二分类模型

代码部分:

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader,TensorDataset"""
准备数据
"""#正负样本数量
n_positive,n_negative = 2000,2000#生成正样本, 小圆环分布
r_p = 5.0 + torch.normal(0.0,1.0,size = [n_positive,1])
theta_p = 2*np.pi*torch.rand([n_positive,1])
Xp = torch.cat([r_p*torch.cos(theta_p),r_p*torch.sin(theta_p)],axis = 1)
Yp = torch.ones_like(r_p)#生成负样本, 大圆环分布
r_n = 8.0 + torch.normal(0.0,1.0,size = [n_negative,1])
theta_n = 2*np.pi*torch.rand([n_negative,1])
Xn = torch.cat([r_n*torch.cos(theta_n),r_n*torch.sin(theta_n)],axis = 1)
Yn = torch.zeros_like(r_n)#汇总样本
X = torch.cat([Xp,Xn],axis = 0)
Y = torch.cat([Yp,Yn],axis = 0)#可视化
plt.figure(figsize = (6,6))
plt.scatter(Xp[:,0],Xp[:,1],c = "r")
plt.scatter(Xn[:,0],Xn[:,1],c = "g")
plt.legend(["positive","negative"])plt.show()"""
#构建输入数据管道
"""
ds = TensorDataset(X,Y)
dl = DataLoader(ds,batch_size = 10,shuffle=True)"""
2, 定义模型
"""
class DNNModel(nn.Module):def __init__(self):super(DNNModel, self).__init__()self.fc1 = nn.Linear(2,4)self.fc2 = nn.Linear(4,8)self.fc3 = nn.Linear(8,1)def forward(self,x):x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))y = nn.Sigmoid()(self.fc3(x))return ydef loss_func(self,y_pred,y_true):return nn.BCELoss()(y_pred,y_true)def metric_func(self,y_pred,y_true):y_pred = torch.where(y_pred > 0.5, torch.ones_like(y_pred, dtype=torch.float32),torch.zeros_like(y_pred, dtype=torch.float32))acc = torch.mean(1 - torch.abs(y_true - y_pred))return acc@propertydef optimizer(self):return torch.optim.Adam(self.parameters(), lr=0.001)model = DNNModel()# 测试模型结构
(features,labels) = next(iter(dl))
predictions = model(features)loss = model.loss_func(predictions,labels)
metric = model.metric_func(predictions,labels)print("init loss:",loss.item())
print("init metric:",metric.item())"""
3,训练模型
"""
def train_step(model, features, labels):# 正向传播求损失predictions = model(features)loss = model.loss_func(predictions,labels)metric = model.metric_func(predictions,labels)# 反向传播求梯度loss.backward()# 更新模型参数model.optimizer.step()model.optimizer.zero_grad()return loss.item(),metric.item()# 测试train_step效果
features,labels = next(iter(dl))    #非for循环就用next
train_step(model,features,labels)def train_model(model,epochs):for epoch in range(1,epochs+1):loss_list,metric_list = [],[]for features, labels in dl:lossi,metrici = train_step(model,features,labels)loss_list.append(lossi)metric_list.append(metrici)loss = np.mean(loss_list)metric = np.mean(metric_list)if epoch%100==0:print("epoch =",epoch,"loss = ",loss,"metric = ",metric)train_model(model,epochs = 300)# 结果可视化
fig, (ax1,ax2) = plt.subplots(nrows=1,ncols=2,figsize = (12,5))
ax1.scatter(Xp[:,0],Xp[:,1], c="r")
ax1.scatter(Xn[:,0],Xn[:,1],c = "g")
ax1.legend(["positive","negative"])
ax1.set_title("y_true")Xp_pred = X[torch.squeeze(model.forward(X)>=0.5)]
Xn_pred = X[torch.squeeze(model.forward(X)<0.5)]ax2.scatter(Xp_pred[:,0],Xp_pred[:,1],c = "r")
ax2.scatter(Xn_pred[:,0],Xn_pred[:,1],c = "g")
ax2.legend(["positive","negative"])
ax2.set_title("y_pred")plt.show()

结果展示:

数据部分:

在这里插入图片描述

结果分类:

在这里插入图片描述

思考:

本文中的DNN模型,将loss(损失),metric(准确率),optimizer(优化器)的定义放在了DNN网络中,这也产生了一系列的问题。首先在调用这些函数时,需要用网络名+“.”来调用。例如:loss = model.loss_func(predictions,labels)
但是这里最重要的一点是 def optimizer(self):在optimizer函数上面必须有:@property,若没有将会出现AttributeError: 'function' object has no attribute 'step'的报错。
@property装饰器的作用:我们可以使用@property装饰器来创建只读属性,@property装饰器会将方法转换为相同名称的只读属性,可以与所定义的属性配合使用,这样可以防止属性被修改。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/389376.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

OO期末总结

$0 写在前面 善始善终&#xff0c;临近期末&#xff0c;为一学期的收获和努力画一个圆满的句号。 $1 测试与正确性论证的比较 $1-0 什么是测试&#xff1f; 测试是使用人工操作或者程序自动运行的方式来检验它是否满足规定的需求或弄清预期结果与实际结果之间的差别的过程。 它…

puppet puppet模块、file模块

转载&#xff1a;http://blog.51cto.com/ywzhou/1577356 作用&#xff1a;通过puppet模块自动控制客户端的puppet配置&#xff0c;当需要修改客户端的puppet配置时不用在客户端一一设置。 1、服务端配置puppet模块 &#xff08;1&#xff09;模块清单 [rootpuppet ~]# tree /et…

数据可视化及其重要性:Python

Data visualization is an important skill to possess for anyone trying to extract and communicate insights from data. In the field of machine learning, visualization plays a key role throughout the entire process of analysis.对于任何试图从数据中提取和传达见…

熊猫数据集_熊猫迈向数据科学的第三部分

熊猫数据集Data is almost never perfect. Data Scientist spend more time in preprocessing dataset than in creating a model. Often we come across scenario where we find some missing data in data set. Such data points are represented with NaN or Not a Number i…

Pytorch有关张量的各种操作

一&#xff0c;创建张量 1. 生成float格式的张量: a torch.tensor([1,2,3],dtype torch.float)2. 生成从1到10&#xff0c;间隔是2的张量: b torch.arange(1,10,step 2)3. 随机生成从0.0到6.28的10个张量 注意&#xff1a; (1).生成的10个张量中包含0.0和6.28&#xff…

mongodb安装失败与解决方法(附安装教程)

安装mongodb遇到的一些坑 浪费了大量的时间 在此记录一下 主要是电脑系统win10企业版自带的防火墙 当然还有其他的一些坑 一般的问题在第6步骤都可以解决&#xff0c;本教程的安装步骤不够详细的话 请自行百度或谷歌 安装教程很多 我是基于node.js使用mongodb结合Robo 3T数…

【洛谷算法题】P1046-[NOIP2005 普及组] 陶陶摘苹果【入门2分支结构】Java题解

&#x1f468;‍&#x1f4bb;博客主页&#xff1a;花无缺 欢迎 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! 本文由 花无缺 原创 收录于专栏 【洛谷算法题】 文章目录 【洛谷算法题】P1046-[NOIP2005 普及组] 陶陶摘苹果【入门2分支结构】Java题解&#x1f30f;题目…

web性能优化(理论)

什么是性能优化&#xff1f; 就是让用户感觉你的网站加载速度很快。。。哈哈哈。 分析 让我们来分析一下从用户按下回车键到网站呈现出来经历了哪些和前端相关的过程。 缓存 首先看本地是否有缓存&#xff0c;如果有符合使用条件的缓存则不需要向服务器发送请求了。DNS查询建立…

python多项式回归_如何在Python中实现多项式回归模型

python多项式回归Let’s start with an example. We want to predict the Price of a home based on the Area and Age. The function below was used to generate Home Prices and we can pretend this is “real-world data” and our “job” is to create a model which wi…

充分利用UC berkeleys数据科学专业

By Kyra Wong and Kendall Kikkawa黄凯拉(Kyra Wong)和菊川健多 ( Kendall Kikkawa) 什么是“数据科学”&#xff1f; (What is ‘Data Science’?) Data collection, an important aspect of “data science”, is not a new idea. Before the tech boom, every industry al…

文本二叉树折半查询及其截取值

using System;using System.ComponentModel;using System.Data;using System.Drawing;using System.Text;using System.Windows.Forms;using System.Collections;using System.IO;namespace CS_ScanSample1{ /// <summary> /// Logic 的摘要说明。 /// </summary> …

nn.functional 和 nn.Module入门讲解

本文来自《20天吃透Pytorch》 一&#xff0c;nn.functional 和 nn.Module 前面我们介绍了Pytorch的张量的结构操作和数学运算中的一些常用API。 利用这些张量的API我们可以构建出神经网络相关的组件(如激活函数&#xff0c;模型层&#xff0c;损失函数)。 Pytorch和神经网络…

10.30PMP试题每日一题

SC>0&#xff0c;CPI<1&#xff0c;说明项目截止到当前&#xff1a;A、进度超前&#xff0c;成本超值B、进度落后&#xff0c;成本结余C、进度超前&#xff0c;成本结余D、无法判断 答案将于明天和新题一起揭晓&#xff01; 10.29试题答案&#xff1a;A转载于:https://bl…

02-web框架

1 while True:print(server is waiting...)conn, addr server.accept()data conn.recv(1024) print(data:, data)# 1.得到请求的url路径# ------------dict/obj d["path":"/login"]# d.get(”path“)# 按着http请求协议解析数据# 专注于web业…

ai驱动数据安全治理_AI驱动的Web数据收集解决方案的新起点

ai驱动数据安全治理Data gathering consists of many time-consuming and complex activities. These include proxy management, data parsing, infrastructure management, overcoming fingerprinting anti-measures, rendering JavaScript-heavy websites at scale, and muc…

从Text文本中读值插入到数据库中

/// <summary> /// 转换数据&#xff0c;从Text文本中导入到数据库中 /// </summary> private void ChangeTextToDb() { if(File.Exists("Storage Card/Zyk.txt")) { try { this.RecNum.Visibletrue; SqlCeCommand sqlCreateTable…

Dataset和DataLoader构建数据通道

重点在第二部分的构建数据通道和第三部分的加载数据集 Pytorch通常使用Dataset和DataLoader这两个工具类来构建数据管道。 Dataset定义了数据集的内容&#xff0c;它相当于一个类似列表的数据结构&#xff0c;具有确定的长度&#xff0c;能够用索引获取数据集中的元素。 而D…

铁拳nat映射_铁拳如何重塑我的数据可视化设计流程

铁拳nat映射It’s been a full year since I’ve become an independent data visualization designer. When I first started, projects that came to me didn’t relate to my interests or skills. Over the past eight months, it’s become very clear to me that when cl…

Django2 Web 实战03-文件上传

作者&#xff1a;Hubery 时间&#xff1a;2018.10.31 接上文&#xff1a;接上文&#xff1a;Django2 Web 实战02-用户注册登录退出 视频是一种可视化媒介&#xff0c;因此视频数据库至少应该存储图像。让用户上传文件是个很大的隐患&#xff0c;因此接下来会讨论这俩话题&#…

BZOJ.2738.矩阵乘法(整体二分 二维树状数组)

题目链接 BZOJ洛谷 整体二分。把求序列第K小的树状数组改成二维树状数组就行了。 初始答案区间有点大&#xff0c;离散化一下。 因为这题是一开始给点&#xff0c;之后询问&#xff0c;so可以先处理该区间值在l~mid的修改&#xff0c;再处理询问。即二分标准可以直接用点的标号…