SVM在高光谱分类中的优势
| 优势 | 说明 |
|---|---|
| 小样本学习 | 在高光谱标注样本有限的情况下仍能有效学习 |
| 高维处理 | 适合处理高光谱数据的高维特征 |
| 非线性分类 | 通过核函数处理复杂的非线性分类问题 |
| 泛化能力强 | 基于结构风险最小化原理,泛化性能好 |
完整MATLAB实现代码
1. 数据加载与预处理
function [data, labels, feature_names] = load_hyperspectral_data()% 加载高光谱数据% 这里以Indian Pines数据集为例% 如果已有数据文件,可以直接加载% load('Indian_pines.mat');% load('Indian_pines_gt.mat');% 或者使用模拟数据演示fprintf('生成模拟高光谱数据...\n');num_samples = 2000; % 总样本数num_bands = 200; % 波段数num_classes = 6; % 地物类别数% 生成模拟高光谱数据(正态分布模拟不同地物)data = zeros(num_samples, num_bands);labels = zeros(num_samples, 1);% 不同类别具有不同的光谱特征class_centers = linspace(0.3, 0.8, num_classes);for i = 1:num_classesclass_samples = floor(num_samples / num_classes);start_idx = (i-1) * class_samples + 1;end_idx = min(i * class_samples, num_samples);% 为每个类别生成具有特定光谱特征的数据for j = 1:num_bandscenter_val = class_centers(i) * sin(j/50) + 0.2;data(start_idx:end_idx, j) = center_val + 0.1 * randn(end_idx-start_idx+1, 1);endlabels(start_idx:end_idx) = i;end% 特征名称(波段)feature_names = arrayfun(@(x) sprintf('Band_%d', x), 1:num_bands, 'UniformOutput', false);fprintf('数据维度: %d × %d\n', size(data, 1), size(data, 2));fprintf('类别数量: %d\n', num_classes);fprintf('类别分布: \n');tabulate(labels);
end
2. 特征提取与降维
function [features_selected, selected_indices] = feature_selection_hyperspectral(data, labels, method)% 高光谱特征选择% method: 'PCA', 'SPA', 'CARS', 'RF'switch methodcase 'PCA'% 主成分分析[coeff, score, ~, ~, explained] = pca(data);% 选择累计贡献率>95%的主成分cum_explained = cumsum(explained);num_components = find(cum_explained >= 95, 1);features_selected = score(:, 1:num_components);selected_indices = 1:num_components;fprintf('PCA选择 %d 个主成分 (累计方差: %.2f%%)\n', ...num_components, cum_explained(num_components));case 'SPA'% 连续投影算法 - 简化实现num_selected = min(30, size(data, 2));selected_indices = successive_projections_algorithm(data, num_selected);features_selected = data(:, selected_indices);case 'RF'% 基于随机森林的特征重要性tree = fitensemble(data, labels, 'Bag', 100, 'Tree', 'Type', 'Classification');imp = oobPermutedPredictorImportance(tree);[~, idx] = sort(imp, 'descend');selected_indices = idx(1:min(50, length(idx)));features_selected = data(:, selected_indices);otherwise% 默认使用所有特征features_selected = data;selected_indices = 1:size(data, 2);end
endfunction selected_indices = successive_projections_algorithm(data, k)% 简化的连续投影算法实现[n, p] = size(data);selected_indices = zeros(1, k);% 选择初始波长(反射率方差最大的)[~, selected_indices(1)] = max(var(data));for i = 2:kavailable_indices = setdiff(1:p, selected_indices(1:i-1));projections = zeros(length(available_indices), 1);for j = 1:length(available_indices)idx = available_indices(j);% 计算投影向量x_j = data(:, idx);proj_sum = 0;for m = 1:i-1x_m = data(:, selected_indices(m));proj_sum = proj_sum + (x_j' * x_m) / (x_m' * x_m) * x_m;endprojections(j) = norm(x_j - proj_sum);end[~, max_idx] = max(projections);selected_indices(i) = available_indices(max_idx);end
end
3. SVM分类器实现
function svm_model = train_svm_classifier(features, labels, kernel_type)% 训练SVM分类器% kernel_type: 'linear', 'rbf', 'polynomial'% 数据标准化features = zscore(features);% 设置SVM参数switch kernel_typecase 'linear'template = templateSVM('KernelFunction', 'linear', ...'BoxConstraint', 1, ...'Standardize', true);case 'rbf'template = templateSVM('KernelFunction', 'rbf', ...'BoxConstraint', 1, ...'KernelScale', 'auto', ...'Standardize', true);case 'polynomial'template = templateSVM('KernelFunction', 'polynomial', ...'BoxConstraint', 1, ...'PolynomialOrder', 3, ...'Standardize', true);end% 训练多类SVM分类器svm_model = fitcecoc(features, labels, ...'Learners', template, ...'Coding', 'onevsone', ...'Verbose', 1);fprintf('SVM分类器训练完成 (核函数: %s)\n', kernel_type);
endfunction [accuracy, confusion_mat, class_report] = evaluate_svm_model(model, features_test, labels_test)% 评估SVM模型性能% 预测labels_pred = predict(model, features_test);% 计算准确率accuracy = sum(labels_pred == labels_test) / length(labels_test);% 混淆矩阵confusion_mat = confusionmat(labels_test, labels_pred);% 各类别性能指标unique_labels = unique(labels_test);class_report = struct();for i = 1:length(unique_labels)true_positive = sum((labels_test == unique_labels(i)) & (labels_pred == unique_labels(i)));false_positive = sum((labels_test ~= unique_labels(i)) & (labels_pred == unique_labels(i)));false_negative = sum((labels_test == unique_labels(i)) & (labels_pred ~= unique_labels(i)));precision = true_positive / (true_positive + false_positive + eps);recall = true_positive / (true_positive + false_negative + eps);f1_score = 2 * (precision * recall) / (precision + recall + eps);class_report(i).Class = unique_labels(i);class_report(i).Precision = precision;class_report(i).Recall = recall;class_report(i).F1_Score = f1_score;class_report(i).Support = sum(labels_test == unique_labels(i));endfprintf('测试集准确率: %.4f\n', accuracy);
end
4. 参数优化与交叉验证
function [best_model, best_params] = optimize_svm_parameters(features, labels)% SVM参数优化% 创建优化变量box_constraint = optimizableVariable('BoxConstraint', [0.1, 100], 'Transform', 'log');kernel_scale = optimizableVariable('KernelScale', [0.1, 100], 'Transform', 'log');% 对于RBF核if size(features, 2) > 10 % 如果特征维度较高,使用RBF核kernel_function = 'rbf';variables = [box_constraint, kernel_scale];elsekernel_function = 'linear';variables = box_constraint;end% 目标函数fun = @(params)svm_crossval_loss(features, labels, kernel_function, params);% 贝叶斯优化results = bayesopt(fun, variables, ...'MaxTime', 300, ...'IsObjectiveDeterministic', false, ...'NumSeedPoints', 10, ...'Verbose', 1);% 获取最佳参数best_params = results.XAtMinObjective;% 使用最佳参数训练最终模型if strcmp(kernel_function, 'rbf')template = templateSVM('KernelFunction', kernel_function, ...'BoxConstraint', best_params.BoxConstraint, ...'KernelScale', best_params.KernelScale, ...'Standardize', true);elsetemplate = templateSVM('KernelFunction', kernel_function, ...'BoxConstraint', best_params.BoxConstraint, ...'Standardize', true);endbest_model = fitcecoc(features, labels, 'Learners', template);fprintf('参数优化完成\n');
endfunction loss = svm_crossval_loss(features, labels, kernel_function, params)% SVM交叉验证损失函数% 设置SVM参数if strcmp(kernel_function, 'rbf')template = templateSVM('KernelFunction', kernel_function, ...'BoxConstraint', params.BoxConstraint, ...'KernelScale', params.KernelScale, ...'Standardize', true);elsetemplate = templateSVM('KernelFunction', kernel_function, ...'BoxConstraint', params.BoxConstraint, ...'Standardize', true);end% 5折交叉验证cv_model = crossval(fitcecoc(features, labels, 'Learners', template), 'KFold', 5);% 计算分类误差loss = kfoldLoss(cv_model);
end
5. 完整的分类系统
function hyperspectral_svm_classification_system()% 高光谱遥感图像SVM分类完整系统close all; clc;fprintf('=== 高光谱遥感图像SVM分类系统 ===\n\n');%% 1. 数据加载与探索fprintf('步骤1: 数据加载...\n');[data, labels, feature_names] = load_hyperspectral_data();% 数据探索figure('Position', [100, 100, 1200, 800]);subplot(2, 3, 1);plot_mean_spectra(data, labels);title('各类别平均光谱曲线');%% 2. 数据预处理fprintf('步骤2: 数据预处理...\n');% 划分训练集和测试集 (70%训练, 30%测试)rng(42); % 设置随机种子确保可重复性cv = cvpartition(labels, 'HoldOut', 0.3);data_train = data(cv.training, :);labels_train = labels(cv.training);data_test = data(cv.test, :);labels_test = labels(cv.test);fprintf('训练集: %d 样本\n', size(data_train, 1));fprintf('测试集: %d 样本\n', size(data_test, 1));%% 3. 特征选择fprintf('步骤3: 特征选择...\n');feature_methods = {'PCA', 'RF', 'None'};feature_results = struct();for i = 1:length(feature_methods)method = feature_methods{i};fprintf(' 使用 %s 方法进行特征选择...\n', method);[features_train, selected_idx] = feature_selection_hyperspectral(data_train, labels_train, method);features_test = data_test(:, selected_idx);% 存储结果feature_results(i).Method = method;feature_results(i).FeaturesTrain = features_train;feature_results(i).FeaturesTest = features_test;feature_results(i).SelectedIndices = selected_idx;end%% 4. SVM模型训练与评估fprintf('步骤4: SVM模型训练...\n');kernel_types = {'linear', 'rbf'};results = struct();result_count = 1;for feat_idx = 1:length(feature_methods)for kernel_idx = 1:length(kernel_types)fprintf(' 训练: %s特征 + %s核SVM\n', ...feature_methods{feat_idx}, kernel_types{kernel_idx});% 训练SVM模型svm_model = train_svm_classifier(...feature_results(feat_idx).FeaturesTrain, ...labels_train, kernel_types{kernel_idx});% 模型评估[accuracy, confusion_mat, class_report] = evaluate_svm_model(...svm_model, ...feature_results(feat_idx).FeaturesTest, ...labels_test);% 存储结果results(result_count).FeatureMethod = feature_methods{feat_idx};results(result_count).KernelType = kernel_types{kernel_idx};results(resultCount).Accuracy = accuracy;results(resultCount).ConfusionMatrix = confusion_mat;results(resultCount).ClassReport = class_report;results(resultCount).Model = svm_model;result_count = result_count + 1;endend%% 5. 参数优化fprintf('步骤5: 参数优化...\n');best_feat_idx = 1; % 选择PCA特征进行优化[optimized_model, best_params] = optimize_svm_parameters(...feature_results(best_feat_idx).FeaturesTrain, labels_train);% 评估优化后的模型[opt_accuracy, opt_confusion, opt_report] = evaluate_svm_model(...optimized_model, ...feature_results(best_feat_idx).FeaturesTest, ...labels_test);results(result_count).FeatureMethod = 'PCA_Optimized';results(result_count).KernelType = 'rbf';results(result_count).Accuracy = opt_accuracy;results(result_count).ConfusionMatrix = opt_confusion;results(result_count).ClassReport = opt_report;results(result_count).Model = optimized_model;results(result_count).BestParams = best_params;%% 6. 结果可视化fprintf('步骤6: 结果可视化...\n');plot_classification_results(results, data_test, labels_test, feature_results);%% 7. 性能总结print_performance_summary(results);
end
6. 可视化函数
function plot_mean_spectra(data, labels)% 绘制各类别平均光谱曲线unique_labels = unique(labels);colors = lines(length(unique_labels));hold on;for i = 1:length(unique_labels)class_data = data(labels == unique_labels(i), :);mean_spectrum = mean(class_data, 1);std_spectrum = std(class_data, 0, 1);x = 1:length(mean_spectrum);plot(x, mean_spectrum, 'Color', colors(i,:), 'LineWidth', 2, ...'DisplayName', sprintf('Class %d', unique_labels(i)));% 绘制标准差区域patch([x, fliplr(x)], ...[mean_spectrum + std_spectrum, fliplr(mean_spectrum - std_spectrum)], ...colors(i,:), 'FaceAlpha', 0.2, 'EdgeColor', 'none');endhold off;xlabel('波段');ylabel('反射率');legend('show');grid on;
endfunction plot_classification_results(results, data_test, labels_test, feature_results)% 绘制分类结果figure('Position', [100, 100, 1400, 1000]);% 1. 准确率比较subplot(2, 3, 1);accuracies = [results.Accuracy];methods = cellfun(@(x,y) sprintf('%s\n%s', x, y), ...{results.FeatureMethod}, {results.KernelType}, ...'UniformOutput', false);bar(accuracies);set(gca, 'XTickLabel', methods, 'XTickLabelRotation', 45);ylabel('准确率');title('不同方法准确率比较');grid on;% 添加数值标签for i = 1:length(accuracies)text(i, accuracies(i) + 0.01, sprintf('%.3f', accuracies(i)), ...'HorizontalAlignment', 'center');end% 2. 最佳模型的混淆矩阵subplot(2, 3, 2);[~, best_idx] = max(accuracies);best_confusion = results(best_idx).ConfusionMatrix;imagesc(best_confusion);colorbar;title(sprintf('最佳模型混淆矩阵\n(%s + %s)', ...results(best_idx).FeatureMethod, results(best_idx).KernelType));xlabel('预测类别');ylabel('真实类别');% 3. 各类别F1分数subplot(2, 3, 3);best_report = results(best_idx).ClassReport;f1_scores = [best_report.F1_Score];class_labels = [best_report.Class];bar(f1_scores);set(gca, 'XTickLabel', arrayfun(@num2str, class_labels, 'UniformOutput', false));ylabel('F1分数');title('各类别F1分数');grid on;% 4. 特征重要性(如果使用RF特征选择)subplot(2, 3, 4);rf_idx = find(strcmp({results.FeatureMethod}, 'RF'), 1);if ~isempty(rf_idx)% 这里可以绘制特征重要性图plot(1:length(feature_results(2).SelectedIndices), ...ones(1, length(feature_results(2).SelectedIndices)), 'o-');title('选择的特征波段');xlabel('特征索引');ylabel('选择状态');grid on;end% 5. PCA投影可视化subplot(2, 3, 5);pca_features = feature_results(1).FeaturesTest;if size(pca_features, 2) >= 2scatter(pca_features(:,1), pca_features(:,2), 30, labels_test, 'filled');xlabel('第一主成分');ylabel('第二主成分');title('PCA投影可视化');colorbar;end% 6. 学习曲线(简化)subplot(2, 3, 6);plot(accuracies, 'o-', 'LineWidth', 2, 'MarkerSize', 8);xlabel('实验编号');ylabel('准确率');title('模型性能趋势');grid on;sgtitle('高光谱图像SVM分类结果分析', 'FontSize', 14, 'FontWeight', 'bold');
endfunction print_performance_summary(results)% 打印性能总结fprintf('\n=== 性能总结 ===\n');fprintf('%-20s %-10s %-8s\n', '方法', '核函数', '准确率');fprintf('----------------------------------------\n');for i = 1:length(results)fprintf('%-20s %-10s %.4f\n', ...results(i).FeatureMethod, ...results(i).KernelType, ...results(i).Accuracy);end[best_acc, best_idx] = max([results.Accuracy]);fprintf('\n最佳模型: %s + %s核SVM\n', ...results(best_idx).FeatureMethod, results(best_idx).KernelType);fprintf('最佳准确率: %.4f\n', best_acc);if isfield(results(best_idx), 'BestParams')fprintf('最佳参数:\n');disp(results(best_idx).BestParams);end
end
7. 预测新样本
function [predicted_labels, scores] = predict_new_samples(model, new_data, feature_method)% 对新样本进行预测% 特征选择(需要与训练时使用相同的方法)if nargin > 2 && ~strcmp(feature_method, 'None')% 在实际应用中,这里需要保存训练时的特征选择参数new_features = feature_selection_hyperspectral(new_data, [], feature_method);elsenew_features = new_data;end% 标准化new_features = zscore(new_features);% 预测[predicted_labels, scores] = predict(model, new_features);fprintf('完成 %d 个新样本的预测\n', size(new_data, 1));
end
使用说明
- 运行完整系统:
hyperspectral_svm_classification_system();
- 单独训练模型:
[data, labels] = load_hyperspectral_data();
features = feature_selection_hyperspectral(data, labels, 'PCA');
model = train_svm_classifier(features, labels, 'rbf');
- 预测新数据:
new_labels = predict_new_samples(model, new_data, 'PCA');
参考代码 SVM分类用于高光谱遥感图像分类、预测 www.youwenfan.com/contentcnl/79960.html
关键改进建议
-
实际数据适配:
- 替换
load_hyperspectral_data函数以加载真实的高光谱数据 - 调整特征选择参数以适应具体数据集
- 替换
-
性能优化:
- 对于大数据集,考虑使用随机子空间或特征bagging
- 使用GPU加速SVM训练(如果可用)
-
高级技术:
- 结合空间-光谱特征(如Gabor滤波、形态学剖面)
- 使用深度特征提取器预处理数据