TensorFlow实现逻辑回归模型

逻辑回归是一种经典的分类算法,广泛应用于二分类问题。本文将介绍如何使用TensorFlow框架实现逻辑回归模型,并通过动态绘制决策边界和损失曲线来直观地观察模型的训练过程。

数据准备

首先,我们准备两类数据点,分别表示两个不同的类别。这些数据点将作为模型的输入特征。

# 1.散点输入
class1_points=np.array([[1.9,1.2],[1.5,2.1],[1.9,0.5],[1.5,0.9],[0.9,1.2],[1.1,1.7],[1.4,1.1]])
class2_points=np.array([[3.2,3.2],[3.7,2.9],[3.2,2.6],[1.7,3.3],[3.4,2.6],[4.1,2.3],[3.0,2.9]])

将两类数据点合并为一个矩阵,并为每个数据点分配相应的标签(0或1)。

#不用单独提取出x1_data 和x2_data
#框架会根据输入特征数自动提取
x_train=np.concatenate((class1_points,class2_points),axis=0)
y_train=np.concatenate((np.zeros(len(class1_points)),np.ones(len(class2_points))))

将数据转换为TensorFlow张量,以便在模型中使用。

import tensorflow as tfx_train_tensor = tf.convert_to_tensor(x_train, dtype=tf.float32)
y_train_tensor = tf.convert_to_tensor(y_train, dtype=tf.float32)

模型定义

使用TensorFlow的tf.keras模块定义逻辑回归模型。模型包含一个输入层和一个输出层,输出层使用sigmoid激活函数。

def LogisticRegreModel():input = tf.keras.Input(shape=(2,))fc = tf.keras.layers.Dense(1, activation='sigmoid')(input)lr_model = tf.keras.models.Model(inputs=input, outputs=fc)return lr_modelmodel = LogisticRegreModel()

定义优化器和损失函数。这里使用随机梯度下降优化器和二元交叉熵损失函数。

opt = tf.keras.optimizers.SGD(learning_rate=0.01)
model.compile(optimizer=opt, loss="binary_crossentropy")

训练过程

训练模型时,我们记录每个epoch的损失值,并动态绘制决策边界和损失曲线。

 

import matplotlib.pyplot as pltfig, (ax1, ax2) = plt.subplots(1, 2)epochs = 500
epoch_list = []
epoch_loss = []for epoch in range(1, epochs + 1):y_pre = model.fit(x_train_tensor, y_train_tensor, epochs=50, verbose=0)epoch_loss.append(y_pre.history["loss"][0])epoch_list.append(epoch)w1, w2 = model.get_weights()[0].flatten()b = model.get_weights()[1][0]slope = -w1 / w2intercept = -b / w2x_min, x_max = 0, 5x = np.array([x_min, x_max])y = slope * x + interceptax1.clear()ax1.plot(x, y, 'r')ax1.scatter(x_train[:len(class1_points), 0], x_train[:len(class1_points), 1])ax1.scatter(x_train[len(class1_points):, 0], x_train[len(class1_points):, 1])ax2.clear()ax2.plot(epoch_list, epoch_loss, 'b')plt.pause(1)

结果展示

训练完成后,决策边界图将显示模型如何将两类数据分开,损失曲线图将显示模型在训练过程中的损失值变化。生成结果基本如图所示:

通过动态绘制决策边界和损失曲线,我们可以直观地观察模型的训练过程,了解模型如何逐渐学习数据的分布并优化决策边界。

总结

本文介绍了如何使用TensorFlow实现逻辑回归模型,并通过动态绘制决策边界和损失曲线来观察模型的训练过程。逻辑回归是一种简单而有效的分类算法,适用于二分类问题。通过TensorFlow框架,我们可以轻松地实现和训练逻辑回归模型,并利用其强大的功能来优化模型的性能。


完整代码

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
# 1.散点输入
class1_points=np.array([[1.9,1.2],[1.5,2.1],[1.9,0.5],[1.5,0.9],[0.9,1.2],[1.1,1.7],[1.4,1.1]])
class2_points=np.array([[3.2,3.2],[3.7,2.9],[3.2,2.6],[1.7,3.3],[3.4,2.6],[4.1,2.3],[3.0,2.9]])#不用单独提取出x1_data 和x2_data
#框架会根据输入特征数自动提取
x_train=np.concatenate((class1_points,class2_points),axis=0)
y_train=np.concatenate((np.zeros(len(class1_points)),np.ones(len(class2_points))))
#转化为张量
x_train_tensor=tf.convert_to_tensor(x_train,dtype=tf.float32)
y_train_tensor=tf.convert_to_tensor(y_train,dtype=tf.float32)#2.定义前向模型
# 使用类的方式
# 先设置一下随机数种子
seed=0
tf.random.set_seed(0)def LogisticRegreModel():input=tf.keras.Input(shape=(2,))fc=tf.keras.layers.Dense(1,activation='sigmoid')(input)lr_model=tf.keras.models.Model(inputs=input,outputs=fc)return lr_model
#实例化网络
model=LogisticRegreModel()
#3.定义损失函数和优化器
#定义优化器
#需要输入模型参数和学习率
lr=0.1
opt=tf.keras.optimizers.SGD(learning_rate=0.01)
model.compile(optimizer=opt,loss="binary_crossentropy")# 最后画图
fig,(ax1,ax2)=plt.subplots(1,2)
#训练
epoches=500
epoch_list=[]
epoch_loss=[]
for epoch in range(1,epoches+1):# verbose=0 进度条不显示  epochs迭代次数y_pre=model.fit(x_train_tensor,y_train_tensor,epochs=50,verbose=0)# print(y_pre.history["loss"])epoch_loss.append(y_pre.history["loss"][0])epoch_list.append(epoch)w1,w2=model.get_weights()[0].flatten()b=model.get_weights()[1][0]#画左图# 使用斜率和截距画直线#目前将x2当作y轴 x1当作x轴# w1*x1+w2*x2+b=0#求出斜率和截距slope=-w1/w2intercept=-b/w2#绘制直线 开始结束位置x_min,x_max=0,5x=np.array([x_min,x_max])y=slope*x+interceptax1.clear()ax1.plot(x,y,'r')#画散点图ax1.scatter(x_train[:len(class1_points),0],x_train[:len(class1_points),1])ax1.scatter(x_train[len(class1_points):, 0],x_train[len(class1_points):, 1])#画右图ax2.clear()ax2.plot(epoch_list,epoch_loss,'b')plt.pause(1)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/69467.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Unity git版本管理

创建仓库的时候添加了Unity的.gitignore模版,在这个时候就能自动过滤不需要的文件 打开git bash之后,步骤git版本管理-CSDN博客 如果报错,尝试重新进git 第一次传会耗时较长,之后的更新就很快了

【AI论文】扩散对抗后训练用于一步视频生成总结

摘要:扩散模型被广泛应用于图像和视频生成,但其迭代生成过程缓慢且资源消耗大。尽管现有的蒸馏方法已显示出在图像领域实现一步生成的潜力,但它们仍存在显著的质量退化问题。在本研究中,我们提出了一种在扩散预训练后针对真实数据…

低代码系统-产品架构案例介绍、明道云(十一)

明道云HAP-超级应用平台(Hyper Application Platform),其实就是企业级应用平台,跟微搭类似。 通过自设计底层架构,兼容各种平台,使用低代码做到应用搭建、应用运维。 企业级应用平台最大的特点就是隐藏在冰山下的功能很深&#xf…

实时数据处理与模型推理:利用 Spring AI 实现对数据的推理与分析

在现代企业中,实时数据处理与快速决策已经成为关键需求。通过集成 Spring AI,我们不仅可以高效地获取实时数据,还可以将这些数据输入到 AI 模型中进行推理与分析,以便生成实时的业务洞察。 本文将讲解如何通过 Spring AI 实现实时…

制造企业的成本核算

一、生产成本与制造费用的区别 (1)生产成本,是直接用于产品生产,构成产品实体的材料成本。 包括企业在生产经营过程中实际消耗的原材料、辅助材料、备品备件、外购半成品、燃料、动力包装物以及其它直接材料,和直接参加产品生产的工人工资,以及按生产工人的工资总额和规…

2025年AI手机集中上市,三星Galaxy S25系列上市

2025年被认为是AI手机集中爆发的一年,各大厂商都会推出搭载人工智能的智能手机。三星Galaxy S25系列全球上市了。 三星Galaxy S25系列包含S25、S25和S25 Ultra三款机型,起售价为800美元(约合人民币5800元)。全系搭载骁龙8 Elite芯…

【ESP32】ESP-IDF开发 | WiFi开发 | TCP传输控制协议 + TCP服务器和客户端例程

1. 简介 TCP(Transmission Control Protocol),全称传输控制协议。它的特点有以下几点:面向连接,每一个TCP连接只能是点对点的(一对一);提供可靠交付服务;提供全双工通信&…

2025数学建模美赛|赛题翻译|E题

2025数学建模美赛,E题赛题翻译 更多美赛内容持续更新中...

【Elasticsearch】Elasticsearch的查询

Elasticsearch的查询 DSL查询基础语句叶子查询全文检索查询matchmulti_match 精确查询termrange 复合查询算分函数查询bool查询 排序分页基础分页深度分页 高亮高亮原理实现高亮 RestClient查询基础查询叶子查询复合查询排序和分页高亮 数据聚合DSL实现聚合Bucket聚合带条件聚合…

什么是循环神经网络?

一、概念 循环神经网络(Recurrent Neural Network, RNN)是一类用于处理序列数据的神经网络。与传统的前馈神经网络不同,RNN具有循环连接,可以利用序列数据的时间依赖性。正因如此,RNN在自然语言处理、时间序列预测、语…

零售EDI:Costco EDI 项目须知

Costco 是全球领先的会员制仓储式零售商,致力于为会员提供高品质且价格实惠的商品。其经营范围涵盖食品、电子产品、家居用品、服装和办公设备等多个领域。 Costco 的 EDI 对接需求分析 为了更高效地管理其复杂的全球供应链,Costco 采用了先进的 EDI&am…

Kafka运维宝典 (三)- Kafka 最大连接数超出限制问题、连接超时问题、消费者消费时间超过限制问题详细介绍

Kafka运维宝典 (三) 文章目录 Kafka运维宝典 (三)一、Kafka Broker 配置中的最大连接数超出限制问题1. 错误原因2. 相关 Kafka 配置参数2.1 connections.max2.2 max.connections.per.ip2.3 num.network.threads2.4 connections.ma…

模板泛化类如何卸载释放内存

CustomWidget::~CustomWidget() {for (size_t i 0; i < buttonManager.registerItem.size(); i) {delete buttonManager.registerItem(exitButton);} } 以上该怎么写删除对象操作&#xff0c;类如下&#xff1a;template <typename T> class GenericManager { public…

在Linux系统上安装.NET

测试系统&#xff1a;openKylin(开放麒麟) 1.确定系统和架构信息&#xff1a; 打开终端&#xff08;Ctrl Alt T&#xff09;&#xff0c;输入cat /etc/os-release查看系统版本相关信息。 输入uname -m查看系统架构。确保你的系统和架构符合.NET 的要求&#xff0c;如果架构…

28. 【.NET 8 实战--孢子记账--从单体到微服务】--简易报表--报表定时器与报表数据修正

这篇文章是《.NET 8 实战–孢子记账–从单体到微服务》系列专栏的《单体应用》专栏的最后一片和开发有关的文章。在这片文章中我们一起来实现一个数据统计的功能&#xff1a;报表数据汇总。这个功能为用户查看月度、年度、季度报表提供数据支持。 一、需求 数据统计方面&…

深入探索C++17的std::any:类型擦除与泛型编程的利器

文章目录 基本概念构建方式构造函数直接赋值std::make_anystd::in_place_type 访问值值转换引用转换指针转换 修改器emplaceresetswap 观察器has_valuetype 使用场景动态类型的API设计类型安全的容器简化类型擦除实现 性能考虑动态内存分配类型转换和异常处理 总结 在C17的标准…

物管系统赋能智慧物业管理提升服务质量与工作效率的新风潮

内容概要 在当今的物业管理领域&#xff0c;物管系统的崛起为智慧物业管理带来了新的机遇和挑战。这些先进的系统能够有效整合各类信息&#xff0c;促进数字化管理&#xff0c;从而提升服务质量和工作效率。通过物管系统&#xff0c;物业管理者可以实时查看和分析各种数据&…

代码随想录算法训练营第三十八天-动态规划-完全背包-322. 零钱兑换

太难了 但听了前面再听这道题感觉递推公式也不是不难理解 动规五部曲 dp[j]代表装满容量为j&#xff08;也就是目标值&#xff09;的背包最少物品数量递推公式&#xff1a;dp[j] std::min(dp[j], dp[j - coins[i]] 1)当使用coins[i]这张纸币时&#xff0c;要向前找到容量为…

分组表格antd+ react +ts

import React from "react"; import { Table, Tag } from "antd"; import styles from "./index.less"; import GroupTag from "../Tag"; const GroupTable () > {const columns [{title: "姓名",dataIndex: "nam…

【JAVA实战】如何使用 Apache POI 在 Java 中写入 Excel 文件

大家好&#xff01;&#x1f31f; 在这篇文章中&#xff0c;我们将带你深入学习如何使用 Apache POI 在 Java 中编写 Excel 文件的技巧&#xff01;&#x1f4ca;&#x1f4da; 如果你是 Java 开发者&#xff0c;或者正在探索如何处理 Excel 文件的数据&#xff0c;那么这篇文章…