sglang 大模型推理框架支持的EAGLE 1,2,3

文章目录

      • EAGLE 系列模型的演进与核心机制
      • 关键参数与训练逻辑
      • 思考

参考来源:https://docs.sglang.com.cn/backend/speculative_decoding.html
https://github.com/SafeAILab/EAGLE
EAGLE3 https://arxiv.org/pdf/2503.01840

EAGLE 系列模型的演进与核心机制

EAGLE 基础架构
草稿模型通过特征序列和 token 序列预测下一个特征向量,基于原始 LLM 的最后一个隐藏状态生成候选。采样后的 token 与原始序列以树状结构扩展,分支因子由speculative_eagle_topk控制,确保上下文连贯性。扩展后的树结构重新作为输入迭代生成。

EAGLE-2 的优化
引入动态分支评估机制,草稿模型主动评估扩展分支的可能性,提前终止低概率分支的扩展。扩展阶段结束后,通过重排序筛选前speculative_num_draft_tokens个节点作为最终草稿 token,减少冗余计算。

--speculative-token-map参数设置为true以启用高频 token 优化功能。该参数通常在模型推理或训练配置文件中进行设置。

EAGLE-3 的改进
移除特征预测目标,整合低层与中间层特征提升表示能力。采用 on-policy 训练方式,使模型在推理阶段的行为与训练目标更一致,进一步优化生成质量与效率。

关键参数与训练逻辑

  • speculative_eagle_topk:控制每步扩展的分支数量,影响生成多样性与计算开销。
  • speculative_num_draft_tokens:决定保留的候选 token 数量,平衡生成速度与准确性。
  • On-policy 训练:通过对齐训练与推理阶段的策略,减少分布偏移问题。

  • https://github.com/SafeAILab/EAGLE/blob/main/eagle/traineagle3/cnets.py

核心代码部分

def_prepare_decoder_attention_mask(self,attention_mask,input_shape,inputs_embeds,past_key_values_length):# create causal mask# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]combined_attention_mask=Noneifinput_shape[-1]>1:combined_attention_mask=_make_causal_mask(input_shape,inputs_embeds.dtype,device=inputs_embeds.device,past_key_values_length=past_key_values_length,)ifattention_maskisnotNone:# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]expanded_attn_mask=_expand_mask(attention_mask,inputs_embeds.dtype,tgt_len=input_shape[-1]).to(inputs_embeds.device)combined_attention_mask=(expanded_attn_maskifcombined_attention_maskisNoneelseexpanded_attn_mask+combined_attention_mask)returncombined_attention_mask@torch.no_grad()defdataprepare(self,input_ids,attention_mask,loss_mask):device=input_ids.device outs=self.target_model(input_ids=input_ids,attention_mask=attention_mask)hidden_states0=outs.hidden_states[0]hidden_states1=outs.hidden_states[1]hidden_states2=outs.hidden_states[2]hidden_states=torch.cat((hidden_states0,hidden_states1,hidden_states2),dim=-1)# hidden_states=torch.cat((hidden_states0,hidden_states1),dim=-1)target=outs.logits target=padding(target,left=False)input_ids=padding(input_ids,left=False)iftargetisnotNone:target=target.to(device)loss_mask=loss_mask[...,None]loss_mask=loss_mask.to(device)returnhidden_states,target,loss_mask,input_idsdefforward(self,# hidden_states,input_ids,attention_mask:Optional[torch.Tensor]=None,position_ids:Optional[torch.LongTensor]=None,past_key_values:Optional[List[torch.FloatTensor]]=None,use_cache:Optional[bool]=None,output_attentions:Optional[bool]=None,output_hidden_states:Optional[bool]=None,loss_mask:Optional[torch.Tensor]=None,):hidden_states,target,loss_mask,input_ids=self.dataprepare(input_ids,attention_mask,loss_mask)batch_size,seq_length,_=hidden_states.shape seq_length_with_past=seq_length past_key_values_length=0# with torch.no_grad():# inputs_embeds = self.embed_tokens(input_ids)# inputs_embeds = inputs_embeds.detach()ifself.trainingandself.gradient_checkpointingandnothidden_states.requires_grad:hidden_states.requires_grad=Truehidden_states=self.fc(hidden_states)ifpast_key_valuesisnotNone:past_key_values_length=past_key_values[0][0].shape[2]seq_length_with_past=seq_length_with_past+past_key_values_lengthifposition_idsisNone:device=hidden_states.device position_ids=torch.arange(past_key_values_length,seq_length+past_key_values_length,dtype=torch.long,device=device)position_ids=position_ids.unsqueeze(0).view(-1,seq_length)else:position_ids=position_ids.view(-1,seq_length).long()ifattention_maskisNone:attention_mask=torch.ones((batch_size,seq_length_with_past),dtype=torch.bool,device=hidden_states.device)attention_mask=self._prepare_decoder_attention_mask(attention_mask,(batch_size,seq_length),hidden_states,past_key_values_length)ifself.gradient_checkpointingandself.training:ifuse_cache:use_cache=Falseplosses=[]vlosses=[]acces=[]cache_hidden=[[],[]]foridxinrange(self.length):last=idx==self.length-1inputs_embeds=self.embed_tokens(input_ids)ifself.trainingandself.gradient_checkpointingandnotinputs_embeds.requires_grad:inputs_embeds.requires_grad=Trueinputs_embeds=inputs_embeds.to(hidden_states.dtype)ifself.gradient_checkpointingandself.training:defcreate_custom_forward(module):defcustom_forward(*inputs):# None for past_key_valuereturnmodule(*inputs,None,output_attentions)returncustom_forward layer_outputs,cache_hidden=torch.utils.checkpoint.checkpoint(create_custom_forward(self.midlayer),inputs_embeds,hidden_states,cache_hidden,attention_mask,position_ids,)else:layer_outputs,cache_hidden=self.midlayer(input_emb=inputs_embeds,hidden_states=hidden_states,cache_hidden=cache_hidden,attention_mask=attention_mask,position_ids=position_ids,past_key_value=None,output_attentions=output_attentions,use_cache=True,)hidden_states_out=layer_outputs[0]# cache_hidden.append(layer_outputs[1])# kv_cahce = layer_outputs[-1]withtorch.no_grad():# hidden_states_target = padding(hidden_states, left=False)target_head=target target_max_token=target_head.argmax(-1)# Move d2t to the same device as target_max_tokenself.t2d=self.t2d.to(target_max_token.device)target_mask=self.t2d[target_max_token]target_mask=target_mask[...,None].int()position_mask=target_mask*loss_mask target_head=target_head[...,self.t2d]target_head=target_head.float()target_p=nn.Softmax(dim=2)(target_head)target_p=target_p.detach()hidden_states=hidden_states_out hidden_states_out=self.norm(hidden_states_out)logits=self.lm_head(hidden_states_out)logits=logits.float()out_logp=nn.LogSoftmax(dim=2)(logits)plogp=target_p*out_logp loss=-torch.sum(position_mask*plogp,2).mean()plosses.append(loss)withtorch.no_grad():acces.append(((logits.argmax(-1)==target_p.argmax(-1))*position_mask.squeeze(-1)).sum().item()/(loss_mask.sum().item()+1e-6))ifnotlast:input_ids=padding(input_ids,left=False)target=padding(target,left=False)loss_mask=padding(loss_mask,left=False)returnplosses,vlosses,acces

思考

》 FASTMTP与EAGLE3相比,谁更快一些?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/1014903.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

延凡科技 EMS 智慧云平台:3 万起订阅,中小用能单位的 “云端全能源管家”

延凡科技 EMS 智慧云平台是专为工厂、园区、楼宇、医院等中小用能单位打造的云原生能源管理解决方案,聚焦 “降本节能、碳排合规、云端运维、数据驱动” 核心目标,采用 SaaS 订阅模式,整合物联网感知、云边协同、AI 能效优化算法,…

拦截器注册InterceptorRegistry 实现讲解

1.核心概念InterceptorRegistry 是 Spring MVC 提供的拦截器注册器,用于配置拦截器的拦截规则。2.主要方法addInterceptor(): 添加拦截器 addPathPatterns(): 指定要拦截的路径 excludePathPatterns(): 指定要排除的路径 路径匹配规则 /api/**: 匹配 /api/ 下的所有…

汇编语言全接触-27.工具提示控件

我们将学习工具提示控件:它是什么如何创建和使用.下载例子理论:工具提示是当鼠标在某特定区域上停留时显示的一个矩形窗口.工具提示窗口包含一些编程者想要显示的文本.在这点上,工具提示同状态栏的作用是一样的,所不同的是工具提示当单击或者远离指定区域的时候就会消逝,你可能…

汇编语言全接触-26.启动画面

上一章我们学习了位图的使用.在这一章我们要用上帝赋予我们的创造力来融会贯通上一章我们学到的知识.那就是研究如何用位图来创建启动画面. 你可以在这里下载示范: the example. 理论首先,我们先要搞清楚什么是启动画面.举个简单的例子:我们启动某些作的专业一点的程序时(比如N…

验证IP地址(一)

我们先来看题目描述:给定两个 没有重复元素 的数组 nums1 和 nums2 ,其中nums1 是 nums2 的子集。找到 nums1 中每个元素在 nums2 中的下一个比其大的值。nums1 中数字 x 的下一个更大元素是指 x 在 nums2 中对应位置的右边的第一个比 x 大的元素。如果不…

医院管理|基于springboot 医院管理系统(源码+数据库+文档)

医院管理 目录 基于springboot vue医院管理系统 一、前言 二、系统功能演示 三、技术选型 四、其他项目参考 五、代码参考 六、测试参考 七、最新计算机毕设选题推荐 八、源码获取: 基于springboot vue医院管理系统 一、前言 博主介绍:✌️大…

浅谈:算法中的斐波那契数(一)

我们先来看题目描述:斐波那契数,通常用 F(n) 表示,形成的序列称为斐波那契数列。该数列由 0 和 1 开始,后面的每一项数字都是前面两项数字的和。也就是:F(0) 0, F(1) 1 F(N) F(N - 1) F(N - 2), 其中 N > 1.给…

测试的“元认知”:智能体如何评估自身可靠性?

在软件测试领域,自动化与智能化正以前所未有的速度重塑工作流程。随着人工智能代理(智能体)广泛应用于测试用例生成、缺陷预测和持续集成,一个关键问题浮出水面:这些智能体如何像人类测试专家一样,对自身行…

10.8 总结

10.8 总结 作业回顾 1.1 索引练习节选 s hello 1 world 2 hello 3 Python # 获取s的长度 print(len(s)) # 30 # 获取第4个字符 print(s[3]) # l # 获取最后一个字符 print(s[-1]) # n # 获取第7个字符 print(s[6]) # 1 # 获取倒数第7个字符 print(s[-7]) # 空格【不显…

【Hadoop+Spark+python毕设】物联网网络安全威胁数据分析系统、计算机毕业设计、包括数据爬取、数据分析、数据可视化、Hadoop、实战教学

🎓 作者:计算机毕设小月哥 | 软件开发专家 🖥️ 简介:8年计算机软件程序开发经验。精通Java、Python、微信小程序、安卓、大数据、PHP、.NET|C#、Golang等技术栈。 🛠️ 专业服务 🛠️ 需求定制化开发源码提…

9.28总结

9.28总结 知识回顾 # 1. 封装一个函数:获取指定数据的阶乘 【没有指定数据的话默认求10的阶乘】 默认参数 # 阶乘 比如5!5*4*3*2*1 # 未知数据 有1个 # 是否需要返回结果 def factorial(num10):result 1for i in range(num, 0, -1):result * ireturn…

零基础学JAVA--Day34(Map接口+HashTable+HashMap+TreeSet+TreeMap+开发中如何选择集合实现类?(重要)) - 指南

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

电影院购票|基于springboot 电影院购票系统(源码+数据库+文档)

电影院购票 目录 基于springboot vue电影院购票系统 一、前言 二、系统功能演示 三、技术选型 四、其他项目参考 五、代码参考 六、测试参考 七、最新计算机毕设选题推荐 八、源码获取: 基于springboot vue电影院购票系统 一、前言 博主介绍&#xff1a…

C#+VisionMaster联合开发(二)_操作流程

1、获取方案中的流程列表 // 加载流程列表 ProcessInfoList processInfoList = VmSolutionMain.GetAllProcedureList(); if (processInfoList.nNum > 0) {var processNames = processInfoList.astProcessInfo.ToLis…

本地部署DeepSeek

ollama终端的方式部署参考:ollama本地部署 智谱API Key获取 LM Studio 它是模型的托管平台,可以把模型加载后,作为服务器向外提供服务器,本身也具有简单的对话框可以聊天。 :https://lmstudio.ai/ 在左下角改为开发者…

AI驱动的手动测试变革:赋能而非替代

随着大语言模型和智能自动化技术的飞速发展,软件测试领域正迎来前所未有的变革浪潮。传统手动测试作为软件质量保障的基石,面临着效率提升与价值重塑的双重挑战。 AI时代手动测试的困境与机遇 传统手动测试的局限性 手动测试长期面临着测试覆盖率低、…

航空机票预定系统|基于springboot 航空机票预定系统(源码+数据库+文档)

航空机票预定 目录 基于springboot vue航空机票预定系统 一、前言 二、系统功能演示 ​三、技术选型 四、其他项目参考 五、代码参考 六、测试参考 七、最新计算机毕设选题推荐 八、源码获取: 基于springboot vue航空机票预定系统 一、前言 博主介绍&am…

[Windows] 剪映自动预合成v1.0

[Windows] 剪映自动预合成v1.0 链接:https://pan.xunlei.com/s/VOgRWgF_QfvslGjXSYwZaeDXA1?pwdrd56# 从零散的元素 【进入】预合成状态,一键完成。 配合47kb的【剪映草稿助手】还是不错的。

低代码平台的测试挑战:测试从业者的新战场

随着低代码开发平台在企业数字化转型中的广泛应用,软件测试领域正面临前所未有的范式转变。据Gartner预测,到2025年,70%的新应用将由低代码平台开发,这一趋势正在重新定义测试工程师的角色定位和方法体系。作为测试从业者&#xf…