【2025】Datawhale AI春训营-RNA结构预测(AI+创新药)-Task2笔记

【2025】Datawhale AI春训营-RNA结构预测(AI+创新药)-Task2笔记

本文对Task2提供的进阶代码进行理解。

任务描述

Task2的任务仍然是基于给定的RNA三维骨架结构,生成一个或多个RNA序列,使得这些序列能够折叠并尽可能接近给定的目标三维骨架结构。这是一个RNA逆折叠的过程。

将RNA序列折叠成特定三维结构的过程是一个RNA折叠的过程。

在Task2中,继续使用算法进行RNA逆折叠。评估标准是序列的恢复率,即算法生成的RNA序列在多大程度上能与真实能够折叠成目标结构的RNA序列相似。

代码理解

1、导入模块

import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.nn import TransformerConv, LayerNorm
from torch_geometric.nn import radius_graph
from Bio import SeqIO
import math

2、配置参数

# 配置参数
class Config:seed = 42device = "cuda" if torch.cuda.is_available() else "cpu"batch_size = 16 if torch.cuda.is_available() else 8  # 根据显存调整lr = 0.001epochs = 50seq_vocab = "AUCG"coord_dims = 7  hidden_dim = 256num_layers = 4  # 减少层数防止显存溢出k_neighbors = 20  dropout = 0.1rbf_dim = 16num_heads = 4amp_enabled = True  # 混合精度训练

3、定义几何生成器

# 几何特征生成器
class GeometricFeatures:@staticmethoddef rbf(D, D_min=0., D_max=20., D_count=16):device = D.deviceD_mu = torch.linspace(D_min, D_max, D_count, device=device)D_mu = D_mu.view(*[1]*len(D.shape), -1)D_sigma = (D_max - D_min) / D_countD_expand = D.unsqueeze(-1)return torch.exp(-((D_expand - D_mu)/D_sigma) ** 2)@staticmethoddef dihedrals(X, eps=1e-7):X = X.to(torch.float32)L = X.shape[0]dX = X[1:] - X[:-1]U = F.normalize(dX, dim=-1)# 计算连续三个向量u_prev = U[:-2]u_curr = U[1:-1]u_next = U[2:]# 计算法向量n_prev = F.normalize(torch.cross(u_prev, u_curr, dim=-1), dim=-1)n_curr = F.normalize(torch.cross(u_curr, u_next, dim=-1), dim=-1)# 计算二面角cosD = (n_prev * n_curr).sum(-1)cosD = torch.clamp(cosD, -1+eps, 1-eps)D = torch.sign((u_prev * n_curr).sum(-1)) * torch.acos(cosD)# 填充处理if D.shape[0] < L:D = F.pad(D, (0,0,0,L-D.shape[0]), "constant", 0)return torch.stack([torch.cos(D[:,:5]), torch.sin(D[:,:5])], -1).view(L,-1)@staticmethoddef direction_feature(X):dX = X[1:] - X[:-1]return F.pad(F.normalize(dX, dim=-1), (0,0,0,1))

4、定义图构建器

# 图构建器
class RNAGraphBuilder:@staticmethoddef build_graph(coord, seq):assert coord.shape[1:] == (7,3), f"坐标维度错误: {coord.shape}"coord = torch.tensor(coord, dtype=torch.float32)# 节点特征node_feats = [coord.view(-1, 7 * 3),  # [L,21]GeometricFeatures.dihedrals(coord[:,:6,:]),  # [L,10]GeometricFeatures.direction_feature(coord[:,4,:])  # [L,3]]x = torch.cat(node_feats, dim=-1)  # [L,34]# 边构建pos = coord[:,4,:]edge_index = radius_graph(pos, r=20.0, max_num_neighbors=Config.k_neighbors)# 边特征row, col = edge_indexedge_vec = pos[row] - pos[col]edge_dist = torch.norm(edge_vec, dim=-1, keepdim=True)edge_feat = torch.cat([GeometricFeatures.rbf(edge_dist).squeeze(1),  # [E,16]F.normalize(edge_vec, dim=-1)  # [E,3]], dim=-1)  # [E,19]# 标签y = torch.tensor([Config.seq_vocab.index(c) for c in seq], dtype=torch.long)return Data(x=x, edge_index=edge_index, edge_attr=edge_feat, y=y)

5、定义模型结构

# 模型架构
class RNAGNN(nn.Module):def __init__(self):super().__init__()# 节点特征编码self.feat_encoder = nn.Sequential(nn.Linear(34, Config.hidden_dim),nn.ReLU(),LayerNorm(Config.hidden_dim),nn.Dropout(Config.dropout))# 边特征编码(关键修复)self.edge_encoder = nn.Sequential(nn.Linear(19, Config.hidden_dim),nn.ReLU(),LayerNorm(Config.hidden_dim),nn.Dropout(Config.dropout))# Transformer卷积层self.convs = nn.ModuleList([TransformerConv(Config.hidden_dim,Config.hidden_dim // Config.num_heads,heads=Config.num_heads,edge_dim=Config.hidden_dim,  # 匹配编码后维度dropout=Config.dropout) for _ in range(Config.num_layers)])# 残差连接self.mlp_skip = nn.ModuleList([nn.Sequential(nn.Linear(Config.hidden_dim, Config.hidden_dim),nn.ReLU(),LayerNorm(Config.hidden_dim)) for _ in range(Config.num_layers)])# 分类头self.cls_head = nn.Sequential(nn.Linear(Config.hidden_dim, Config.hidden_dim),nn.ReLU(),LayerNorm(Config.hidden_dim),nn.Dropout(Config.dropout),nn.Linear(Config.hidden_dim, len(Config.seq_vocab)))self.apply(self._init_weights)def _init_weights(self, module):if isinstance(module, nn.Linear):nn.init.xavier_uniform_(module.weight)if module.bias is not None:nn.init.constant_(module.bias, 0)def forward(self, data):x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr# 边特征编码(关键步骤)edge_attr = self.edge_encoder(edge_attr)  # [E,19] -> [E,256]# 节点编码h = self.feat_encoder(x)# 消息传递for i, (conv, skip) in enumerate(zip(self.convs, self.mlp_skip)):h_res = conv(h, edge_index, edge_attr=edge_attr)h = h + skip(h_res)if i < len(self.convs)-1:h = F.relu(h)h = F.dropout(h, p=Config.dropout, training=self.training)return self.cls_head(h)

6、定义数据增强类

# 数据增强
class CoordTransform:@staticmethoddef random_rotation(coords):device = torch.device(Config.device)coords_tensor = torch.from_numpy(coords).float().to(device)angle = np.random.uniform(0, 2*math.pi)rot_mat = torch.tensor([[math.cos(angle), -math.sin(angle), 0],[math.sin(angle), math.cos(angle), 0],[0, 0, 1]], device=device)return (coords_tensor @ rot_mat.T).cpu().numpy()

7、定义数据集类

# 数据集类
class RNADataset(torch.utils.data.Dataset):def __init__(self, coords_dir, seqs_dir, augment=False):self.samples = []self.augment = augmentfor fname in os.listdir(coords_dir):# 加载坐标coord = np.load(os.path.join(coords_dir, fname))coord = np.nan_to_num(coord, nan=0.0)# 数据增强if self.augment and np.random.rand() > 0.5:coord = CoordTransform.random_rotation(coord)# 加载序列seq_id = os.path.splitext(fname)[0]seq_path = os.path.join(seqs_dir, f"{seq_id}.fasta")seq = str(next(SeqIO.parse(seq_path, "fasta")).seq)# 构建图self.samples.append(RNAGraphBuilder.build_graph(coord, seq))def __len__(self): return len(self.samples)def __getitem__(self, idx): return self.samples[idx]

8、训练函数

# 训练函数
def train(model, loader, optimizer, scheduler, criterion):model.train()scaler = torch.cuda.amp.GradScaler(enabled=Config.amp_enabled)total_loss = 0for batch in loader:batch = batch.to(Config.device)optimizer.zero_grad()with torch.cuda.amp.autocast(enabled=Config.amp_enabled):logits = model(batch)loss = criterion(logits, batch.y)scaler.scale(loss).backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)scaler.step(optimizer)scaler.update()total_loss += loss.item()scheduler.step()return total_loss / len(loader)

9、评估函数

# 评估函数
def evaluate(model, loader):model.eval()total_correct = total_nodes = 0with torch.no_grad():for batch in loader:batch = batch.to(Config.device)logits = model(batch)preds = logits.argmax(dim=1)total_correct += (preds == batch.y).sum().item()total_nodes += batch.y.size(0)return total_correct / total_nodes

10、主函数

if __name__ == "__main__":# 初始化torch.manual_seed(Config.seed)if torch.cuda.is_available():torch.cuda.manual_seed_all(Config.seed)torch.backends.cudnn.benchmark = True# 数据集train_set = RNADataset("./RNA_design_public/RNAdesignv1/train/coords","./RNA_design_public/RNAdesignv1/train/seqs",augment=True)# 划分数据集train_size = int(0.8 * len(train_set))val_size = (len(train_set) - train_size) // 2test_size = len(train_set) - train_size - val_sizetrain_set, val_set, test_set = torch.utils.data.random_split(train_set, [train_size, val_size, test_size])# 数据加载train_loader = torch_geometric.loader.DataLoader(train_set, batch_size=Config.batch_size, shuffle=True,pin_memory=True,num_workers=4)val_loader = torch_geometric.loader.DataLoader(val_set, batch_size=Config.batch_size)test_loader = torch_geometric.loader.DataLoader(test_set, batch_size=Config.batch_size)# 模型初始化model = RNAGNN().to(Config.device)optimizer = optim.AdamW(model.parameters(), lr=Config.lr, weight_decay=0.01)scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=Config.epochs)criterion = nn.CrossEntropyLoss()# 训练循环best_acc = 0for epoch in range(Config.epochs):train_loss = train(model, train_loader, optimizer, scheduler, criterion)val_acc = evaluate(model, val_loader)print(f"Epoch {epoch+1}/{Config.epochs} | Loss: {train_loss:.4f} | Val Acc: {val_acc:.4f}")if val_acc > best_acc:best_acc = val_acctorch.save(model.state_dict(), "best_model.pth")# 最终测试model.load_state_dict(torch.load("best_model.pth"))test_acc = evaluate(model, test_loader)print(f"\nFinal Test Accuracy: {test_acc:.4f}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/77895.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vim 命令复习

命令模式下的命令及快捷键 # dd删除光所在行的内容 # ndd从光标所在行开始向下删除n行 # yy复制光标所在行的内容 # nyy复制光标所在行向下n行的内容 # p将复制的内容粘贴到光标所在行以下&#xff08;小写&#xff09; # P将复制的内容粘贴到光标所在行以上&#xff08;大写&…

哪些心电图表现无缘事业编体检呢?

根据《公务员录用体检通用标准》心血管系统条款及事业单位体检实施细则&#xff0c;心电图不合格主要涉及以下类型及处置方案&#xff1a; 一、心律失常类 早搏&#xff1a;包括房性早搏、室性早搏和交界性早搏。如果每分钟早搏次数较多&#xff08;如超过5次&#xff09;&…

Linux学习——UDP

编程的整体框架 bind&#xff1a;绑定服务器&#xff1a;TCP地址和端口号 receivefrom()&#xff1a;阻塞等待客户端数据 sendto():指定服务器的IP地址和端口号&#xff0c;要发送的数据 无连接尽力传输&#xff0c;UDP:是不可靠传输 实时的音视频传输&#x…

ReAct Agent 实战:基于DeepSeek从0到1实现大模型Agent的探索模式

写在前面:动态思考,边想边做 大型语言模型(LLM)的崛起开启了通用人工智能(AGI)的无限遐想。但要让 LLM 从一个被动的“文本生成器”转变为能够主动解决问题、与环境交互的智能体(Agent),我们需要赋予它思考、行动和学习的能力。ReAct (Reason + Act) 框架正是实现这一…

从物理到预测:数据驱动的深度学习的结构化探索及AI推理

在当今科学探索的时代&#xff0c;理解的前沿不再仅仅存在于我们书写的方程式中&#xff0c;也存在于我们收集的数据和构建的模型中。在物理学和机器学习的交汇处&#xff0c;一个快速发展的领域正在兴起&#xff0c;它不仅观察宇宙&#xff0c;更是在学习宇宙。 AI推理 我们…

结合地理数据处理

CSV 文件不仅可以存储表格数据&#xff0c;还可以与地理空间数据结合&#xff0c;实现更强大的地理处理功能。例如&#xff0c;你可以将 CSV 文件中的坐标数据转换为点要素类&#xff0c;然后进行空间分析。 示例&#xff1a;将 CSV 文件中的坐标数据转换为点要素类 假设我们有…

SpringBoot中6种自定义starter开发方法

在SpringBoot生态中,starter是一种特殊的依赖,它能够自动装配相关组件,简化项目配置。 自定义starter的核心价值在于: • 封装复杂的配置逻辑,实现开箱即用 • 统一技术组件的使用规范,避免"轮子"泛滥 • 提高开发效率,减少重复代码 方法一:基础配置类方式 …

滚珠导轨松动会导致哪些影响?

直线导轨用于高精度或快速直线往复运动场所&#xff0c;且能够担负一定的扭矩&#xff0c;在高负载的情况下实现高精度的直线运动。它主要由导轨和滑块组成&#xff0c;其中导轨作为固定元件&#xff0c;滑块则在其上进行往复直线运动。但是滚珠导轨松动会导致哪些影响&#xf…

从零开始搭建Django博客②--Django的服务器内容搭建

本文主要在Ubuntu环境上搭建&#xff0c;为便于研究理解&#xff0c;采用SSH连接在虚拟机里的ubuntu-24.04.2-desktop系统搭建&#xff0c;当涉及一些文件操作部分便于通过桌面化进行理解&#xff0c;通过Nginx代理绑定域名&#xff0c;对外发布。 此为从零开始搭建Django博客…

ZLMediaKit支持JT1078实时音视频

ZLMediaKit 对 JT1078 实时音视频协议的支持主要通过其扩展版本或与其他中间件结合实现。以下是基于搜索结果的综合分析&#xff1a; 一、ZLMediaKit 原生支持能力 开源版本的基础支持 ZLMediaKit 开源版本本身未直接集成 JT1078 协议解析模块&#xff0c;但可通过 RTP 推流功能…

Java队列(Queue)核心操作与最佳实践:深入解析与面试指南

文章目录 概述一、Java队列核心实现类对比1. LinkedList2. ArrayDeque3. PriorityQueue 二、核心操作API与时间复杂度三、经典使用场景与最佳实践场景1&#xff1a;BFS层序遍历&#xff08;树/图&#xff09;场景2&#xff1a;滑动窗口最大值&#xff08;单调队列&#xff09; …

MetaGPT智能体框架深度解析:记忆模块设计与应用实践

在AI智能体技术从单点突破迈向系统工程的关键阶段&#xff0c;MetaGPT凭借其创新的记忆架构重新定义了多智能体协作范式。本文深度解构其革命性的三级记忆系统&#xff0c;揭秘支撑10倍效能提升的知识蒸馏算法与动态上下文控制策略&#xff0c;通过企业级应用案例与性能基准测试…

集结号海螺捕鱼服务器调度与房间分配机制详解:六

本篇围绕服务器调度核心逻辑进行剖析&#xff0c;重点讲解用户连接过程、房间分配机制、服务端并发策略及常见性能瓶颈优化。适用于具备中高级 C 后端开发经验的读者&#xff0c;覆盖网络会话池、逻辑服调度器与房间生命周期管理等关键模块。 一、服务器结构概览 整体系统采用…

【电子通识】热敏打印机是怎么形成(打印)图像和文字的?

在我们身边&#xff0c;热敏打印方式常见用于装饰贴纸、便利店的小票。此外&#xff0c;物流及食品条码标签、身份证件、机票・火车票、X光片、食品日期印刷等&#xff0c;很多打印都用到了热敏打印头。 热敏打印头的蓄热层(涂釉层)上分布着一排加热元件&#xff08;发热线&…

SQL注入漏洞中会使用到的函数

目录 一、信息获取函数 1. 通用函数 2. 元数据查询&#xff08;INFORMATION_SCHEMA&#xff09; 二、字符串操作函数 1. 字符串连接 2. 字符串截取 3. 编码/解码 三、报错注入专用函数 1. MySQL 2. SQL Server 3. PostgreSQL 四、时间盲注函数 1. 通用延迟 2. 计…

车载信息安全架构 --- 汽车网络安全

我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 周末洗了一个澡,换了一身衣服,出了门却不知道去哪儿,不知道去找谁,漫无目的走着,大概这就是成年人最深的孤独吧! 旧人不知我近况,新人不知我过…

Linux423 删除用户

查找 上面已查过&#xff1a;无法使用sudo 新开个终端试试 之前开了一个终端&#xff0c;按照deepseek排查 计划再开一个进程 开一个终端 后强制删除时显示&#xff1a;此事将被报告

《从卷积核到数字解码:CNN 手写数字识别实战解析》

文章目录 一、手写数字识别的本质与挑战二、使用步骤1.导入torch库以及与视觉相关的torchvision库2.下载datasets自带的手写数字的数据集到本地 三、完整代码展示 一、手写数字识别的本质与挑战 手写数字识别的核心是&#xff1a;从二维像素矩阵中提取具有判别性的特征&#x…

UniOcc:自动驾驶占用预测和预报的统一基准

25年3月来自 UC Riverside、U Wisconsin 和 TAMU 的论文"UniOcc: A Unified Benchmark for Occupancy Forecasting and Prediction in Autonomous Driving"。 UniOcc 是一个全面统一的占用预测基准&#xff08;即基于历史信息预测未来占用&#xff09;和基于摄像头图…

模型量化核心技术解析:从算法原理到工业级实践

一、模型量化为何成为大模型落地刚需&#xff1f; 算力困境&#xff1a;175B参数模型FP32推理需0.5TB内存&#xff0c;超出主流显卡容量 速度瓶颈&#xff1a;FP16推理延迟难以满足实时对话需求&#xff08;如客服场景<200ms&#xff09; 能效挑战&#xff1a;边缘设备运行…