Lag-Llama时间序列模型简单实现数据预测

前言:

最近在大模型预测,简单了解了lag-llama开源项目,网上也有很多讲解原理的,这里就将如何快速上手使用说一下,只懂得一点点皮毛,有错误的地方欢迎大佬指出。

简单介绍:

Lag-Llama 是一个开源的时间序列预测模型,基于 Transformer 架构设计,专注于利用 滞后特征(Lagged Features) 捕捉时间序列的长期依赖关系。其核心思想是将传统时间序列分析中的滞后算子(Lags)与现代深度学习结合,实现对复杂时序模式的高效建模。

GitHup地址:GitHub - time-series-foundation-models/lag-llama: Lag-Llama: Towards Foundation Models for Probabilistic Time Series Forecasting

相关技术原理:...(搜一下很多文章讲的都非常好)

实现模型预测:

1.下载模型文件

从 HuggingFace下载,如果网络原因访问不了,建议从魔搭社区下载(lag-Llama · 模型库)

2.准备数据集

参考文档:pandas.DataFrame based dataset - GluonTS documentation

以我测试数据举例:

3.完整代码:(需要替换模型文件地址和数据集地址)

from itertools import islicefrom matplotlib import pyplot as plt
import matplotlib.dates as mdatesimport torch
from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.dataset.repository.datasets import get_datasetfrom gluonts.dataset.pandas import PandasDataset
import pandas as pdfrom lag_llama.gluon.estimator import LagLlamaEstimatordef get_lag_llama_predictions(dataset, prediction_length, device, num_samples, context_length=32, use_rope_scaling=False):# 模型文件地址ckpt = torch.load("/models/lag-Llama/lag-llama.ckpt", map_location=device, weights_only=False)  # Uses GPU since in this Colab we use a GPU.estimator_args = ckpt["hyper_parameters"]["model_kwargs"]rope_scaling_arguments = {"type": "linear","factor": max(1.0, (context_length + prediction_length) / estimator_args["context_length"]),}estimator = LagLlamaEstimator(# 模型文件地址ckpt_path="/models/lag-Llama/lag-llama.ckpt",prediction_length=prediction_length,context_length=context_length,# Lag-Llama was trained with a context length of 32, but can work with any context length# estimator argsinput_size=estimator_args["input_size"],n_layer=estimator_args["n_layer"],n_embd_per_head=estimator_args["n_embd_per_head"],n_head=estimator_args["n_head"],scaling=estimator_args["scaling"],time_feat=estimator_args["time_feat"],rope_scaling=rope_scaling_arguments if use_rope_scaling else None,batch_size=1,num_parallel_samples=100,device=device,)lightning_module = estimator.create_lightning_module()transformation = estimator.create_transformation()predictor = estimator.create_predictor(transformation, lightning_module)forecast_it, ts_it = make_evaluation_predictions(dataset=dataset,predictor=predictor,num_samples=num_samples)forecasts = list(forecast_it)tss = list(ts_it)return forecasts, tssimport pandas as pd
from gluonts.dataset.pandas import PandasDataseturl = ("/lag-llama/history.csv"
)
df = pd.read_csv(url, index_col=0, parse_dates=True)# Set numerical columns as float32
for col in df.columns:# Check if column is not of string typeif df[col].dtype != 'object' and pd.api.types.is_string_dtype(df[col]) == False:df[col] = df[col].astype('float32')# Create the Pandas
dataset = PandasDataset.from_long_dataframe(df, target="target", item_id="item_id")backtest_dataset = dataset
# 预测长度
prediction_length = 24  # Define your prediction length. We use 24 here since the data is of hourly frequency
# 样本数
num_samples = 1  # number of samples sampled from the probability distribution for each timestep
device = torch.device("cuda:1")  # You can switch this to CPU or other GPUs if you'd like, depending on your environmentforecasts, tss = get_lag_llama_predictions(backtest_dataset, prediction_length, device, num_samples)# 提取第一个时间序列的预测结果
forecast = forecasts[0]
print('=================================')
# 概率预测的完整样本(形状: [num_samples, prediction_length])
samples = forecast.samples
print(samples)

关键参数说明:

参数

说明

prediction_length

预测的未来时间步长

context_length

模型输入的历史时间步长(需 >= 季节性周期)

num_samples

概率预测的采样次数(值越大,概率区间越准)

checkpoint_path

预训练模型权重路径(需提前下载)

freq

时间序列频率(如 "H" 小时、"D" 天)

结果:

这里只是给出了简单的代码实现,想要更好的效果还需深入研究!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/899488.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Plastiform复制胶泥:高精度表面复制与测量的高效工具

在工业制造和质量检测领域,表面复制和测量是确保产品质量的关键环节。Plastiform复制胶泥作为一种创新材料,凭借其出色的性能和多样化的应用,为用户提供了可靠的解决方案。它能够快速捕捉复杂表面的细节,确保测量结果的准确性&…

AI大模型、机器学习以及AI Agent开源社区和博客

以下梳理了适合学习 AI大模型、机器学习、AI Agent和多模态技术 的英文网站、社区、官网和博客,按类别分类整理: 一、官方网站与开源平台 1. AI大模型 (Large Language Models) • OpenAI • 官网: openai.com • 内容: GPT系列模型文档、研究论文、AP…

python 上下文管理器with

with 上下文管理器 上下文管理器示例如下:若想不使用with关键字 上下文管理器 任何实现了 enter() 和 exit() 方法的对象都可称之为上下文管理器,上下文管理器对象可以使用 with 关键字。 必须同时具有__enter__和__exit__,就可以使用with语句…

买卖股票的最佳时机(121)

121. 买卖股票的最佳时机 - 力扣&#xff08;LeetCode&#xff09; 解法&#xff1a; class Solution { public:int maxProfit(vector<int>& prices) {int cur_min prices[0];int max_profit 0;for (int i 1; i < prices.size(); i) {if (prices[i] > cur…

CesiumJS 本地数据瓦片加载南北两极出现圆点问题

const imageryProvider new UrlTemplateImageryProvider({url: "/gisimg/{z}/{x}/{reverseY}.png",minimumLevel: 0,maximumLevel: 19})上面这段代码是加载本地切片&#xff0c;但是有个致命问题就是会出现南北两极显示蓝色圆点 解决方案&#xff1a; 加上这句话&am…

Linux编译器gcc/g++使用完全指南:从编译原理到动静态链接

一、gcc/g基础认知 在Linux开发环境中&#xff0c;gcc和g是我们最常用的编译器工具&#xff1a; gcc&#xff1a;GNU C Compiler&#xff0c;专门用于编译C语言程序g&#xff1a;GNU C Compiler&#xff0c;用于编译C程序&#xff08;也可编译C语言&#xff09; &#x1f4cc…

Vue学习笔记集--computed

computed 在 Vue 3 的 Composition API 中&#xff0c;computed 用于定义响应式计算属性 它的核心特性是自动追踪依赖、缓存计算结果&#xff08;依赖未变化时不会重新计算&#xff09; 基本用法 1. 定义只读计算属性 import { ref, computed } from vue;const count ref(…

飞致云荣获“Alibaba Cloud Linux最佳AI镜像服务商”称号

2025年3月24日&#xff0c;阿里云云市场联合龙蜥社区发布“2024年度Alibaba Cloud Linux最佳AI镜像服务商”评选结果。 经过主办方的严格考量&#xff0c;飞致云&#xff08;即杭州飞致云信息科技有限公司&#xff09;凭借旗下MaxKB开源知识库问答系统、1Panel开源面板、Halo开…

Vue如何利用Postman和Axios制作小米商城购物车----简版

实现功能&#xff1a;全选、单选、购物数量显示、合计价格显示 实现效果如下&#xff1a; 思路&#xff1a; 1.数据要利用写在Postman里面&#xff0c;通过地址来调用Postman里面的数据。 2.写完数据后&#xff0c;给写的数据一个名字&#xff0c;然后加上一个空数组&#xf…

第一篇:系统分析师首篇

目录 一、目标二、计划三、完成情况1.宏观思维导图2.过程中的团队管理和其它方面的思考 四、意外之喜(最少2点)1.计划内的明确认知和思想的提升标志2.计划外的具体事情提升内容和标志 一、目标 通过参加考试&#xff0c;训练学习能力&#xff0c;而非单纯以拿证为目的。 1.在复…

CSS学习笔记4——盒子模型

目录 盒子模型是什么&#xff1f; 盒子模型的组成 一、div标签 二、边框属性 1、border-style:边框样式 2、border-width:边框宽度 3、border-color:边框颜色、border&#xff1a;综合设置 4、border-radius:圆角边框 5、border-image&#xff1a;图像边框 三、边距属性…

复现文献中的三维重建图像生成,包括训练、推理和可视化

要复现《One - 2 - 3 - 45 Fast Single Image to 3D Objects with Consistent Multi - View Generation and 3D Diffusion (CVPR)2024》文献中的三维重建图像生成&#xff0c;包括训练、推理和可视化&#xff0c;并且确保代码能正常运行&#xff0c;下面是基本的实现步骤和示例…

stable diffusion 本地部署教程 2025最新版

前提&#xff1a; 需要环境 git git下载地址Git - Downloading Package ​ 直接装即可 python3.10.6 下载地址 Python Release Python 3.10.6 | Python.org ​ 记得python环境一定要3.10.6&#xff01;&#xff01;&#xff01; 第一个版本 项目地址https://github.…

【二刷代码随想录】螺旋矩阵求解方法、推荐习题

一、求解方法 &#xff08;1&#xff09;按点模拟路径 在原有坐标的基准上&#xff0c;叠加 横纵坐标 的变化值&#xff0c;求出下一位置&#xff0c;并按题完成要求。但需注意转角的时机判断&#xff0c;特别是最后即将返回上一出发点的位置。 &#xff08;2&#xff09;按层…

从Manus到OpenManus:AI智能体技术如何重塑未来生活场景?

从Manus到OpenManus&#xff1a;AI智能体技术如何重塑未来生活场景&#xff1f; 一、现状&#xff1a;AI智能体技术面临的三大核心矛盾 &#xff08;通过分析用户高频痛点与市场反馈提炼&#xff09; 能力与门槛的失衡 Manus展示的复杂任务处理能力&#xff08;如股票分析、代…

迭代器与可迭代对象

概念层面&#xff1a; 可迭代对象&#xff1a; 一个可迭代对象是指任何可以返回一个迭代器的对象。换句话说&#xff0c;它实现了 __iter__() 方法 比如&#xff1a;列表、元组、字典、字符串、集合等 直接通过 for 循环使用&#xff0c;因为 for 循环内部会调用其 __iter__(…

总结PostgreSQL创建数据库失败的解决办法

作者&#xff1a;朱金灿 来源&#xff1a;clever101的专栏 系统环境是Windows 11 专业版&#xff0c;PostgreSQL版本是17。在运行sql语句创建数据库时出现错误&#xff1a; 閿欒: template database \"template1\" has a collation version mismatch DETAIL: Th…

Mybatis源码 插件机制

简介 插件是一种常见的扩展方式&#xff0c;大多数开源框架也都支持用户通过添加自定义插件的方式来扩展或者改变原有的功能&#xff0c;MyBatis中也提供的有插件&#xff0c;虽然叫插件&#xff0c;但是实际上是通过拦截器(Interceptor)实现的&#xff0c;在MyBatis的插件模块…

Android14 SystemUI中添加第三方AIDL

由于特殊需求&#xff0c;需要在SystemUI中添加第三方AIDL&#xff0c;去做一些客制化的修改。现在记录一下AIDL添加的过程。 1.将AIDL文件拷贝到frameworks/base/packages/SystemUI/src/下&#xff0c;我要添加的AIDL文件是com/test/myctr/IDevicectr.aidl&#xff0c;添加后的…

Binlog、Redo log、Undo log的区别

一、binlog和redo log的区别 特性binlogredo log记录对象记录的是 MySQL 服务器的事务操作&#xff0c;针对的是整个数据库实例。记录的是 InnoDB 存储引擎的数据页变化&#xff0c;针对的是具体的存储引擎层面。记录内容记录的是事务的逻辑操作&#xff0c;例如 SQL 语句&…