nlp培训重点

SGD梯度下降公式:

\theta _{t+1}=\theta _{t} - lr * \frac{\partial f}{\partial \theta _{t}}

当梯度大于0时,\theta _{t}变小,往左边找梯度接近0的值。

当梯度小于0时,\theta _{t}减去一个负数会变大,往右边找梯度接近0的值,此时梯度从负数到0上升

#coding:utf8import torch
import torch.nn as nn
import numpy as np
import copy"""
基于pytorch的网络编写
手动实现梯度计算和反向传播
加入激活函数
"""class TorchModel(nn.Module):def __init__(self, hidden_size):super(TorchModel, self).__init__()self.layer = nn.Linear(hidden_size, hidden_size, bias=False) #w = hidden_size * hidden_size  wx+b -> wxself.activation = torch.sigmoidself.loss = nn.functional.mse_loss  #loss采用均方差损失#当输入真实标签,返回loss值;无真实标签,返回预测值def forward(self, x, y=None):y_pred = self.layer(x)y_pred = self.activation(y_pred)if y is not None:return self.loss(y_pred, y)else:return y_pred#自定义模型,接受一个参数矩阵作为入参
class DiyModel:def __init__(self, weight):self.weight = weightdef forward(self, x, y=None):x = np.dot(x, self.weight.T)y_pred = self.diy_sigmoid(x)if y is not None:return self.diy_mse_loss(y_pred, y)else:return y_pred#sigmoiddef diy_sigmoid(self, x):return 1 / (1 + np.exp(-x))#手动实现mse,均方差lossdef diy_mse_loss(self, y_pred, y_true):return np.sum(np.square(y_pred - y_true)) / len(y_pred)#手动实现梯度计算def calculate_grad(self, y_pred, y_true, x):#前向过程# wx = np.dot(self.weight, x)# sigmoid_wx = self.diy_sigmoid(wx)# loss = self.diy_mse_loss(sigmoid_wx, y_true)#反向过程# 均方差函数 (y_pred - y_true) ^ 2 / n 的导数 = 2 * (y_pred - y_true) / n , 结果为2维向量grad_mse = 2/len(x) * (y_pred - y_true)# sigmoid函数 y = 1/(1+e^(-x)) 的导数 = y * (1 - y), 结果为2维向量grad_sigmoid = y_pred * (1 - y_pred)# wx矩阵运算,见ppt拆解, wx = [w11*x0 + w21*x1, w12*x0 + w22*x1]#导数链式相乘grad_w11 = grad_mse[0] * grad_sigmoid[0] * x[0]grad_w12 = grad_mse[1] * grad_sigmoid[1] * x[0]grad_w21 = grad_mse[0] * grad_sigmoid[0] * x[1]grad_w22 = grad_mse[1] * grad_sigmoid[1] * x[1]grad = np.array([[grad_w11, grad_w12],[grad_w21, grad_w22]])#由于pytorch存储做了转置,输出时也做转置处理return grad.T#梯度更新
def diy_sgd(grad, weight, learning_rate):return weight - learning_rate * grad#adam梯度更新
def diy_adam(grad, weight):#参数应当放在外面,此处为保持后方代码整洁简单实现一步alpha = 1e-3  #学习率beta1 = 0.9   #超参数beta2 = 0.999 #超参数eps = 1e-8    #超参数t = 0         #初始化mt = 0        #初始化vt = 0        #初始化#开始计算t = t + 1gt = gradmt = beta1 * mt + (1 - beta1) * gtvt = beta2 * vt + (1 - beta2) * gt ** 2mth = mt / (1 - beta1 ** t)vth = vt / (1 - beta2 ** t)weight = weight - (alpha * mth/ (np.sqrt(vth) + eps))return weightx = np.array([-0.5, 0.1])  #输入
y = np.array([0.1, 0.2])  #预期输出#torch实验
torch_model = TorchModel(2)
torch_model_w = torch_model.state_dict()["layer.weight"]
print(torch_model_w, "初始化权重")
numpy_model_w = copy.deepcopy(torch_model_w.numpy())
#numpy array -> torch tensor, unsqueeze的目的是增加一个batchsize维度
torch_x = torch.from_numpy(x).float().unsqueeze(0) 
torch_y = torch.from_numpy(y).float().unsqueeze(0)
#torch的前向计算过程,得到loss
torch_loss = torch_model(torch_x, torch_y)
print("torch模型计算loss:", torch_loss)
# #手动实现loss计算
diy_model = DiyModel(numpy_model_w)
diy_loss = diy_model.forward(x, y)
print("diy模型计算loss:", diy_loss)# # #设定优化器
learning_rate = 0.1
# optimizer = torch.optim.SGD(torch_model.parameters(), lr=learning_rate)
optimizer = torch.optim.Adam(torch_model.parameters())
# optimizer.zero_grad()
# #
# # #pytorch的反向传播操作
torch_loss.backward()
print(torch_model.layer.weight.grad, "torch 计算梯度")  #查看某层权重的梯度# # #手动实现反向传播
grad = diy_model.calculate_grad(diy_model.forward(x), y, x)
print(grad, "diy 计算梯度")
# #
# #torch梯度更新
optimizer.step()
# # #查看更新后权重
update_torch_model_w = torch_model.state_dict()["layer.weight"]
print(update_torch_model_w, "torch更新后权重")
# #
# # #手动梯度更新
# diy_update_w = diy_sgd(grad, numpy_model_w, learning_rate)
diy_update_w = diy_adam(grad, numpy_model_w)
print(diy_update_w, "diy更新权重")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/62766.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Qt5语法的connect指定多个重载信号槽函数中的具体某一个

Qt5新语法的connect函数,使用起来更加简洁明了,但如果信号槽有同名的多个重载函数,只用类名和函数名就无法绑定,这时,可以使用qOverload来指定参数类型,例如: connect(ui->comboBox, qOverlo…

如何在Spark中使用gbdt模型分布式预测

这目录 1 训练gbdt模型2 第三方包python环境打包3 Spark中使用gbdt模型3.1 spark配置文件3.2 主函数main.py 4 spark任务提交 1 训练gbdt模型 我们可以基于lightgbm快速的训练一个gbdt模型,训练相对比较简单,只要把训练样本处理好,几行代码可…

38 基于单片机的宠物喂食(ESP8266、红外、电机)

目录 一、主要功能 二、硬件资源 三、程序编程 四、实现现象 一、主要功能 基于STC89C52单片机,采用L298N驱动连接P2.3和P2.4口进行电机驱动, 然后串口连接P3.0和P3.1模拟ESP8266, 红外传感器连接ADC0832数模转换器连接单片机的P1.0~P1.…

Python 【图像分类】之 PyTorch 进行猫狗分类功能的实现(Swanlab训练可视化/ Gradio 实现猫狗分类 Demo)

Python 【图像分类】之 PyTorch 进行猫狗分类功能的实现(Swanlab训练可视化/ Gradio 实现猫狗分类 Demo) 目录 Python 【图像分类】之 PyTorch 进行猫狗分类功能的实现(Swanlab训练可视化/ Gradio 实现猫狗分类 Demo) 一、简单介绍 二、PyTorch 三、CNN 1、神经网络 2、卷…

HTML5系列(7)-- Web Storage 实战指南

前端技术探索系列:HTML5 Web Storage 实战指南 🗄️ 致读者:本地存储的新纪元 👋 前端开发者们, 今天我们将深入探讨 HTML5 中的 Web Storage 技术,这是一个强大的本地存储解决方案,让我们能…

week 6 - SQL Select II

Overview 1. Joins 包括交叉连接(Cross)、内连接(Inner)、自然连接(Natural)、外连接(Outer) 2. ORDER BY to produce ordered output 3. 聚合函数(Aggregate Functio…

算法训练营day23(二叉树09:修建二叉搜索树,有序数组转化为平衡二叉搜索树,二叉搜索树转化为累加树,二叉树专题总结)

第六章 二叉树part09今日内容:● 669. 修剪二叉搜索树 ● 108.将有序数组转换为二叉搜索树 ● 538.把二叉搜索树转换为累加树 ● 总结篇 详细布置 669. 修剪二叉搜索树 这道题目比较难,比 添加增加和删除节点难的多,建议先看视频理解。题目…

C语言操作符深度解析

目录 一、操作符的分类 1、算术操作符 1、1、 和- 1、2、* 1、3、/ 1、4、% 2、赋值操作符:和复合赋值 2、1、连续赋值 2、2、复合赋值符 3、单⽬操作符:、--、、- 3、1、和-- 3、1、1、前置 3、1、2、后置 3、2、1、前置-- 3、2、2、后…

打造高质量技术文档的关键要素(结合MATLAB)

在技术的浩瀚海洋中,一份优秀的技术文档宛如精准的航海图。它不仅是知识传承的载体,也是团队协作的桥梁,更是产品成功的幕后英雄。打造出色的技术文档并非易事,以下将从多个方向探讨如何做到这一点。 文章目录 方向一:…

《C++与人工智能:照亮能源可持续发展之路》

在全球对能源需求持续攀升以及对可持续发展日益重视的当下,如何有效解决能源领域的复杂问题成为了亟待攻克的关键挑战。而 C与人工智能技术的融合,正犹如一盏明灯,为能源管理、可再生能源预测等方面开辟出全新的路径,有力地推动着…

Python 深度学习框架之Keras库详解

文章目录 Python 深度学习框架之Keras库详解一、引言二、Keras的特点和优势1、用户友好2、多网络支持3、跨平台运行 三、Keras的安装和环境配置1、软硬件环境2、Python虚拟环境 四、使用示例1、MNIST手写数字识别 五、总结 Python 深度学习框架之Keras库详解 一、引言 Keras是…

【大语言模型】ACL2024论文-23 检索增强的多语言知识编辑

【大语言模型】ACL2024论文-23 检索增强的多语言知识编辑 目录 文章目录 【大语言模型】ACL2024论文-23 检索增强的多语言知识编辑目录摘要研究背景问题与挑战如何解决核心创新点算法模型实验效果(包含重要数据与结论)相关工作后续优化方向 后记 检索增强…

android user版本默认usb模式为充电模式

android插入usb时会切换至默认设置的模式,debug版本为adb,user版本为mtp protected long getChargingFunctions() {// if ADB is enabled, reset functions to ADB// else enable MTP as usual.if (isAdbEnabled()) {return UsbManager.FUNCTION_ADB;} e…

_C#_串口助手_字符串拼接缺失问题(未知原理)

最近使用WPF开发串口助手时,遇到一个很奇怪的问题,无论是主线程、异步还是多线程,当串口接收速度达到0.016s一次以上,就会发生字符串缺失问题并且很卡。而0.016s就一切如常,仿佛0.015s与0.016s是天堑之隔。 同一份代码…

CF Round988 题解报告

/***实力还是要努力 D 赛时我过了&#xff0c;就不讲了&#xff0c;毕竟我过的也大概是简单题&#xff1b; 代码&#xff1a; #include<iostream> #include<queue> using namespace std; #define int long long int t; int n,m,l; struct hurdle{int l,r,len; …

基于Python的猎聘网招聘数据采集与可视化分析

1.1项目简介 在现代社会&#xff0c;招聘市场的竞争日趋激烈&#xff0c;企业和求职者都希望能够更有效地找到合适的机会与人才。猎聘网作为国内领先的人力资源服务平台&#xff0c;汇聚了大量的招聘信息和求职者数据&#xff0c;为研究招聘市场趋势提供了丰富的素材。基于Pyt…

HTML 常用标签属性汇总一<input> 标签

1.可选的属性 DTD 指示此属性允许在哪种 DTD 中使用.SStrict, TTransitional, FFrameset。 属性 值 描述 DTD accept mime_type 规定通过文件上传来提交的文件的类型。 STF align ​​​​​​​ left center right middle bottom 不赞成使用。规定图像输入的对齐方式…

基于Java Springboot高校社团微信小程序

一、作品包含 源码数据库设计文档万字PPT全套环境和工具资源部署教程 二、项目技术 前端技术&#xff1a;Html、Css、Js、Vue、Element-ui 数据库&#xff1a;MySQL 后端技术&#xff1a;Java、Spring Boot、MyBatis 三、运行环境 开发工具&#xff1a;IDEA/eclipse 微信…

springboot(20)(删除文章分类。获取、更新、删除文章详细)(Validation分组校验)

目录 一、删除文章分类功能。 &#xff08;1&#xff09;接口文档。 1、请求路径、请求参数。 2、请求参数。 3、响应数据。 &#xff08;2&#xff09;实现思路与代码书写。 1、controller层。 2、service接口业务层。 3、serviceImpl实现类。 4、mapper层。 5、后端接口测试。…

vim 显示行数和删除内容操作

在 Vim 中&#xff0c;显示行数和删除内容是两个常见的操作&#xff0c;结合使用可以帮助你更加高效地编辑文件。以下是关于如何在 Vim 中显示行数和删除内容的详细说明&#xff1a; 1. 显示行数 显示绝对行号 绝对行号会显示每一行的实际行号&#xff0c;适合你查看文件的大…