基于Transformer的多资产收益预测模型实战(附PyTorch实现与避坑指南)

基于Transformer的多资产收益预测模型实战(附PyTorch模型训练及可视化完整代码)


一、项目背景与目标

在量化投资领域,利用时间序列数据预测资产收益是核心任务之一。传统方法如LSTM难以捕捉资产间的复杂依赖关系,而Transformer架构通过自注意力机制能有效建模多资产间的联动效应。
本文将从零开始构建一个基于PyTorch的多资产收益预测模型,涵盖数据生成、特征工程、模型设计、训练及可视化全流程,适合深度学习与量化投资的初学者入门。

二、核心技术栈

  • 数据处理:Pandas/Numpy(数据生成与预处理)
  • 深度学习框架:PyTorch(模型构建与训练)
  • 可视化:Matplotlib(结果分析)
  • 核心算法:Transformer(自注意力机制)

三、数据生成与预处理

1. 模拟金融数据生成

我们通过以下步骤生成包含5只资产的时间序列数据:

  • 市场基准因子:模拟市场整体趋势(几何布朗运动)
  • 行业因子:引入周期性波动区分不同行业(如科技、消费、能源)
  • 特质因子:每只资产的独立噪声
def generate_market_data(days=2000, n_assets=5):  np.random.seed(42)  market = np.cumprod(1 + np.random.normal(0.0003, 0.015, days))  # 市场基准  assets = []  sector_map = {0: "Tech", 1: "Tech", 2: "Consume", 3: "Consume", 4: "Energy"}  for i in range(n_assets):  sector_factor = 0.3 * np.sin(i * 0.8 + np.linspace(0, 10 * np.pi, days))  # 行业周期因子  idiosyncratic = np.cumprod(1 + np.random.normal(0.0002, 0.02, days))  # 特质因子  price = market * (1 + sector_factor) * idiosyncratic  # 价格合成  assets.append(price)  dates = pd.date_range("2015-01-01", periods=days)  return pd.DataFrame(np.array(assets).T, index=dates, columns=[f"Asset_{i}" for i in range(n_assets)])  

2. 数据形状说明

生成的DataFrame形状为[2000天, 5资产],索引为时间戳,列名为Asset_0到Asset_4。

四、特征工程:从价格到可训练数据

1. 基础时间序列特征

为每只资产计算以下特征:

  • 收益率(Return):相邻日价格变化率
  • 波动率(Volatility):20日滚动标准差年化
  • 移动平均(MA10):10日价格移动平均
  • 行业相对强弱(Sector_RS):资产价格与所属行业平均价格的比值
def create_features(data, lookback=60):  n_assets = data.shape[1]  sector_map = {0: "Tech", 1: "Tech", 2: "Consume", 3: "Consume", 4: "Energy"}  features = []  for i, asset in enumerate(data.columns):  df = pd.DataFrame()  df["Return"] = data[asset].pct_change()  df["Volatility"] = df["Return"].rolling(20).std() * np.sqrt(252)  # 年化波动率  df["MA10"] = data[asset].rolling(10).mean()  # 计算行业相对强弱  sector = sector_map[i]  sector_cols = [col for col in data.columns if sector_map[int(col.split("_")[1])] == sector]  df["Sector_RS"] = data[asset] / data[sector_cols].mean(axis=1)  features.append(df.dropna())  # 去除NaN  # 对齐时间索引  common_idx = features[0].index  for df in features[1:]:  common_idx = common_idx.intersection(df.index)  features = [df.loc[common_idx] for df in features]  # 构建3D特征张量 [样本数, 时间步, 资产数, 特征数]  X = np.stack([np.stack([feat.iloc[i-lookback:i] for i in range(lookback, len(feat))], axis=0) for feat in features], axis=2)  # 标签:未来5日平均收益率  y = np.array([data.loc[common_idx].iloc[i:i+5].pct_change().mean().values for i in range(lookback, len(common_idx))])  return X, y  

2. 输入输出形状

  • 特征张量X形状:[样本数, 时间步(60), 资产数(5), 特征数(4)]
  • 标签y形状:[样本数, 资产数(5)](每个样本对应5只资产的未来5日平均收益率)

五、Transformer模型构建:核心架构解析

1. 模型设计目标

  • 处理多资产时间序列:同时输入5只资产的历史数据
  • 捕捉时间依赖资产间依赖:通过位置编码和自注意力机制
  • 输出多资产收益预测:回归问题,使用MSE损失

2. 关键组件解析

(1)资产嵌入层(Asset Embedding)

将每个资产的4维特征映射到64维隐空间:

self.asset_embed = nn.Linear(n_features=4, d_model=64)  

输入形状:(batch, seq_len, assets, features) → 输出:(batch, seq_len, assets, d_model)

(2)位置编码(Positional Embedding)

由于Transformer无内置时序信息,需手动添加位置编码:

self.time_pos = nn.Parameter(torch.randn(1, lookback=60, 1, d_model=64))  # 时间位置编码  
self.asset_pos = nn.Parameter(torch.randn(1, 1, n_assets=5, d_model=64))  # 资产位置编码  
  • 通过广播机制与资产嵌入相加,分别捕获时间和资产维度的位置信息。
(3)自定义Transformer编码器层(Custom Transformer Encoder Layer)

继承PyTorch原生层,返回注意力权重以可视化:

class CustomTransformerEncoderLayer(nn.TransformerEncoderLayer):  def __init__(self, d_model, nhead, dim_feedforward=256, dropout=0.1):  super(

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/80837.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

养生:打造健康生活的全方位策略

在生活节奏不断加快的当下,养生已成为提升生活质量、维护身心平衡的重要方式。从饮食、运动到睡眠,再到心态调节,各个方面的养生之道共同构建起健康生活的坚实基础。以下为您详细介绍养生的关键要点,助您拥抱健康生活。 饮食养生…

轻型汽车鼓式液压制动器系统设计

一、设计基础参数 1.1 整车匹配参数 参数项数值范围整备质量1200-1500kg最大设计车速160km/h轮胎规格195/65 R15制动法规要求GB 12676-2014 1.2 制动性能指标 制动减速度:≥6.2m/s(0型试验) 热衰退率:≤30%(连续10…

无法更新Google Chrome的解决问题

解决问题:原文链接:【百分百成功】Window 10 Google Chrome无法启动更新检查(错误代码为1:0x80004005) google谷歌chrome浏览器无法更新Chrome无法更新至最新版本? 下载了 就是更新Google Chrome了

【AAAI 2025】 Local Conditional Controlling for Text-to-Image Diffusion Models

Local Conditional Controlling for Text-to-Image Diffusion Models(文本到图像扩散模型的局部条件控制) 文章目录 内容摘要关键词作者及研究团队项目主页01 研究领域待解决问题02 论文解决的核心问题03 关键解决方案04 主要贡献05 相关研究工作06 解决…

Kuka AI音乐AI音乐开发「人声伴奏分离」 —— 「Kuka Api系列|中文咬字清晰|AI音乐API」第6篇

导读 今天我们来了解一下 Kuka API 的人声与伴奏分离功能。 所谓“人声伴奏分离”,顾名思义,就是将一段完整的音频拆分为两个独立的轨道:一个是人声部分,另一个是伴奏(乐器)部分。 这个功能在音乐创作和…

Idea 设置编码UTF-8 Idea中 .properties 配置文件中文乱码

Idea 设置编码UTF-8 Idea中 .properties 配置文件中文乱码 一、设置编码 1、步骤: File -> Setting -> Editor -> File encodings --> 设置编码二、配置文件中文乱码 1、步骤: File -> Setting -> Editor -> File encodings ->…

Xilinx FPGA PCIe | XDMA IP 核 / 应用 / 测试 / 实践

注:本文为 “Xilinx FPGA 中 PCIe 技术与 XDMA IP 核的应用” 相关文章合辑。 图片清晰度受引文原图所限。 略作重排,未整理去重。 如有内容异常,请看原文。 FPGA(基于 Xilinx)中 PCIe 介绍以及 IP 核 XDMA 的使用 N…

sqli—labs第六关——双引号报错注入

一:判断输入类型 首先测试 ?id1,?id1,?id1",页面回显均无变化 所以我们采用简单的布尔测试,分别测试数字型,单引号,双引号 然后发现,只有在测试到双引号注入的时候符合关键…

【TroubleShoot】禁用Unity Render Graph API 兼容模式

使用Unity 6时新建了项目,有一个警告提示: The project currently uses the compatibility mode where the Render Graph API is disabled. Support for this mode will be removed in future Unity versions. Migrate existing ScriptableRenderPasses…

图形学、人机交互、VR/AR、可视化等领域文献速读【持续更新中...】

(1)笔者在时间有限的情况下,想要多积累一些自身课题之外的新文献、新知识,所以开了这一篇文章。 (2)想通过将文献喂给大模型,并向大模型提问的方式来快速理解文献的重要信息(如基础i…

Hadoop-HDFS-Packet含义及作用

在 HDFS(Hadoop Distributed File System)中,Packet 是数据读写过程中用于数据传输的基本单位。它是 HDFS 客户端与数据节点(DataNode)之间进行数据交互时的核心概念,尤其在写入和读取文件时,Pa…

显示的图标跟UI界面对应不上。

图片跟UI界面不符合。 要找到对应dp的值。UI的dp要跟代码里的xml文件里的dp要对应起来。 蓝湖里设置一个宽度给对应上。然后把对应的值填入xml. 一个屏幕上的图片到底是用topmarin来设置,还是用bottommarin来设置。 因为第一节,5,7 车厢的…

【taro3 + vue3 + webpack4】在微信小程序中的请求封装及使用

前言 正在写一个 以taro3 vue3 webpack4为基础框架的微信小程序,之前一直没有记咋写的,现在总结记录一下。uniapp vite 的后面出。 文章目录 前言一、创建环境配置文件二、 配置 Taro 环境变量三、 创建请求封装四、如何上传到微信小程序体验版1.第二…

LeetCode:513、找树左下角的值

//递归法 /*** Definition for a binary tree node.* public class TreeNode {* int val;* TreeNode left;* TreeNode right;* TreeNode() {}* TreeNode(int val) { this.val val; }* TreeNode(int val, TreeNode left, TreeNode right) {* t…

采用均线策略来跟踪和投资基金

策略来源#睿思量化#小程序 截图来源#睿思量化#小程序 在基金投资中,趋势跟踪策略是一种备受关注的交易方法。本文将基于两张关于广发电子信息传媒股票 A(代码:005310)的图片资料,详细阐述这一策略的应用与效果。 从第…

leetcode刷题---二分查找

力扣题目链接 二分查找算法使用前提&#xff1a;有序数组&#xff1b;数组内无重复元素 易错点&#xff1a; 1.while循环的边界条件&#xff1a;如到底是 while(left < right) 还是 while(left < right) 2.if条件后right&#xff0c;left的取值&#xff1a;到底是 right …

(leetcode) 力扣100 10.和为K的子数组(前缀和+哈希)

题目 给你一个整数数组 nums 和一个整数 k &#xff0c;请你统计并返回 该数组中和为 k 的子数组的个数 。 子数组是数组中元素的连续非空序列。 数据范围 1 < nums.length < 2 * 104 -1000 < nums[i] < 1000 -107 < k < 107 样例 示例 1&#xff1a; 输…

遨游卫星电话与普通手机有什么区别?

在数字化浪潮席卷全球的今天&#xff0c;通信设备的角色早已超越传统语音工具&#xff0c;成为连接物理世界与数字世界的核心枢纽。然而&#xff0c;当普通手机在都市丛林中游刃有余时&#xff0c;面对偏远地区、危险作业场景的应急通信需求&#xff0c;其局限性便显露无遗。遨…

在Linux中如何使用Kill(),向进程发送发送信号

kill()函数 #include <sys/types.h> #include <signal.h> int kill(pid_t pid, int sig); 函数参数和返回值含义如下: pid:参数 pid 为正数的情况下,用于指定接收此信号的进程 pid;除此之外,参数 pid 也可设置为 0 或-1 以及小于-1 等不同值,稍后给说明。 …

Java SpringMVC 和 MyBatis 整合关键配置详解

目录 一、数据源配置二、MyBatis 工厂配置三、Mapper 扫描配置四、SpringMVC 配置五、整合示例实体类Mapper 接口Mapper XML 文件Service 类控制器JSP 页面六、总结在 Java Web 开发中,SpringMVC 和 MyBatis 是两个常用框架。SpringMVC 负责 Web 层的请求处理和视图渲染,MyBa…