【nnUNetv2进阶】五、nnUNetv2 魔改网络-小试牛刀-加入注意力机制SpatialAttention

nnUNet是一个自适应的深度学习框架,专为医学图像分割任务设计。以下是关于nnUNet的详细解释和特点:

自适应框架:nnUNet能够根据具体的医学图像分割任务自动调整模型结构、训练参数等,从而避免了繁琐的手工调参过程。
自动化流程:nnUNet包含了从数据预处理到模型训练、验证及测试的全流程自动化工具,大大简化了使用深度学习进行医学图像分割的复杂度。
自适应网络结构调整:根据输入数据集的特点,nnUNet能够自动选择和配置合适的网络深度、宽度等超参数,确保模型在复杂性和性能之间取得平衡。
Patch-Based Training and Inference:nnUNet使用基于patch级别的训练方法,通过滑窗的方式遍历整个图像进行训练。在推理阶段,也采用类似的方法来生成整个图像的分割结果。这种方法对于处理大尺寸图像或有限显存的情况非常有效。
集成学习与交叉验证:nnUNet还采用了交叉验证策略以最大程度利用有限的数据集,并结合集成学习技术来提高模型预测的稳定性和准确性。
此外,nnUNet还提供了丰富的文档和示例,帮助用户更好地了解和使用该框架。要使用nnUNet,用户需要安装Python和相应的深度学习框架,然后按照官方文档提供的步骤进行操作即可。

总的来说,nnUNet是一个功能强大、易于使用的深度学习框架,特别适用于医学图像分割任务。它的自适应特性、自动化流程和先进的训练策略使得用户能够更高效地构建和训练模型,同时获得更好的性能表现。

之前已经介绍过nnunet的安装、使用以及自定义网络的教程,本文介绍在nnunet中加入SpatialAttention的方法,阅读本文前,请确保已经掌握以下内容:

【nnUNetv2实践】一、nnUNetv2安装

【nnUNetv2实践】二、nnUNetv2快速入门-训练验证推理集成一条龙教程

【nnUNetv2进阶】三、nnUNetv2 自定义网络-发paper必会-CSDN博客

ChannelAttention的改进地址:

【nnUNetv2进阶】四、nnUNetv2 魔改网络-小试牛刀-加入注意力机制ChannelAttention-CSDN博客

本文介绍在nnunet中加入SpatialAttention的方法,SpatialAttention是一种非常简单的注意力机制,非常适合魔改网络练手之用,更高级的魔改教程后续慢慢推出。

一、SpatialAttention

SpatialAttention就是通道注意力机制,其2D代码非常简单,这里不过多介绍其原理,各位朋友可自行搜索其原理。


class SpatialAttention(nn.Module):def __init__(self, conv_op, kernel_size=7):super().__init__()assert kernel_size in (3, 7), "kernel size must be 3 or 7"padding = 3 if kernel_size == 7 else 1if conv_op == torch.nn.modules.conv.Conv2d:self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)elif conv_op == torch.nn.modules.conv.Conv3d:self.cv1 = nn.Conv3d(2, 1, kernel_size, padding=padding, bias=False)self.act = nn.Sigmoid()def forward(self, x):return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))

二、nnunet加入SpatialAttention

之前的教程已经提到过,nnunet的网络需要在dynamic-network-architectures中修改,并在数据集的plan中修改来实现自己的网络训练。

1、网络结构修改

在dynamic-network-architectures的architectures目录下新建saunet.py,如下图:

代码内容如下:


from typing import Union, Type, List, Tupleimport torch
from dynamic_network_architectures.building_blocks.helper import convert_conv_op_to_dim
from dynamic_network_architectures.initialization.weight_init import InitWeights_He
from torch import nn
from torch.nn.modules.conv import _ConvNd
from torch.nn.modules.dropout import _DropoutNd
from dynamic_network_architectures.building_blocks.helper import maybe_convert_scalar_to_list, get_matching_pool_op
import numpy as np
from dynamic_network_architectures.building_blocks.helper import get_matching_convtranspclass SAPlainConvUNet(nn.Module):def __init__(self,input_channels: int,n_stages: int,features_per_stage: Union[int, List[int], Tuple[int, ...]],conv_op: Type[_ConvNd],kernel_sizes: Union[int, List[int], Tuple[int, ...]],strides: Union[int, List[int], Tuple[int, ...]],n_conv_per_stage: Union[int, List[int], Tuple[int, ...]],num_classes: int,n_conv_per_stage_decoder: Union[int, Tuple[int, ...], List[int]],conv_bias: bool = False,norm_op: Union[None, Type[nn.Module]] = None,norm_op_kwargs: dict = None,dropout_op: Union[None, Type[_DropoutNd]] = None,dropout_op_kwargs: dict = None,nonlin: Union[None, Type[torch.nn.Module]] = None,nonlin_kwargs: dict = None,deep_supervision: bool = False,nonlin_first: bool = False):"""nonlin_first: if True you get conv -> nonlin -> norm. Else it's conv -> norm -> nonlin"""super().__init__()if isinstance(n_conv_per_stage, int):n_conv_per_stage = [n_conv_per_stage] * n_stagesif isinstance(n_conv_per_stage_decoder, int):n_conv_per_stage_decoder = [n_conv_per_stage_decoder] * (n_stages - 1)assert len(n_conv_per_stage) == n_stages, "n_conv_per_stage must have as many entries as we have " \f"resolution stages. here: {n_stages}. " \f"n_conv_per_stage: {n_conv_per_stage}"assert len(n_conv_per_stage_decoder) == (n_stages - 1), "n_conv_per_stage_decoder must have one less entries " \f"as we have resolution stages. here: {n_stages} " \f"stages, so it should have {n_stages - 1} entries. " \f"n_conv_per_stage_decoder: {n_conv_per_stage_decoder}"self.encoder = SAPlainConvEncoder(input_channels, n_stages, features_per_stage, conv_op, kernel_sizes, strides,n_conv_per_stage, conv_bias, norm_op, norm_op_kwargs, dropout_op,dropout_op_kwargs, nonlin, nonlin_kwargs, return_skips=True,nonlin_first=nonlin_first)self.decoder = SAUNetDecoder(self.encoder, num_classes, n_conv_per_stage_decoder, deep_supervision,nonlin_first=nonlin_first)print('using sa unet...')def forward(self, x):skips = self.encoder(x)return self.decoder(skips)def compute_conv_feature_map_size(self, input_size):assert len(input_size) == convert_conv_op_to_dim(self.encoder.conv_op), "just give the image size without color/feature channels or " \"batch channel. Do not give input_size=(b, c, x, y(, z)). " \"Give input_size=(x, y(, z))!"return self.encoder.compute_conv_feature_map_size(input_size) + self.decoder.compute_conv_feature_map_size(input_size)@staticmethoddef initialize(module):InitWeights_He(1e-2)(module)class SAPlainConvEncoder(nn.Module):def __init__(self,input_channels: int,n_stages: int,features_per_stage: Union[int, List[int], Tuple[int, ...]],conv_op: Type[_ConvNd],kernel_sizes: Union[int, List[int], Tuple[int, ...]],strides: Union[int, List[int], Tuple[int, ...]],n_conv_per_stage: Union[int, List[int], Tuple[int, ...]],conv_bias: bool = False,norm_op: Union[None, Type[nn.Module]] = None,norm_op_kwargs: dict = None,dropout_op: Union[None, Type[_DropoutNd]] = None,dropout_op_kwargs: dict = None,nonlin: Union[None, Type[torch.nn.Module]] = None,nonlin_kwargs: dict = None,return_skips: bool = False,nonlin_first: bool = False,pool: str = 'conv'):super().__init__()if isinstance(kernel_sizes, int):kernel_sizes = [kernel_sizes] * n_stagesif isinstance(features_per_stage, int):features_per_stage = [features_per_stage] * n_stagesif isinstance(n_conv_per_stage, int):n_conv_per_stage = [n_conv_per_stage] * n_stagesif isinstance(strides, int):strides = [strides] * n_stagesassert len(kernel_sizes) == n_stages, "kernel_sizes must have as many entries as we have resolution stages (n_stages)"assert len(n_conv_per_stage) == n_stages, "n_conv_per_stage must have as many entries as we have resolution stages (n_stages)"assert len(features_per_stage) == n_stages, "features_per_stage must have as many entries as we have resolution stages (n_stages)"assert len(strides) == n_stages, "strides must have as many entries as we have resolution stages (n_stages). " \"Important: first entry is recommended to be 1, else we run strided conv drectly on the input"stages = []for s in range(n_stages):stage_modules = []if pool == 'max' or pool == 'avg':if (isinstance(strides[s], int) and strides[s] != 1) or \isinstance(strides[s], (tuple, list)) and any([i != 1 for i in strides[s]]):stage_modules.append(get_matching_pool_op(conv_op, pool_type=pool)(kernel_size=strides[s], stride=strides[s]))conv_stride = 1elif pool == 'conv':conv_stride = strides[s]else:raise RuntimeError()stage_modules.append(SAStackedConvBlocks(n_conv_per_stage[s], conv_op, input_channels, features_per_stage[s], kernel_sizes[s], conv_stride,conv_bias, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first))stages.append(nn.Sequential(*stage_modules))input_channels = features_per_stage[s]self.stages = nn.Sequential(*stages)self.output_channels = features_per_stageself.strides = [maybe_convert_scalar_to_list(conv_op, i) for i in strides]self.return_skips = return_skips# we store some things that a potential decoder needsself.conv_op = conv_opself.norm_op = norm_opself.norm_op_kwargs = norm_op_kwargsself.nonlin = nonlinself.nonlin_kwargs = nonlin_kwargsself.dropout_op = dropout_opself.dropout_op_kwargs = dropout_op_kwargsself.conv_bias = conv_biasself.kernel_sizes = kernel_sizesdef forward(self, x):ret = []for s in self.stages:x = s(x)ret.append(x)if self.return_skips:return retelse:return ret[-1]def compute_conv_feature_map_size(self, input_size):output = np.int64(0)for s in range(len(self.stages)):if isinstance(self.stages[s], nn.Sequential):for sq in self.stages[s]:if hasattr(sq, 'compute_conv_feature_map_size'):output += self.stages[s][-1].compute_conv_feature_map_size(input_size)else:output += self.stages[s].compute_conv_feature_map_size(input_size)input_size = [i // j for i, j in zip(input_size, self.strides[s])]return outputclass SAUNetDecoder(nn.Module):def __init__(self,encoder: Union[SAPlainConvEncoder],num_classes: int,n_conv_per_stage: Union[int, Tuple[int, ...], List[int]],deep_supervision,nonlin_first: bool = False,norm_op: Union[None, Type[nn.Module]] = None,norm_op_kwargs: dict = None,dropout_op: Union[None, Type[_DropoutNd]] = None,dropout_op_kwargs: dict = None,nonlin: Union[None, Type[torch.nn.Module]] = None,nonlin_kwargs: dict = None,conv_bias: bool = None):"""This class needs the skips of the encoder as input in its forward.the encoder goes all the way to the bottleneck, so that's where the decoder picks up. stages in the decoderare sorted by order of computation, so the first stage has the lowest resolution and takes the bottleneckfeatures and the lowest skip as inputsthe decoder has two (three) parts in each stage:1) conv transpose to upsample the feature maps of the stage below it (or the bottleneck in case of the first stage)2) n_conv_per_stage conv blocks to let the two inputs get to know each other and merge3) (optional if deep_supervision=True) a segmentation output Todo: enable upsample logits?:param encoder::param num_classes::param n_conv_per_stage::param deep_supervision:"""super().__init__()self.deep_supervision = deep_supervisionself.encoder = encoderself.num_classes = num_classesn_stages_encoder = len(encoder.output_channels)if isinstance(n_conv_per_stage, int):n_conv_per_stage = [n_conv_per_stage] * (n_stages_encoder - 1)assert len(n_conv_per_stage) == n_stages_encoder - 1, "n_conv_per_stage must have as many entries as we have " \"resolution stages - 1 (n_stages in encoder - 1), " \"here: %d" % n_stages_encodertranspconv_op = get_matching_convtransp(conv_op=encoder.conv_op)conv_bias = encoder.conv_bias if conv_bias is None else conv_biasnorm_op = encoder.norm_op if norm_op is None else norm_opnorm_op_kwargs = encoder.norm_op_kwargs if norm_op_kwargs is None else norm_op_kwargsdropout_op = encoder.dropout_op if dropout_op is None else dropout_opdropout_op_kwargs = encoder.dropout_op_kwargs if dropout_op_kwargs is None else dropout_op_kwargsnonlin = encoder.nonlin if nonlin is None else nonlinnonlin_kwargs = encoder.nonlin_kwargs if nonlin_kwargs is None else nonlin_kwargs# we start with the bottleneck and work out way upstages = []transpconvs = []seg_layers = []for s in range(1, n_stages_encoder):input_features_below = encoder.output_channels[-s]input_features_skip = encoder.output_channels[-(s + 1)]stride_for_transpconv = encoder.strides[-s]transpconvs.append(transpconv_op(input_features_below, input_features_skip, stride_for_transpconv, stride_for_transpconv,bias=conv_bias))# input features to conv is 2x input_features_skip (concat input_features_skip with transpconv output)stages.append(SAStackedConvBlocks(n_conv_per_stage[s-1], encoder.conv_op, 2 * input_features_skip, input_features_skip,encoder.kernel_sizes[-(s + 1)], 1,conv_bias,norm_op,norm_op_kwargs,dropout_op,dropout_op_kwargs,nonlin,nonlin_kwargs,nonlin_first))# we always build the deep supervision outputs so that we can always load parameters. If we don't do this# then a model trained with deep_supervision=True could not easily be loaded at inference time where# deep supervision is not needed. It's just a convenience thingseg_layers.append(encoder.conv_op(input_features_skip, num_classes, 1, 1, 0, bias=True))self.stages = nn.ModuleList(stages)self.transpconvs = nn.ModuleList(transpconvs)self.seg_layers = nn.ModuleList(seg_layers)def forward(self, skips):"""we expect to get the skips in the order they were computed, so the bottleneck should be the last entry:param skips::return:"""lres_input = skips[-1]seg_outputs = []for s in range(len(self.stages)):x = self.transpconvs[s](lres_input)x = torch.cat((x, skips[-(s+2)]), 1)x = self.stages[s](x)if self.deep_supervision:seg_outputs.append(self.seg_layers[s](x))elif s == (len(self.stages) - 1):seg_outputs.append(self.seg_layers[-1](x))lres_input = x# invert seg outputs so that the largest segmentation prediction is returned firstseg_outputs = seg_outputs[::-1]if not self.deep_supervision:r = seg_outputs[0]else:r = seg_outputsreturn rdef compute_conv_feature_map_size(self, input_size):"""IMPORTANT: input_size is the input_size of the encoder!:param input_size::return:"""# first we need to compute the skip sizes. Skip bottleneck because all output feature maps of our ops will at# least have the size of the skip above that (therefore -1)skip_sizes = []for s in range(len(self.encoder.strides) - 1):skip_sizes.append([i // j for i, j in zip(input_size, self.encoder.strides[s])])input_size = skip_sizes[-1]# print(skip_sizes)assert len(skip_sizes) == len(self.stages)# our ops are the other way around, so let's match things upoutput = np.int64(0)for s in range(len(self.stages)):# print(skip_sizes[-(s+1)], self.encoder.output_channels[-(s+2)])# conv blocksoutput += self.stages[s].compute_conv_feature_map_size(skip_sizes[-(s+1)])# trans convoutput += np.prod([self.encoder.output_channels[-(s+2)], *skip_sizes[-(s+1)]], dtype=np.int64)# segmentationif self.deep_supervision or (s == (len(self.stages) - 1)):output += np.prod([self.num_classes, *skip_sizes[-(s+1)]], dtype=np.int64)return outputclass SAStackedConvBlocks(nn.Module):def __init__(self,num_convs: int,conv_op: Type[_ConvNd],input_channels: int,output_channels: Union[int, List[int], Tuple[int, ...]],kernel_size: Union[int, List[int], Tuple[int, ...]],initial_stride: Union[int, List[int], Tuple[int, ...]],conv_bias: bool = False,norm_op: Union[None, Type[nn.Module]] = None,norm_op_kwargs: dict = None,dropout_op: Union[None, Type[_DropoutNd]] = None,dropout_op_kwargs: dict = None,nonlin: Union[None, Type[torch.nn.Module]] = None,nonlin_kwargs: dict = None,nonlin_first: bool = False):""":param conv_op::param num_convs::param input_channels::param output_channels: can be int or a list/tuple of int. If list/tuple are provided, each entry is forone conv. The length of the list/tuple must then naturally be num_convs:param kernel_size::param initial_stride::param conv_bias::param norm_op::param norm_op_kwargs::param dropout_op::param dropout_op_kwargs::param nonlin::param nonlin_kwargs:"""super().__init__()if not isinstance(output_channels, (tuple, list)):output_channels = [output_channels] * num_convsself.convs = nn.Sequential(ConvDropoutNormReLU(conv_op, input_channels, output_channels[0], kernel_size, initial_stride, conv_bias, norm_op,norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first),*[ConvDropoutNormReLU(conv_op, output_channels[i - 1], output_channels[i], kernel_size, 1, conv_bias, norm_op,norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first)for i in range(1, num_convs-1)],SA(conv_op, output_channels[-2], output_channels[-1], kernel_size, 1, conv_bias, norm_op,norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first))self.act = nonlin(**nonlin_kwargs)self.output_channels = output_channels[-1]self.initial_stride = maybe_convert_scalar_to_list(conv_op, initial_stride)def forward(self, x):out = self.convs(x)out = self.act(out)return outdef compute_conv_feature_map_size(self, input_size):assert len(input_size) == len(self.initial_stride), "just give the image size without color/feature channels or " \"batch channel. Do not give input_size=(b, c, x, y(, z)). " \"Give input_size=(x, y(, z))!"output = self.convs[0].compute_conv_feature_map_size(input_size)size_after_stride = [i // j for i, j in zip(input_size, self.initial_stride)]for b in self.convs[1:]:output += b.compute_conv_feature_map_size(size_after_stride)return outputclass ConvDropoutNormReLU(nn.Module):def __init__(self,conv_op: Type[_ConvNd],input_channels: int,output_channels: int,kernel_size: Union[int, List[int], Tuple[int, ...]],stride: Union[int, List[int], Tuple[int, ...]],conv_bias: bool = False,norm_op: Union[None, Type[nn.Module]] = None,norm_op_kwargs: dict = None,dropout_op: Union[None, Type[_DropoutNd]] = None,dropout_op_kwargs: dict = None,nonlin: Union[None, Type[torch.nn.Module]] = None,nonlin_kwargs: dict = None,nonlin_first: bool = False):super(ConvDropoutNormReLU, self).__init__()self.input_channels = input_channelsself.output_channels = output_channelsstride = maybe_convert_scalar_to_list(conv_op, stride)self.stride = stridekernel_size = maybe_convert_scalar_to_list(conv_op, kernel_size)if norm_op_kwargs is None:norm_op_kwargs = {}if nonlin_kwargs is None:nonlin_kwargs = {}ops = []self.conv = conv_op(input_channels,output_channels,kernel_size,stride,padding=[(i - 1) // 2 for i in kernel_size],dilation=1,bias=conv_bias,)ops.append(self.conv)if dropout_op is not None:self.dropout = dropout_op(**dropout_op_kwargs)ops.append(self.dropout)if norm_op is not None:self.norm = norm_op(output_channels, **norm_op_kwargs)ops.append(self.norm)if nonlin is not None:self.nonlin = nonlin(**nonlin_kwargs)ops.append(self.nonlin)if nonlin_first and (norm_op is not None and nonlin is not None):ops[-1], ops[-2] = ops[-2], ops[-1]self.all_modules = nn.Sequential(*ops)def forward(self, x):return self.all_modules(x)def compute_conv_feature_map_size(self, input_size):assert len(input_size) == len(self.stride), "just give the image size without color/feature channels or " \"batch channel. Do not give input_size=(b, c, x, y(, z)). " \"Give input_size=(x, y(, z))!"output_size = [i // j for i, j in zip(input_size, self.stride)]  # we always do same paddingreturn np.prod([self.output_channels, *output_size], dtype=np.int64)class ConvDropoutNorm(nn.Module):def __init__(self,conv_op: Type[_ConvNd],input_channels: int,output_channels: int,kernel_size: Union[int, List[int], Tuple[int, ...]],stride: Union[int, List[int], Tuple[int, ...]],conv_bias: bool = False,norm_op: Union[None, Type[nn.Module]] = None,norm_op_kwargs: dict = None,dropout_op: Union[None, Type[_DropoutNd]] = None,dropout_op_kwargs: dict = None,nonlin: Union[None, Type[torch.nn.Module]] = None,nonlin_kwargs: dict = None,nonlin_first: bool = False):super(ConvDropoutNorm, self).__init__()self.input_channels = input_channelsself.output_channels = output_channelsstride = maybe_convert_scalar_to_list(conv_op, stride)self.stride = stridekernel_size = maybe_convert_scalar_to_list(conv_op, kernel_size)if norm_op_kwargs is None:norm_op_kwargs = {}if nonlin_kwargs is None:nonlin_kwargs = {}ops = []self.conv = conv_op(input_channels,output_channels,kernel_size,stride,padding=[(i - 1) // 2 for i in kernel_size],dilation=1,bias=conv_bias,)ops.append(self.conv)if dropout_op is not None:self.dropout = dropout_op(**dropout_op_kwargs)ops.append(self.dropout)if norm_op is not None:self.norm = norm_op(output_channels, **norm_op_kwargs)ops.append(self.norm)self.all_modules = nn.Sequential(*ops)def forward(self, x):return self.all_modules(x)def compute_conv_feature_map_size(self, input_size):assert len(input_size) == len(self.stride), "just give the image size without color/feature channels or " \"batch channel. Do not give input_size=(b, c, x, y(, z)). " \"Give input_size=(x, y(, z))!"output_size = [i // j for i, j in zip(input_size, self.stride)]  # we always do same paddingreturn np.prod([self.output_channels, *output_size], dtype=np.int64)class SA(nn.Module):def __init__(self,conv_op: Type[_ConvNd],input_channels: int,output_channels: int,kernel_size: Union[int, List[int], Tuple[int, ...]],stride: Union[int, List[int], Tuple[int, ...]],conv_bias: bool = False,norm_op: Union[None, Type[nn.Module]] = None,norm_op_kwargs: dict = None,dropout_op: Union[None, Type[_DropoutNd]] = None,dropout_op_kwargs: dict = None,nonlin: Union[None, Type[torch.nn.Module]] = None,nonlin_kwargs: dict = None,nonlin_first: bool = False):super(SA, self).__init__()self.input_channels = input_channelsself.output_channels = output_channelsstride = maybe_convert_scalar_to_list(conv_op, stride)self.stride = stridekernel_size = maybe_convert_scalar_to_list(conv_op, kernel_size)if norm_op_kwargs is None:norm_op_kwargs = {}if nonlin_kwargs is None:nonlin_kwargs = {}ops = []self.conv = conv_op(input_channels,output_channels,kernel_size,stride,padding=[(i - 1) // 2 for i in kernel_size],dilation=1,bias=conv_bias,)ops.append(self.conv)if dropout_op is not None:self.dropout = dropout_op(**dropout_op_kwargs)ops.append(self.dropout)if norm_op is not None:self.norm = norm_op(output_channels, **norm_op_kwargs)ops.append(self.norm)self.all_modules = nn.Sequential(*ops)self.sa = SpatialAttention(conv_op=conv_op, kernel_size=7)def forward(self, x):x =  self.all_modules(x)x = self.sa(x) * xreturn xdef compute_conv_feature_map_size(self, input_size):assert len(input_size) == len(self.stride), "just give the image size without color/feature channels or " \"batch channel. Do not give input_size=(b, c, x, y(, z)). " \"Give input_size=(x, y(, z))!"output_size = [i // j for i, j in zip(input_size, self.stride)]  # we always do same paddingreturn np.prod([self.output_channels, *output_size], dtype=np.int64)class SpatialAttention(nn.Module):"""Spatial-attention module."""def __init__(self, conv_op, kernel_size=7):"""Initialize Spatial-attention module with kernel size argument."""super().__init__()assert kernel_size in (3, 7), "kernel size must be 3 or 7"padding = 3 if kernel_size == 7 else 1if conv_op == torch.nn.modules.conv.Conv2d:self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)elif conv_op == torch.nn.modules.conv.Conv3d:self.cv1 = nn.Conv3d(2, 1, kernel_size, padding=padding, bias=False)self.act = nn.Sigmoid()def forward(self, x):"""Apply channel and spatial attention on input for feature recalibration."""return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))

简单说下修改思路:在plainconvunet的StackedConvBlocks中加入SpatialAttention模块。

2、配置文件修改

在完成了模型修改后,还是用上个教程的Task04_Hippocampus数据集来验证(如果没做上个教程的,自行完成数据处理),编辑nnUNet\nnUNet_preprocessed\Dataset004_Hippocampus\nnUNetPlans.json这个配置文件,进行以下改动,把network_class_name改成dynamic_network_architectures.architectures.saunet.SAPlainConvUNet,如下图:

三、模型训练

完成了模型和数据集配置文件的修改后,开始训练模型,使用的数据集还是Task04_Hippocampus,以上的代码支持2d和3d模型,可以使用以下的训练命令:

nnUNetv2_train 4 2d 0  
nnUNetv2_train 4 2d 1 
nnUNetv2_train 4 2d 2  
nnUNetv2_train 4 2d 3 
nnUNetv2_train 4 2d 4  nnUNetv2_train 4 3d_fullres 0 
nnUNetv2_train 4 3d_fullres 1
nnUNetv2_train 4 3d_fullres 2 
nnUNetv2_train 4 3d_fullres 3 
nnUNetv2_train 4 3d_fullres 4 

可以看到模型已经成功跑起来了:

因为nnunet训练非常的久,实验资源有限,没有完成全部训练,只完成了代码修改及跑通。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/823287.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

H3C交换机ACL部分规则不生效问题

问题描述 H3C交换机ACL部分规则不生效问题 H3C交换机配置ACL后,规则在100左右,约10个接口下调用后,单独 permit 4个指定源、目IP地址的流量。但是只有前2个生效,后边2个相同的配置不生效。 问题原因 ACL性能不够的问题 dis q…

刀具表面上的微结构

刀具表面微结构通常指在刀具表面对特定功能设计的微观纹理,这些纹理可以是沟槽、凹坑、凸起或任何其他形式的微观图案。这些微结构的设计和应用是为了改善刀具的切削性能,减少切削力和切削温度,提高切削效率和精度,同时降低切削液…

前端常见面试题:HTML+CSS

1. title与h1的区别、b与strong的区别、i与em的区别? title与h1的区别: title标签用于定义整个HTML文档的标题,它显示在浏览器窗口的标题栏或者标签页上。每个HTML文档只应该有一个title标签,它对搜索引擎优化(SEO&a…

mysql performance schema 实践

参考MySQL调优性能监控之performance schema,做了一些扩展 1 2、哪类SQL的平均响应时间最多 SUM_NO_INDEX_USED>0用来过滤那些没有使用的查询。 SELECT SCHEMA_NAME,DIGEST_TEXT,AVG_TIMER_WAIT,MAX_TIMER_WAIT,SUM_LOCK_TIME,SUM_ERRORS ,SUM_SELECT_FULL_JOIN,SUM_NO_IND…

大规模端云协同智能计算(大小模型端云协同联合学习)

原文作者:上海交通大学 吴帆 0 引言 目前,许多智能技术已经走入人们的日常生活,例如交互式商品推荐、人脸识别、语音识别、体征观测、疾病诊断、智能决策等,这些智能技术在为我们生活带来极大便利的同时,也深刻改变了…

前端开发攻略---实现与ChatGPT同款光标闪烁打字效果。

1、演示 2、实现代码 <!DOCTYPE html> <html lang"ch-ZN"><head><meta charset"UTF-8" /><meta http-equiv"X-UA-Compatible" content"IEedge" /><meta name"viewport" content"widt…

Spingbot项目配置mySQL或postgresSQL详解

1&#xff1a;postgresql库: yml文件 探索PostgreSQL&#xff1a;从基础到实践&#xff08;简单实例&#xff09; # PageHelper分页插件 pagehelper:helperDialect: postgresqlreasonable: truesupportMethodsArguments: trueparams: countcountSql# 数据源配置 spring:datas…

SQLite作为应用程序文件格式(二十八)

返回&#xff1a;SQLite—系列文章目录 上一篇:SQLite数据库中JSON 函数和运算符(二十七) 下一篇&#xff1a;SQLite—系列文章目录 摘要 具有定义架构的 SQLite 数据库文件 通常是一种出色的应用程序文件格式。 以下是十几个原因&#xff1a; 简化的应用程序开发单文…

web安全学习笔记(9)

记一下第十三课的内容。 准备工作&#xff1a;在根目录下创建template目录&#xff0c;将login.html放入其中&#xff0c;在该目录下新建一个reg.html。在根目录下创建一个function.php 一、函数声明与传参 PHP中的函数定义和其他语言基本上是相同的。我们编辑function.php …

机器学习引领金融革命:重塑金融服务领域新格局,开启智能化新篇章

&#x1f9d1; 作者简介&#xff1a;阿里巴巴嵌入式技术专家&#xff0c;深耕嵌入式人工智能领域&#xff0c;具备多年的嵌入式硬件产品研发管理经验。 &#x1f4d2; 博客介绍&#xff1a;分享嵌入式开发领域的相关知识、经验、思考和感悟,欢迎关注。提供嵌入式方向的学习指导…

element-ui报表合计逻辑踩坑

element-ui报表合计逻辑踩坑 1.快速实现一个合计 ​ Element UI所提供的el-table中提供了方便快捷的合计逻辑实现&#xff1a; ​ https://element.eleme.cn/#/zh-CN/component/table ​ 此实现方法在官方文档中介绍详细&#xff0c;此处不多赘述。 ​ 这里需要注意&#x…

【C语言】万字讲解函数栈帧的创建与销毁

目录 前言 一、什么是函数栈帧&#xff1f; 二、理解函数栈帧能解决什么问题呢 三、函数栈帧的创建和销毁解析 3.1 什么是栈&#xff1f; 3.2 认识相关寄存器和汇编指令 3.3 剖析函数栈帧的创建和销毁 3.3.1 esp寄存器与ebp寄存器的重要性 3.3.2 函数的调用堆栈 3.3.…

SAP MRP-MD01与MRP LIVE-MD01N简介

自从SAP推出HANA以后,无论在做项目还是在面试的时候都会遇到一个问题,就是MRP和MRP LIVE 有什么区别。通常顾问都知道MRPLIVE是运行在内存中的,运行效率会优于传统的MRP。经历了很多家的公司都是HANA的系统,基本都很少会用到MRP LIVE,百分之98%都还是在用传统的MRP在跑物料…

利用 Python 开发手机 App 实战

Python语言虽然很万能&#xff0c;但用它来开发app还是显得有点不对路&#xff0c;因此用Python开发的app应当是作为编码练习、或者自娱自乐所用&#xff0c;加上目前这方面的模块还不是特别成熟&#xff0c;bug比较多&#xff0c;总而言之&#xff0c;劝君莫轻入。 准备工作 …

c++的学习之路:24、 二叉搜索树概念

摘要 本章主要是讲一下二叉搜索树的实现 目录 摘要 一、二叉搜索树概念 二、 二叉搜索树操作 1、二叉搜索树的查找 2、二叉搜索树的插入 3、二叉搜索树的删除 三、二叉搜索树的实现 1、插入 2、中序遍历 3、删除 4、查找 四、二叉搜索树的递归实现 1、插入 2、删…

Leetcode刷题之合并两个有序数组

Leetcode刷题之合并两个有序数组 一、题目描述二、题目解析 一、题目描述 给你两个按 非递减顺序 排列的整数数组 nums1 和 nums2&#xff0c;另有两个整数 m 和 n &#xff0c;分别表示 nums1 和 nums2 中的元素数目。 请你 合并 nums2 到 nums1 中&#xff0c;使合并后的数…

去哪网拿去花不能提现,只能用于透支消费,那么拿去花提现是怎么实现呢?

去哪网拿去花不能提现&#xff0c;只能用于透支消费&#xff0c;那么拿去花提现是怎么实现呢&#xff1f; 申请携程拿去花之后&#xff0c;有一些人就会想着把钱提现出来拿去用。一般来说&#xff0c;他们都是通过线下门店来提现拿去花&#xff0c;拿去花允许用户先消费后付款&…

Python文件操作大全

1 文件操作 1.1 文件打开与关闭 1.1.1 打开文件 在Python中&#xff0c;你可以使用 open() 函数来打开文件。以下是一个简单的例子&#xff1a; # 打开文件&#xff08;默认为只读模式&#xff09; file_path example.txt with open(file_path, r) as file:# 执行文件操作…

LeetCode-二叉树修剪

每日一题 今天遇到的题比较简单&#xff0c;是一道二叉树的题。 题目要求 给定一个二叉树 根节点 root &#xff0c;树的每个节点的值要么是 0&#xff0c;要么是 1。请剪除该二叉树中所有节点的值为 0 的子树。 节点 node 的子树为 node 本身&#xff0c;以及所有 node 的…

appium2报错:Failed to create session. ‘automationName‘ can‘t be blank

1、问题概述&#xff1f; 今天在window环境中安装了appium2.5.2版本&#xff0c;通过appium inspector连接真机的时候报错如下&#xff1a; Failed to create session. automationName cant be blank 原因分析&#xff1a;这是因为appium2的比appium1有了很大的改进&#xff…