rust-candle学习笔记11-实现一个简单的自注意力

参考:about-pytorch

定义ScaledDotProductAttention结构体:

use candle_core::{Result, Device, Tensor};
use candle_nn::{Linear, Module, linear_no_bias, VarMap, VarBuilder, ops};struct ScaledDotProductAttention {wq: Linear,wk: Linear,wv: Linear,d_model: Tensor,device: Device,
}

为ScaledDotProductAttention结构体实现new方法:

impl ScaledDotProductAttention {fn new(vb: VarBuilder, embedding_dim: usize, out_dim: usize, device: Device) -> Result<Self> {Ok(Self { wq: linear_no_bias(embedding_dim, out_dim, vb.pp("wq"))?, wk: linear_no_bias(embedding_dim, out_dim, vb.pp("wk"))?, wv: linear_no_bias(embedding_dim, out_dim, vb.pp("wv"))?,d_model: Tensor::new(embedding_dim as f32, &device)?,device,})}
}

为结构体实现Module的forward trait:

impl Module for ScaledDotProductAttention {fn forward(&self, xs: &Tensor) -> Result<Tensor> {let q = self.wq.forward(xs)?;let k = self.wk.forward(xs)?;let v = self.wv.forward(xs)?;let attn_score = q.matmul(&k.t()?)?;let attn_score = attn_score.broadcast_div(&self.d_model.sqrt()?)?;let dim = attn_score.rank() - 1;let attn_weights = ops::softmax(&attn_score, dim)?;let attn_output = attn_weights.matmul(&v)?;Ok(attn_output)}
}

融合qkv实现:

定义ScaledDotProductAttentionFusedQKV结构体:

struct ScaledDotProductAttentionFusedQKV {w_qkv: Linear,d_model: Tensor,device: Device,
}

为结构体实现new方法:

impl ScaledDotProductAttentionFusedQKV {fn new(vb: VarBuilder, embedding_dim: usize, out_dim: usize, device: Device) -> Result<Self> {Ok(Self { w_qkv: linear_no_bias(embedding_dim, 3*out_dim, vb.pp("w_qkv"))?,d_model: Tensor::new(embedding_dim as f32, &device)?,device,})}
}

为结构体实现forward trait:

impl Module for ScaledDotProductAttentionFusedQKV {fn forward(&self, xs: &Tensor) -> Result<Tensor> {let qkv = self.w_qkv.forward(xs)?;let (batch_size, seq_len, _) = qkv.dims3()?;let qkv = qkv.reshape((batch_size, seq_len, 3, ()))?;let q = qkv.get_on_dim(2, 0)?;let q = q.reshape((batch_size, seq_len, ()))?;let k = qkv.get_on_dim(2, 1)?;let k = k.reshape((batch_size, seq_len, ()))?;let v = qkv.get_on_dim(2, 2)?;let v = v.reshape((batch_size, seq_len, ()))?;let attn_score = q.matmul(&k.t()?)?;let attn_score = attn_score.broadcast_div(&self.d_model.sqrt()?)?;let dim = attn_score.rank() - 1;let attn_weights = ops::softmax(&attn_score, dim)?;let attn_output = attn_weights.matmul(&v)?;Ok(attn_output)}
}

测试:

fn main() -> Result<()> {let device = Device::cuda_if_available(0)?;let varmap = VarMap::new();let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);let input = Tensor::from_vec(vec![0.43f32, 0.15, 0.89, 0.55, 0.87, 0.66,0.57, 0.85, 0.64,0.22, 0.58, 0.33,0.77, 0.25, 0.10,0.05, 0.80, 0.55, 0.43, 0.15, 0.89, 0.55, 0.87, 0.66,0.57, 0.85, 0.64,0.22, 0.58, 0.33,0.77, 0.25, 0.10,0.05, 0.80, 0.55], (2, 6, 3), &device)?;// let model = ScaledDotProductAttention::new(vb.clone(), 3, 2, device.clone())?;let model = ScaledDotProductAttentionFusedQKV::new(vb.clone(), 3, 2, device.clone())?;let output = model.forward(&input)?;println!("output: {:?}\n", output);println!("output: {:?}\n", output.to_vec3::<f32>()?);Ok(())
}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/80724.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

spark MySQL数据库配置

Spark 连接 MySQL 数据库的配置 要让 Spark 与 MySQL 数据库实现连接&#xff0c;需要进行以下配置步骤。下面为你提供详细的操作指南和示例代码&#xff1a; 1. 添加 MySQL JDBC 驱动依赖 你得把 MySQL 的 JDBC 驱动添加到 Spark 的类路径中。可以通过以下两种方式来完成&a…

web 自动化之 KDT 关键字驱动详解

一、什么是关键字驱动&#xff1f; 1、什么是关键字驱动&#xff1f;&#xff08;以关键字函数驱动测试&#xff09; 关键字驱动又叫动作字驱动&#xff0c;把项目业务封装成关键字函数&#xff0c;再基于关键字函数实现自动化测试 2、关键字驱动测试原理 关键字驱动测试是一…

Java使用POI+反射灵活的控制字段导出Excel

前端传入哪些字段&#xff0c;后端就导出哪些到Excel表格中&#xff0c;具体代码实现如下 controller /*** 用户导出* param dto*/PostMapping("/exportUser")public void exportCharterOrder(RequestBody UserExportDTO dto){userService.exportUser(dto);} serv…

Qt/C++面试【速通笔记八】—Qt的事件处理机制

在Qt中&#xff0c;事件处理机制是应用程序与用户或系统交互的核心。通过事件处理&#xff0c;Qt能够响应用户的输入、窗口的变化、定时器的触发等各种情况。 1. 事件循环&#xff08;Event Loop&#xff09; 在Qt应用程序中&#xff0c;事件循环是事件处理机制的基础。事件循…

TTL (Time-To-Live) 解析

文章目录 TTL (Time-To-Live) 解析&#xff1a;网络与Java中的应用一、TTL的定义二、TTL在网络中的应用1. **路由和数据包的生命周期**2. **DNS中的TTL**3. **防止环路** 三、TTL在Java中的应用1. **缓存管理**2. **Java中的ThreadLocal**3. **网络通信中的TTL** 四、TTL的注意…

HDFS的客户端操作(2)文件上传

我们向/maven下上传一个文件。 要用到的api是put (或者copyFormLocalFile&#xff09;。核心代码如下。 public void testCopyFromLocalFile() throws IOException, InterruptedException, URISyntaxException {// 1 获取文件系统Configuration configuration new Configurati…

光谱相机的光电信号转换

光谱相机的光电信号转换是将分光后的光学信息转化为可处理的数字信号的核心环节&#xff0c;具体分为以下关键步骤&#xff1a; 一、分光后光信号接收与光电转换 ‌分光元件作用‌ 光栅/棱镜/滤光片等分光元件将入射光分解为不同波长单色光&#xff0c;投射至探测器阵列表面…

网络协议分析 实验二 IP分片与IPv6

文章目录 索引及重要内容实验2 IP 高级实验实验2.1 IPv4协议分片实验实验2.2 IPV6协议实验2.3 ARP初级 索引及重要内容 实验2 IP 高级实验 实验2.1 IPv4协议分片实验 icmp的不可达报文 实验2.2 IPV6协议 实验2.3 ARP初级 arp –a 查看ARP缓存表内容 arp –s IP地址(格式&…

20、map和set、unordered_map、un_ordered_set的复现

一、map 1、了解 map的使用和常考面试题等等&#xff0c;看这篇文章 map的key是有序的 &#xff0c;值不可重复 。插入使用 insert的效率更高&#xff0c;而在"更新map的键值对时&#xff0c;使用 [ ]运算符效率更高 。" 注意 map 的lower和upper那2个函数&#x…

基于 Amazon Bedrock 和 Amazon Connect 打造智能客服自助服务 – 设计篇

随着 GenAI 技术不断的发展和演进&#xff0c;人工智能技术广泛地被应用在呼叫中心服务领域&#xff0c;主要包括虚拟坐席&#xff08;即自助服务&#xff09;、坐席助手和呼叫中心运营的数据洞察和智能分析。本博客主要针对自助服务应用场景的实现。 1. 传统自助服务系统瓶颈 …

java高效实现爬虫

一、前言 在Web爬虫技术中&#xff0c;Selenium作为一款强大的浏览器自动化工具&#xff0c;能够模拟真实用户操作&#xff0c;有效应对JavaScript渲染、Ajax加载等复杂场景。而集成代理服务则能够解决IP限制、地域访问限制等问题。本文将详细介绍如何利用JavaSelenium快代理实…

【计算机视觉】OpenCV实战项目:基于OpenCV的车牌识别系统深度解析

基于OpenCV的车牌识别系统深度解析 1. 项目概述2. 技术原理与算法设计2.1 图像预处理1) 自适应光照补偿2) 边缘增强 2.2 车牌定位1) 颜色空间筛选2) 形态学操作3) 轮廓分析 2.3 字符分割1) 投影分析2) 连通域筛选 2.4 字符识别 3. 实战部署指南3.1 环境配置3.2 项目代码解析 4.…

Python核心数据类型全解析:字符串、列表、元组、字典与集合

导读&#xff1a; Python 是一门功能强大且灵活的编程语言&#xff0c;而其核心数据类型是构建高效程序的基础。本文深入剖析了 Python 的五大核心数据类型——字符串、列表、元组、字典和集合&#xff0c;结合实际应用场景与最佳实践&#xff0c;帮助读者全面掌握这些数据类型…

GPT-4.1和GPT-4.1-mini系列模型支持微调功能,助力企业级智能应用深度契合业务需求

微软继不久前发布GPT-4.1系列模型后&#xff0c;Azure OpenAI服务&#xff08;国际版&#xff09;现已正式开放对GPT-4.1和GPT-4.1-mini的微调功能&#xff0c;并通过Azure AI Foundry&#xff08;国际版&#xff09;提供完整的部署和管理解决方案。这一重大升级标志着企业级AI…

构造+简单树状

昨日的牛客周赛算是比较简单的&#xff0c;其中最后一道构造题目属实眼前一亮。 倒数第二个题目也是一个很好的模拟题目&#xff08;考验对二叉树的理解和代码的细节&#xff09; 给定每一层的节点个数&#xff0c;自己拟定一个父亲节点&#xff0c;构造一个满足条件的二叉树。…

apache2的默认html修改

使用127.0.0.1的时候&#xff0c;默认打开的是index.html&#xff0c;可以通过配置文件修改成我们想要的html vi /etc/apache2/mods-enabled/dir.conf <IfModule mod_dir.c>DirectoryIndex WS.html index.html index.cgi index.pl index.php index.xhtml index.htm <…

mysql性能提升方法大汇总

前言 最近在开发自己的小程序的时候&#xff0c;由于业务功能对系统性能的要求很高&#xff0c;系统性能损耗又主要在mysql上&#xff0c;而业务功能的数据表很多&#xff0c;单表数据量也很大&#xff0c;又涉及到很多场景的数据查询&#xff0c;所以我针对mysql调用做了优化…

多模态RAG与LlamaIndex——1.deepresearch调研

摘要 关键点&#xff1a; 多模态RAG技术通过结合文本、图像、表格和视频等多种数据类型&#xff0c;扩展了传统RAG&#xff08;检索增强生成&#xff09;的功能。LlamaIndex是一个开源框架&#xff0c;支持多模态RAG&#xff0c;提供处理文本和图像的模型、嵌入和索引功能。研…

LabVIEW中算法开发的系统化解决方案与优化

在 LabVIEW 开发环境中&#xff0c;算法实现是连接硬件数据采集与上层应用的核心环节。由于图形化编程范式与传统文本语言存在差异&#xff0c;LabVIEW 中的算法开发需要特别关注执行效率、代码可维护性以及与硬件资源的适配性。本文从算法架构设计、性能优化到工程实现&#x…

OpenCV中的光流估计方法详解

文章目录 一、引言二、核心算法原理1. 光流法基本概念2. 算法实现步骤 三、代码实现详解1. 初始化设置2. 特征点检测3. 光流计算与轨迹绘制 四、实际应用效果五、优化方向六、结语 一、引言 在计算机视觉领域&#xff0c;运动目标跟踪是一个重要的研究方向&#xff0c;广泛应用…