一、项目背景与设计思路
1. 为什么“端到端 CNN”在医疗中经常失败?
很多教程喜欢这样做:
CT 图像 → CNN → 预测是否患病
但在真实医疗场景中,问题很快会暴露:
数据量不够(几百 ~ 几千)
批次差异大(不同医院 / 设备)
医生需要解释模型结果
模型上线后性能漂移严重
👉 这不是 CNN 不强,而是医疗场景不适合“一把梭”。
2. 更成熟的工程方案:CNN + XGBoost
医学影像 → CNN → 高阶影像特征 ↓ XGBoost / RF / LR ↓ 疾病风险预测
这个结构的优势是:
CNN 专注于特征表达
XGBoost 专注于稳定决策
小样本也能工作
方便做可解释性
二、项目整体结构设计
medical_prediction/ ├── data/ │ ├── images/ │ ├── clinical.csv │ └── labels.csv ├── cnn/ │ ├── dataset.py │ ├── model.py │ └── train_cnn.py ├── feature/ │ └── extract_features.py ├── ml/ │ ├── train_xgb.py │ └── evaluate.py └── main_pipeline.py
这是一个“真实可维护”的结构,不是 Notebook 玩具
三、Step 1:医学影像数据准备与 Dataset 构建
1️⃣ 自定义 Dataset(PyTorch)
# cnn/dataset.py import torch from torch.utils.data import Dataset import numpy as np class MedicalImageDataset(Dataset): def __init__(self, images, labels): self.images = images self.labels = labels def __len__(self): return len(self.labels) def __getitem__(self, idx): x = self.images[idx] y = self.labels[idx] return torch.tensor(x, dtype=torch.float32), torch.tensor(y)2️⃣ 医疗影像预处理经验(非常关键)
真实项目中通常需要:
归一化(HU 值 / 强度)
Resize
中心裁剪
简单增强(翻转、噪声)
不要一上来就疯狂数据增强,医疗里很容易引入伪特征。
四、Step 2:CNN 模型设计
1️⃣ CNN 设计原则
不追求太深
不追求 ImageNet 那套
目标是“稳定特征”而不是极致精度
2️⃣ CNN 模型代码
# cnn/model.py import torch import torch.nn as nn class MedicalCNN(nn.Module): def __init__(self): super().__init__() self.features = nn.Sequential( nn.Conv2d(1, 16, 3, padding=1), nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(16, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2) ) self.classifier = nn.Linear(32 * 7 * 7, 2) def forward(self, x, return_feature=False): x = self.features(x) x = x.view(x.size(0), -1) if return_feature: return x return self.classifier(x)五、Step 3:CNN 训练
1️⃣ 训练代码
# cnn/train_cnn.py import torch import torch.nn as nn import torch.optim as optim from cnn.model import MedicalCNN model = MedicalCNN() criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=1e-3) for epoch in range(15): model.train() images = torch.randn(64, 1, 28, 28) labels = torch.randint(0, 2, (64,)) outputs = model(images) loss = criterion(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() print(f"Epoch {epoch+1}, Loss={loss.item():.4f}")👉 工程经验
CNN 不必训到极致
过拟合反而会让特征“失真”
我通常在 loss 稳定后就停
六、Step 4:CNN 特征提取
# feature/extract_features.py import torch import numpy as np from cnn.model import MedicalCNN model = MedicalCNN() model.eval() def extract_features(images): with torch.no_grad(): feats = model(images, return_feature=True) return feats.cpu().numpy()images = torch.randn(300, 1, 28, 28) cnn_features = extract_features(images) print(cnn_features.shape)七、Step 5:融合临床特征
clinical_features = np.random.randn(300, 6) X = np.concatenate( [cnn_features, clinical_features], axis=1 ) y = np.random.randint(0, 2, 300)👉影像 + 临床 = 医疗 AI 的基本盘
八、Step 6:XGBoost 训练
from xgboost import XGBClassifier from sklearn.model_selection import train_test_split from sklearn.metrics import roc_auc_score X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, random_state=42 ) model = XGBClassifier( n_estimators=400, max_depth=5, learning_rate=0.03, subsample=0.8, colsample_bytree=0.8, eval_metric="logloss" ) model.fit(X_train, y_train) y_prob = model.predict_proba(X_test)[:, 1] print("AUC:", roc_auc_score(y_test, y_prob))九、Step 7:可解释性
1️⃣ SHAP 示例
import shap explainer = shap.TreeExplainer(model) shap_values = explainer.shap_values(X_test) shap.summary_plot(shap_values, X_test)👉 你可以清楚看到:
哪些影像特征重要
哪些临床指标起决定作用
十、真实医疗项目的 5 条血泪经验
1️⃣ 不要迷信大模型
2️⃣ 稳定性 > 精度
3️⃣ 特征质量 > 网络深度
4️⃣ 医生信任比 AUC 更重要
5️⃣CNN + XGBoost 是成熟方案,不是退而求其次
十一、总结
CNN 解决“看不懂影像”的问题
XGBoost 解决“怎么做决定”的问题
这不是妥协,而是工程智慧。