【tensorflow框架神经网络实现鸢尾花分类】

文章目录

  • 1、数据获取
  • 2、数据集构建
  • 3、模型的训练验证
  • 可视化训练过程

1、数据获取

  • 从sklearn中获取鸢尾花数据,并合并处理
from sklearn.datasets import load_iris
import pandas as pdx_data = load_iris().data
y_data = load_iris().targetx_data = pd.DataFrame(x_data, columns=['花萼长度','花萼宽度','花瓣长度','花瓣宽度'])
pd.set_option('display.unicode.east_asian_width', True)x_data['类别'] = y_data
x_data

在这里插入图片描述

2、数据集构建

  • 数据集构建包括:
    • 数据读取
    • 数据打乱
    • 数据划分
    • 小批量迭代器生成
import tensorflow as tf
import numpy as np
from sklearn.datasets import load_iris# 1、从sklearn包中datasets读取数据集
x_data = load_iris().data
y_data = load_iris().target# 2、数据打乱
np.random.seed(1)   # 使用相同的seed,使输入特征/标签一一对应
np.random.shuffle(x_data)
np.random.seed(1)
np.random.shuffle(y_data)
tf.random.set_seed(1)# 3、训练集、测试集划分
x_train, x_test = x_data[:-30], x_data[-30:]
y_train, y_test = y_data[:-30], y_data[-30:]# 4、小批量数据
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
train_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

3、模型的训练验证

# 定义超参数,预设变量
lr = 0.1
loss_all = 0
Epoch = 500
train_loss_list = []
test_acc = []# 定义神经网络的可训练参数
w = tf.Variable(tf.random.truncated_normal([4,3], stddev=0.1, seed=1))
b = tf.Variable(tf.random.truncated_normal([3], stddev=0.1, seed=1))# 循环迭代,训练参数
for epoch in range(Epoch):for step, (x_, y_) in enumerate(train_db):with tf.GradientTape() as tape:x_ = tf.cast(x_, tf.float32)y_pre = tf.matmul(x_, w) + by_pre = tf.nn.softmax(y_pre)y_lab = tf.one_hot(y_, depth=3)loss = tf.reduce_mean(tf.square(y_lab - y_pre))loss_all += loss.numpy()grads = tape.gradient(loss, [w,b])w.assign_sub(lr * grads[0])b.assign_sub(lr * grads[1])print(f'Epoch: {epoch}, loss: {loss_all/4}')train_loss_list.append(loss_all/4)loss_all = 0# 测试部分total_correct, total_number = 0, 0for x_,y_ in test_db:x_ = tf.cast(x_, tf.float32)y_pre = tf.matmul(x_, w) + by_pre = tf.nn.softmax(y_pre)y_p = tf.argmax(y_pre, axis=1)y_p = tf.cast(y_p, dtype=y_.dtype)correct = tf.cast(tf.equal(y_p, y_), dtype=tf.int32)correct = tf.reduce_sum(correct)total_correct += int(correct) total_number += x_.shape[0]acc = total_correct / total_numbertest_acc.append(acc)print("Test_acc:", acc)print("-"*30)

在这里插入图片描述

可视化训练过程

# 绘制测试Acc曲线和训练loss曲线
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.plot(train_loss_list,'b-')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')ax1 = ax.twinx()
ax1.plot(test_acc,'r-')
ax1.set_ylabel('Acc')ax1.spines['left'].set_color('blue')
ax1.spines['right'].set_color('red')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/778516.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ros2相关代码记录

1.ros2概述 ROS2(Robot Operating System 2)是一个用于机器人应用程序的开源软件框架。它是ROS(Robot Operating System)的下一代版本,旨在改进和扩展原始ROS的特性,以适应更广泛的机器人应用场景和需求。…

Unity 实现鼠标左键进行射击

发射脚本实现思路 分析 确定用户交互方式:通过鼠标左键点击发射子弹。确定子弹发射逻辑:每次点击后有一定时间间隔才能再次发射。确定子弹发射源和方向:子弹从枪口(Transform)位置发射,沿枪口方向前进。 变…

Qt扫盲-QAssisant 集成其他qch帮助文档

QAssisant 集成其他qch帮助文档 一、概述二、Cmake qch例子1. 下载 Cmake.qch2. 添加qch1. 直接放置于Qt 帮助的目录下2. 在 QAssisant中添加 一、概述 QAssisant是一个很好的帮助文档,他提供了供我们在外部添加新的 qch帮助文档的功能接口,一般有两中添…

八大技术趋势案例(虚拟现实增强现实)

科技巨变,未来已来,八大技术趋势引领数字化时代。信息技术的迅猛发展,深刻改变了我们的生活、工作和生产方式。人工智能、物联网、云计算、大数据、虚拟现实、增强现实、区块链、量子计算等新兴技术在各行各业得到广泛应用,为各个领域带来了新的活力和变革。 为了更好地了解…

QT QInputDialog弹出消息框用法

使用QInputDialog类的静态方法来弹出对话框获取用户输入,缺点是不能自定义按钮的文字,默认为OK和Cancel: int main(int argc, char *argv[]) {QApplication a(argc, argv);bool isOK;QString text QInputDialog::getText(NULL, "Input …

李宏毅【生成式AI导论 2024】第6讲 大型语言模型修炼_第一阶段_ 自我学习累积实力

背景知识:机器怎么学会做文字接龙 详见:https://blog.csdn.net/qq_26557761/article/details/136986922?spm=1001.2014.3001.5501 在语言模型的修炼中,我们需要训练资料来找出数十亿个未知参数,这个过程叫做训练或学习。找到参数后,我们可以使用函数来进行文字接龙,拿…

【数据分析面试】3.编写数据选取函数(Python)

题目 给定了一个名为 students_df 的学生数据表格 nameagefavorite_colorgradeTim Voss19red91Nicole Johnson20yellow95Elsa Williams21green82John James20blue75Catherine Jones23green93 编写一个名为 grades_colors 的函数,以选择仅当学生喜欢的颜色是绿色或…

2024最新Guitar Pro 8.1中文版永久许可证激活

Guitar Pro是一款非常受欢迎的音乐制作软件,它可以帮助用户创建和编辑各种音乐曲谱。从其诞生以来就送专门为了编写吉他谱而研发迭代的。 尽管这款产品可能已经成为全球最受欢迎的吉他打谱软件,在编写吉他六线谱和乐队总谱中始终处于行业领先地位&#x…

ESCTF-密码赛题WP

*小学生的爱情* Base64解码获得flag *中学生的爱情* 社会主义核心价值观在线解码得到flag http://www.atoolbox.net/Tool.php?Id850 *高中生的爱情* U2FsdG开头为rabbit密码,又提示你密钥为love。本地toolfx密码工具箱解密。不知道为什么在线解密不行。 *大学生的爱情* …

jira安装与配置

1. 环境准备 环境要求 1) JDK1.8以上环境配置 2) Mysql数据库5.7.13 3) Jira版本7及破解包 1.1 JDK1.8安装配置 1) 首先下载 JDK1.8, - 网址:https://www.oracle.com/cn/java/technologies/javase/javase-jdk8-downloads.html - windows64 版&am…

机器学习优化算法(深度学习)

目录 预备知识 梯度 Hessian 矩阵(海森矩阵,或者黑塞矩阵) 拉格朗日中值定理 柯西中值定理 泰勒公式 黑塞矩阵(Hessian矩阵) Jacobi 矩阵 优化方法 梯度下降法(Gradient Descent) 随机…

Pytorch的hook函数

hook函数是勾子函数,用于在不改变原始模型结构的情况下,注入一些新的代码用于调试和检验模型,常见的用法有保留非叶子结点的梯度数据(Pytorch的非叶子节点的梯度数据在计算完毕之后就会被删除,访问的时候会显示为None&…

STM32CubeMX学习笔记28---FreeRTOS软件定时器

一、软件定时器简介 1 、基本概念 定时器,是指从指定的时刻开始,经过一个指定时间,然后触发一个超时事件,用户 可以自定义定时器的周期与频率。类似生活中的闹钟,我们可以设置闹钟每天什么时候响, 还能设置…

Unity | 工具类-UV滚动

一、内置渲染管线Shader Shader"Custom/ImageRoll" {Properties {_MainTex ("Main Tex", 2D) "white" {}_Width ("Width", float) 0.5_Distance ("Distance", float) 0}SubShader {Tags {"Queue""Trans…

2024.3.28学习笔记

今日学习韩顺平java0200_韩顺平Java_对象机制练习_哔哩哔哩_bilibili 今日学习p286-p294 继承 继承可以解决代码复用,让我们的编程更加靠近人类思维,当多个类存在相同的属性和方法时,可以从这些类中抽象出父类,在父类中定义这些…

Day24|回溯算法part01:理论基础、77. 组合

理论基础 回溯法,一般可以解决如下几种问题: 组合问题:N个数里面按一定规则找出k个数的集合切割问题:一个字符串按一定规则有几种切割方式子集问题:一个N个数的集合里有多少符合条件的子集排列问题:N个数…

如何通过vscode连接到wsl

下载wsl扩展 远程连接模式

go的通信Channel

go的通道channel是用于协程之间数据通信的一种方式 一、channel的结构 go源码:GitHub - golang/go: The Go programming language src/runtime/chan.go type hchan struct {qcount uint // total data in the queue 队列中当前元素计数,…

专题二_滑动窗口(2)

目录 1658. 将 x 减到 0 的最小操作数 解析 题解 904. 水果成篮 解析 题解 1658. 将 x 减到 0 的最小操作数 1658. 将 x 减到 0 的最小操作数 - 力扣&#xff08;LeetCode&#xff09; 解析 题解 class Solution { public:int minOperations(vector<int>& num…

MPDataDoc类介绍

MPDataDoc类介绍 使用mp数据库新接口mp_api.client.MPRester获取数据&#xff0c;例子如下&#xff1a; from mp_api.client import MPResterwith MPRester(API_KEY) as mpr:docs mpr.summary.search(material_ids["mp-1176451", "mp-561113"])以上代码返…