rust-candle学习笔记10-使用Embedding

参考:about-pytorch

candle-nn提供embedding()初始化Embedding方法:

pub fn embedding(in_size: usize, out_size: usize, vb: crate::VarBuilder) -> Result<Embedding> {let embeddings = vb.get_with_hints((in_size, out_size),"weight",crate::Init::Randn {mean: 0.,stdev: 1.,},)?;Ok(Embedding::new(embeddings, out_size))
}

 candle Embedding初体验:

其中Tokenizer和dataset的构造详情参考:rust-candle学习笔记9-使用tokenizers加载qwen3分词,使用分词器处理文本

use candle_nn::{embedding, Embedding, Module, VarBuilder, VarMap};fn main() -> Result<()> {let tokenizer = Tokenizer::from_file("assets/qwen3/tokenizer.json")?;let vocab_size = tokenizer.get_vocab_size(true);let text = read_txt("assets/the-verdict.txt")?;let device = Device::cuda_if_available(0)?;let dataset = TokenDataset::new(text, tokenizer, 32, 16, device.clone())?;let (inputs, targets) = dataset.get_item(0)?;println!(" inputs: {:?}\n", inputs);println!(" targets: {:?}\n", targets);let len = dataset.len();println!("{:?}", len);let varmap = VarMap::new();let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);let embedding = embedding(vocab_size, 5, vb)?;let x_embedding = embedding.forward(&inputs)?;let y_embedding = embedding.forward(&targets)?;println!(" x_embedding: {:?}\n", x_embedding);println!("{:?}", x_embedding.to_vec2::<f32>()?);println!(" y_embedding: {:?}\n", y_embedding);println!("{:?}", y_embedding.to_vec2::<f32>()?);Ok(())
}

实现正余弦位置编码:

struct PositionEmbedding {pos_embedding: Tensor,device: Device
}
impl PositionEmbedding {fn new(seq_len: usize, embedding_dim: usize, device: Device) -> Result<Self> {if embedding_dim % 2 != 0 {return Err(Box::new(candle_core::Error::msg("embedding_dim must be even")));}let mut pos_embedding_vec: Vec<f32> = Vec::with_capacity(seq_len * embedding_dim);let w_const: f32 = 10000.0;for t in 0..seq_len {let i_max = embedding_dim / 2;for i in 0..i_max {let denominator = w_const.powf(2.0 * i as f32 / embedding_dim as f32);let pos_sin_i = (t as f32 / denominator).sin();let pos_cos_i = (t as f32 / denominator).cos();pos_embedding_vec.push(pos_sin_i);pos_embedding_vec.push(pos_cos_i);}}let pos_embedding = Tensor::from_vec(pos_embedding_vec, (seq_len, embedding_dim), &device)?;Ok(Self { pos_embedding, device })}
}

测试:

注意:candle 不同维度tensor相加直接用+会报错,要显示的调用广播加,高维tensor和低维tensor谁加谁都可以

fn main() -> Result<()> {let tokenizer = Tokenizer::from_file("assets/qwen3/tokenizer.json")?;let vocab_size = tokenizer.get_vocab_size(true);let text = read_txt("assets/the-verdict.txt")?;let device = Device::cuda_if_available(0)?;let seq_len = 32;let dataset = TokenDataset::new(text, tokenizer, seq_len, 16, device.clone())?;let batch_size: usize = 6;let mut loader = DataLoader::new(dataset, batch_size, true);loader.reset();let (x, y) = loader.next().unwrap()?;let varmap = VarMap::new();let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &device);let embedding_dim: usize = 256;let embedding = embedding(vocab_size, embedding_dim, vb)?;let x_embedding = embedding.forward(&x)?;let y_embedding = embedding.forward(&y)?;println!(" x_embedding: {:?}\n", x_embedding);println!(" y_embedding: {:?}\n", y_embedding);let pos_embedding = PositionEmbedding::new(seq_len, embedding_dim, device.clone())?;let pos_emb = pos_embedding.pos_embedding;// candle 不同维度tensor相加直接用+会报错,// 广播加要显示的调用// 下面两种方式都可以let x_input = x_embedding.broadcast_add(&pos_emb)?;// let x_input = pos_emb.broadcast_add(&x_embedding)?;println!(" x_input: {:?}\n", x_input);Ok(())
}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/80412.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python小酷库系列:Munch,用对象的访问方式访问dict

Munch&#xff0c;用对象的访问方式访问dict 基本使用1、创建一个 Munch 对象2、使用字典初始化3、访问不存在的字段4、嵌套结构支持5、合并操作6、应用场景说明 进阶功能1、嵌套写入&#xff1a;创建不存在的子对象2、序列化&#xff08;转回 dict&#xff09;3、深度拷贝结构…

对称加密以及非对称加密

对称加密和非对称加密是两种不同的加密方式&#xff0c;它们在加密原理、密钥管理、安全性和性能等方面存在区别&#xff0c;以下是具体分析&#xff1a; 加密原理 对称加密&#xff1a;通信双方使用同一把密钥进行加密和解密。就像两个人共用一把钥匙&#xff0c;用这把钥匙锁…

[JAVAEE]HTTP协议(2.0)

响应报文格式 响应报文格式由首行&#xff0c;响应头&#xff08;header&#xff09;&#xff0c;空行&#xff0c;正文&#xff08;body&#xff09; 组成 响应报文首行包括 1.版本号 如HTTP/1.1 2.状态码(如200) 描述了请求的结果 3.状态码描述(如OK) 首行——状态码…

Spring Boot 之MCP Server开发全介绍

Spring AI 的 MCP(模型上下文协议,Model Context Protocol)服务器启动器为在 Spring Boot 应用程序中设置 MCP 服务器提供了自动配置功能。它使得 MCP 服务器功能能够与 Spring Boot 的自动配置系统实现无缝集成。 MCP 服务器启动器具备以下特性: MCP 服务器组件的自动配置…

YOLOv8 对象检测任务的标注、训练和部署过程

YOLOv8 对象检测任务的标注、训练和部署过程 在计算机视觉领域&#xff0c;对象检测是一项基础且重要的任务&#xff0c;YOLOv8 作为当前先进的实时对象检测模型&#xff0c;以其高效性和准确性受到广泛关注。从数据准备到最终模型部署&#xff0c;整个流程包含多个关键环节&a…

电池热管理CFD解决方案,为新能源汽车筑安全防线

在全球能源结构加速转型的大背景下&#xff0c;新能源汽车产业异军突起&#xff0c;成为可持续发展的重要驱动力。而作为新能源汽车 “心脏” 的电池系统&#xff0c;其热管理技术的优劣&#xff0c;直接决定了车辆的安全性、续航里程和使用寿命。电池在充放电过程中会产生大量…

Redis 数据类型:掌握 NoSQL 的基石

Redis (Remote Dictionary Server) 是一种开源的、内存中的数据结构存储系统&#xff0c;通常用作数据库、缓存和消息代理。 它的高性能和丰富的数据类型使其成为现代应用程序开发中不可或缺的一部分。 本文将深入探讨 Redis 的核心数据类型&#xff0c;帮助你更好地理解和利用…

MLX-Audio:高效音频合成的新时代利器

MLX-Audio&#xff1a;高效音频合成的新时代利器 现代社会的快节奏生活中&#xff0c;对语音技术的需求越来越高。无论是个性化语音助手&#xff0c;还是内容创作者所需的高效音频生成工具&#xff0c;语音技术都发挥着不可或缺的作用。今天&#xff0c;我们将介绍一个创新的开…

Kafka单机版安装部署

目录 1.1、概述1.2、系统环境1.3、ZooKeeper的作用1.4、部署流程1.4.1、下载安装包1.4.2、解压文件1.4.3、创建日志目录1.4.4、配置Kafka1.4.5、启动Kafka服务1.4.6、启动成功验证 1.5、创建Topic测试1.6、消息生产与消费测试1.6.1、启动生产者1.6.2、启动消费者 1.1、概述 Kaf…

【C++设计模式之Observer观察者模式】

Observer观察者模式 模式定义动机(Motivation)结构(Structure)应用场景一&#xff08;气象站&#xff09;实现步骤1.定义观察者接口2.定义被观察者(主题)接口3.实现具体被观察者对象(气象站)4.实现具体观察者(例如&#xff1a;显示屏)5.main.cpp中使用示例6.输出结果7. 关键点 …

资产月报怎么填?资产月报填报指南

资产月报是企业对固定资产进行定期检查和管理的重要工具&#xff0c;它能够帮助管理者了解资产的使用情况、维护状况和财务状况&#xff0c;从而为资产的优化配置和决策提供依据。填写资产月报时&#xff0c;除了填报内容外&#xff0c;还需要注意格式的规范性和数据的准确性。…

UG471 之 SelectIO 逻辑资源

背景 《ug471》介绍了Xilinx 7 系列 SelectIO 的输入/输出特性及逻辑资源的相关内容。 第 1 章《SelectIO Resources》介绍了输出驱动器和输入接收器的电气特性&#xff0c;并通过大量实例解析了各类标准接口的实现。 第 2 章《SelectIO Logic Resources》介绍了输入输出数据…

C++ 内存泄漏相关

ASAN 参考链接 https://blog.csdn.net/wonengguwozai/article/details/129593186https://www.cnblogs.com/greatsql/p/16256926.htmlhttps://zhuanlan.zhihu.com/p/700505587小demo // leak.c #include <stdio.h> #include <stdlib.h> #include <string.h>…

计算人声录音后电平的大小(dB SPL->dBFS)

计算人声录音后电平的大小 这里笔记记录一下&#xff0c;怎么计算已知大小的声音&#xff0c;经过麦克风、声卡录制后软件内录得的音量电平值。&#xff08;文章最后将计算过程整理为Python代码&#xff0c;方便复用&#xff09; 假设用正常说话的声音大小65dB&#xff08;SP…

【MySQL数据库】C/C++连接数据库

MySQL要想在C/C下使用&#xff0c;就必须要有 MySQL 提供的头文件和相关的库。 在Ubuntu系统上&#xff0c;使用 apt install mysql-server 安装MySQL服务器后&#xff0c;仅安装了MySQL数据库服务本身&#xff0c;并没有安装MySQL开发所需的库和头文件。因此&#xff0c;在尝试…

Kubernetes调度策略深度解析:NodeSelector与NodeAffinity的正确打开方式

在Kubernetes集群管理中&#xff0c;如何精准控制Pod的落点&#xff1f;本文将深入解析两大核心调度策略的差异&#xff0c;并通过生产案例教你做出正确选择。 一、基础概念快速理解 1.1 NodeSelector&#xff08;节点选择器&#xff09; 核心机制&#xff1a;通过标签硬匹配…

Golang的linux运行环境的安装与配置

很多新手在学go时&#xff0c;linux下的配置环境一头雾水&#xff0c;总结下&#xff0c;可供参考&#xff01; --------------------------------------Golang的运行环境的安装与配置-------------------------------------- 将压缩包放在/home/tools/下 解压 tar -zxvf g…

自定义实现elementui的锚点

背景 前不久有个需求&#xff0c;上半部分是el-step步骤条&#xff0c;下半部分是一些文字说明&#xff0c;需要实现点击步骤条中某个步骤自定义定位到对应部分的文字说明&#xff0c;同时滚动内容区域的时候还要自动选中对应区域的步骤。element-ui-plus的有锚点这个组件&…

Oracle Fusion常用表

模块表名表描述字段说明sodoo_headers_all销售订单头表sodoo_lines_all销售订单行表sodoo_fulfill_lines_all销售订单明细行表popo_headers_all采购订单头表popo_lines_all采购订单行表popo_line_locations_all采购订单分配表popo_distributions_all采购订单发运表invEGP_SYSTE…

面试常问系列(一)-神经网络参数初始化-之-softmax

背景 本文内容还是对之前关于面试题transformer的一个延伸&#xff0c;详细讲解一下softmax 面试常问系列(二)-神经网络参数初始化之自注意力机制-CSDN博客 Softmax函数的梯度特性与输入值的幅度密切相关&#xff0c;这是Transformer中自注意力机制需要缩放点积结果的关键原…