NLP(3)--利用nn反向计算参数

前言

仅记录学习过程,有问题欢迎讨论

获取数据

自定义一个方程,获取一批数据X,Y

import matplotlib.pyplot as pyplot
import math
import sysX = [0.01 * x for x in range(100)]
Y = [2*x**2 + 3*x + 4 for x in X]
print(X)
print(Y)
pyplot.scatter(X, Y, color='red')
pyplot.show()
input()

利用nn模型计算w

X = [0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18,0.19, 0.2, 0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.3, 0.31, 0.32, 0.33, 0.34, 0.35000000000000003,0.36, 0.37, 0.38, 0.39, 0.4, 0.41000000000000003, 0.42, 0.43, 0.44, 0.45, 0.46, 0.47000000000000003, 0.48, 0.49,0.5, 0.51, 0.52, 0.53, 0.54, 0.55, 0.56, 0.5700000000000001, 0.58, 0.59, 0.6, 0.61, 0.62, 0.63, 0.64, 0.65, 0.66,0.67, 0.68, 0.6900000000000001, 0.7000000000000001, 0.71, 0.72, 0.73, 0.74, 0.75, 0.76, 0.77, 0.78, 0.79, 0.8,0.81, 0.8200000000000001, 0.8300000000000001, 0.84, 0.85, 0.86, 0.87, 0.88, 0.89, 0.9, 0.91, 0.92, 0.93,0.9400000000000001, 0.9500000000000001, 0.96, 0.97, 0.98, 0.99]
Y = [4.0, 4.0302, 4.0608, 4.0918, 4.1232, 4.155, 4.1872, 4.2198, 4.2528, 4.2862, 4.32, 4.3542, 4.3888, 4.4238, 4.4592,4.495, 4.5312, 4.5678, 4.6048, 4.6422, 4.68, 4.7181999999999995, 4.7568, 4.7958, 4.8352, 4.875, 4.9152000000000005,4.9558, 4.9968, 5.0382, 5.08, 5.122199999999999, 5.1648, 5.2078, 5.2512, 5.295, 5.3392, 5.3838, 5.4288, 5.4742,5.5200000000000005, 5.5662, 5.6128, 5.6598, 5.7072, 5.755, 5.8032, 5.851800000000001, 5.9008, 5.9502, 6.0, 6.0502,6.1008, 6.1518, 6.203200000000001, 6.255000000000001, 6.3072, 6.3598, 6.4128, 6.4662, 6.52, 6.5742, 6.6288, 6.6838,6.7392, 6.795, 6.8512, 6.9078, 6.9648, 7.022200000000001, 7.08, 7.138199999999999, 7.1968, 7.2558, 7.3152, 7.375,7.4352, 7.4958, 7.5568, 7.6182, 7.680000000000001, 7.7422, 7.8048, 7.867800000000001, 7.9312, 7.994999999999999,8.0592, 8.1238, 8.1888, 8.2542, 8.32, 8.3862, 8.4528, 8.5198, 8.587200000000001, 8.655000000000001, 8.7232, 8.7918,8.8608, 8.9302]# 标签
def func(x):y = w1 * x**2 + w2 * x + w3return y# 损失函数
def loss(y_pre, y_true):return (y_true - y_pre) ** 2# 随机定义 w
w1, w2, w3 = 1, 2, 3
# 学习率
lr = 0.1# 训练过程
for epoch in range(500):epoch_loss = 0for x, y_true in zip(X, Y):y_pre = func(x)# 本轮loss 总值epoch_loss += loss(y_pre, y_true)# 梯度计算grad_w1 = 2 * (y_pre - y_true) * x ** 2grad_w2 = 2 * (y_pre - y_true) * xgrad_w3 = 2 * (y_pre - y_true)# 权重更新w1 = w1 - lr * grad_w1  # sgdw2 = w2 - lr * grad_w2w3 = w3 - lr * grad_w3# 本轮结束 计算平均lossepoch_loss /= len(X)print("第%d轮, loss %f" % (epoch, epoch_loss))if epoch_loss < 0.00001:breakprint(f"训练后权重:w1:{w1} w2:{w2} w3:{w3}")
# #使用训练后模型输出预测值
Yp = [func(i) for i in X]
# 预测值与真实值比对数据分布
pyplot.scatter(X, Y, color="red")
pyplot.scatter(X, Yp)
pyplot.show()

优化梯度、权重计算

X = [0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18,0.19, 0.2, 0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.3, 0.31, 0.32, 0.33, 0.34, 0.35000000000000003,0.36, 0.37, 0.38, 0.39, 0.4, 0.41000000000000003, 0.42, 0.43, 0.44, 0.45, 0.46, 0.47000000000000003, 0.48, 0.49,0.5, 0.51, 0.52, 0.53, 0.54, 0.55, 0.56, 0.5700000000000001, 0.58, 0.59, 0.6, 0.61, 0.62, 0.63, 0.64, 0.65, 0.66,0.67, 0.68, 0.6900000000000001, 0.7000000000000001, 0.71, 0.72, 0.73, 0.74, 0.75, 0.76, 0.77, 0.78, 0.79, 0.8,0.81, 0.8200000000000001, 0.8300000000000001, 0.84, 0.85, 0.86, 0.87, 0.88, 0.89, 0.9, 0.91, 0.92, 0.93,0.9400000000000001, 0.9500000000000001, 0.96, 0.97, 0.98, 0.99]
Y = [4.0, 4.0302, 4.0608, 4.0918, 4.1232, 4.155, 4.1872, 4.2198, 4.2528, 4.2862, 4.32, 4.3542, 4.3888, 4.4238, 4.4592,4.495, 4.5312, 4.5678, 4.6048, 4.6422, 4.68, 4.7181999999999995, 4.7568, 4.7958, 4.8352, 4.875, 4.9152000000000005,4.9558, 4.9968, 5.0382, 5.08, 5.122199999999999, 5.1648, 5.2078, 5.2512, 5.295, 5.3392, 5.3838, 5.4288, 5.4742,5.5200000000000005, 5.5662, 5.6128, 5.6598, 5.7072, 5.755, 5.8032, 5.851800000000001, 5.9008, 5.9502, 6.0, 6.0502,6.1008, 6.1518, 6.203200000000001, 6.255000000000001, 6.3072, 6.3598, 6.4128, 6.4662, 6.52, 6.5742, 6.6288, 6.6838,6.7392, 6.795, 6.8512, 6.9078, 6.9648, 7.022200000000001, 7.08, 7.138199999999999, 7.1968, 7.2558, 7.3152, 7.375,7.4352, 7.4958, 7.5568, 7.6182, 7.680000000000001, 7.7422, 7.8048, 7.867800000000001, 7.9312, 7.994999999999999,8.0592, 8.1238, 8.1888, 8.2542, 8.32, 8.3862, 8.4528, 8.5198, 8.587200000000001, 8.655000000000001, 8.7232, 8.7918,8.8608, 8.9302]# 标签
def func(x):y = w1 * x**2 + w2 * x + w3return y# 损失函数
def loss(y_pre, y_true):return (y_true - y_pre) ** 2# 随机定义 w
w1, w2, w3 = 1, 2, 3
# 学习率
lr = 0.1batch_size = 20
# 训练过程
for epoch in range(500):epoch_loss = 0count = 0grad_w1 = 0grad_w2 = 0grad_w3 = 0for x, y_true in zip(X, Y):count += 1y_pre = func(x)# 本轮loss 总值epoch_loss += loss(y_pre, y_true)# 梯度计算grad_w1 += 2 * (y_pre - y_true) * x ** 2grad_w2 += 2 * (y_pre - y_true) * xgrad_w3 += 2 * (y_pre - y_true)# 更新权重if count == batch_size:count = 0# 权重更新w1 = w1 - lr * grad_w1/batch_size  # sgdw2 = w2 - lr * grad_w2/batch_sizew3 = w3 - lr * grad_w3/batch_size# 本轮结束 计算平均lossepoch_loss /= len(X)print("第%d轮, loss %f" % (epoch, epoch_loss))if epoch_loss < 0.00001:breakprint(f"训练后权重:w1:{w1} w2:{w2} w3:{w3}")
# #使用训练后模型输出预测值
Yp = [func(i) for i in X]
# 预测值与真实值比对数据分布
pyplot.scatter(X, Y, color="red")
pyplot.scatter(X, Yp)
pyplot.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/382.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

KV Cache 技术分析

原文&#xff1a;Notion – The all-in-one workspace for your notes, tasks, wikis, and databases. 1 什么是KV Cache LLM&#xff08;大型语言模型&#xff09;中的 Attention 机制中的 KV Cache&#xff08;键值缓存&#xff09;主要作用是存储键值对&#xff0c;以避免在…

ChatGPT又多了一个强有力的竞争对手:Meta发布Llama 3开源模型!附体验地址

大家好&#xff0c;我是木易&#xff0c;一个持续关注AI领域的互联网技术产品经理&#xff0c;国内Top2本科&#xff0c;美国Top10 CS研究生&#xff0c;MBA。我坚信AI是普通人变强的“外挂”&#xff0c;所以创建了“AI信息Gap”这个公众号&#xff0c;专注于分享AI全维度知识…

ChatPDF代码解读2

这段代码定义了一个名为`ChatPDF`的类,它结合了文本相似性模型和生成模型,用于处理和生成基于特定文档语料库的自然语言回答。以下是对代码的详细解读: 1. **导入依赖**:代码开始处导入了多个Python库,包括`argparse`(命令行参数解析)、`hashlib`(哈希函数)、`os`(操…

力扣:两数之和

知识点&#xff1a; 动态数组的创建&#xff1a; #include<stdlib.h> arr (int*)malloc(len * sizeof(int)); 如何使用sacnf输入数组&#xff1a; scanf 函数在读取输入时&#xff0c;会自动跳过空格&#xff08;空格、制表符、换行符等&#xff09;和换行符&#…

vscode绿绿主题setting config

下载插件Green Tree Theme 选greentree ctrl shift p找到setting {"workbench.colorTheme": "Green Tree","editor.fontSize": 16.5, // 字号"workbench.colorCustomizations": {"[Green Tree]": {"activityBarBadge.…

日志事件ID

日志排查时&#xff0c;通常会根据事件ID搜索日志。 1、安全日志 用户登录事件&#xff1a; 4624&#xff1a;登录成功 4625&#xff1a;登录失败 4634&#xff1a;注销本地登录用户 4647&#xff1a;注销远程登录的用户 4648&#xff1a;使用显式凭证尝试登录 4672&am…

用pigeon kotlin swift写一个自己的插件

文章目录 1. 创建一个flutter plugin项目2. 引入依赖3. 创建pigeons文件夹和message.dart4. 执行生成各个平台文件的命令5. base_plugin.dart6. BasePlugin.kt7. BasePlugin.swift8. 遇到的问题9. [源码](https://github.com/githubityu/base_plugin) 1. 创建一个flutter plugi…

算法一:数字 - 两数之和

给定一个整数数组 nums 和一个目标值 target&#xff0c;请你在该数组中找出和为目标值的那 两个 整数&#xff0c;并返回他们的数组下标。 你可以假设每种输入只会对应一个答案。但是&#xff0c;数组中同一个元素不能使用两遍。 来源&#xff1a;力扣(LeetCode) 链接&#xf…

scala---基础核心知识(变量定义,数据类型,流程控制,方法定义,函数定义)

一、什么是scala Scala 是一种多范式的编程语言&#xff0c;其设计初衷是要集成面向对象编程和函数式编程的各种特性。Scala运行于Java平台&#xff08;Java虚拟机&#xff09;&#xff0c;并兼容现有的Java程序。 二、为什么要学习scala 1、优雅 2、速度快 3、能融合到hado…

管道流设计模式结合业务

文章目录 流程图代码实现pomcontextEventContextBizTypeAbstractEventContext filterEventFilterAbstractEventFilterEventFilterChainFilterChainPipelineDefaultEventFilterChain selectorFilterSelectorDefaultFilterSelector 调用代码PipelineApplicationcontrollerentitys…

浅析binance新币OMNI的前世今生

大盘跌跌不休&#xff0c;近期唯一的指望就是binance即将上线的OMNI 。虽然目前查到的空投数量不及预期&#xff0c;但OMNI能首发币安&#xff0c;确实远超预期了。 OMNI代币总量1亿&#xff0c;初始流通仅10,391,492枚&#xff0c;其中币安Lanchpool可挖350万枚 对于OMNI这个…

设计模式——模版模式21

模板方法模式在超类中定义了一个事务流程的框架&#xff0c; 允许子类在不修改结构的情况下重写其中一个或者多个特定步骤。下面以ggbond的校招礼盒发放为例。 设计模式&#xff0c;一定要敲代码理解 模版抽象 /*** author ggbond* date 2024年04月18日 17:32* 发送奖品*/ p…

50.HarmonyOS鸿蒙系统 App(ArkUI)web组件实现简易浏览器

50.HarmonyOS鸿蒙系统 App(ArkUI)web组件实现简易浏览器 配置网络访问权限&#xff1a; 跳转任务&#xff1a; Button(转到).onClick(() > {try {// 点击按钮时&#xff0c;通过loadUrl&#xff0c;跳转到www.example1.comthis.webviewController.loadUrl(this.get_url);} …

代码随想录第39天|62.不同路径 63. 不同路径 II

62.不同路径 62. 不同路径 - 力扣&#xff08;LeetCode&#xff09; 代码随想录 (programmercarl.com) 动态规划中如何初始化很重要&#xff01;| LeetCode&#xff1a;62.不同路径_哔哩哔哩_bilibili 一个机器人位于一个 m x n 网格的左上角 &#xff08;起始点在下图中标…

Codeforces Round 782 (Div. 2) D. Reverse Sort Sum

题目 思路&#xff1a; #include <bits/stdc.h> using namespace std; #define int long long #define pb push_back #define fi first #define se second #define lson p << 1 #define rson p << 1 | 1 const int maxn 1e6 5, inf 1e9, maxm 4e4 5; co…

Rust常见陷阱 | 线程间传递消息导致主线程无法结束

在多线程编程中,线程之间的通信是一个不可或缺的话题。尤其是在Rust语言中,由于其特有的所有权机制,线程通信需要更加仔细地处理。本文将深入讨论使用Rust标准库中的消息通道时如何避免主线程被阻塞的问题,并提供详尽的代码示例来辅助理解。 问题描述 当我们在Rust中使用…

【WP】猿人学4 雪碧图、样式干扰

https://match.yuanrenxue.cn/match/4 探索 首先打开Fiddler&#xff0c;发现每个包的除了page参数一样&#xff0c;然后重放攻击可以实现&#xff0c;尝试py复现 Python可以正常拿到数据&#xff0c;这题不考请求&#xff0c;这题的难点原来在于数据的加密&#xff0c;这些数字…

用10个Kimichat提示词5分钟创建一门在线课程

●研究市场并在生成式AI主题内找到一个特定细分市场&#xff0c;这一市场尚未被现有课程充分覆盖。使用在线研究来收集关于当前可用课程类型的信息&#xff0c;以及市场上存在哪些空白。利用这些信息创建一个填补空白并吸引对“生成式AI”感兴趣的特定受众群体的课程。确定课程…

面试经典算法系列之二叉树17 -- 验证二叉树

面试经典算法32 - 验证二叉树 LeetCode.98 公众号&#xff1a;阿Q技术站 问题描述 给你一个二叉树的根节点 root &#xff0c;判断其是否是一个有效的二叉搜索树。 有效 二叉搜索树定义如下&#xff1a; 节点的左子树只包含 小于 当前节点的数。节点的右子树只包含 大于 当…

云解析DNS是什么?

说起云解析&#xff0c;相信很多用户都比较陌生&#xff0c;对于刚刚接触互联网的小白并不了解什么是云解析DNS&#xff0c;下面为您详解一下以上问题。 云解析DNS是什么 云解析 DNS(Domain Name System&#xff0c;简称 DNS) 一种安全、快速、稳定、可靠的权威 DNS 解析管理…