【nnUNetv2进阶】六、nnUNetv2 魔改网络-小试牛刀-加入注意力机制CBAM

nnUNet是一个自适应的深度学习框架,专为医学图像分割任务设计。以下是关于nnUNet的详细解释和特点:

自适应框架:nnUNet能够根据具体的医学图像分割任务自动调整模型结构、训练参数等,从而避免了繁琐的手工调参过程。
自动化流程:nnUNet包含了从数据预处理到模型训练、验证及测试的全流程自动化工具,大大简化了使用深度学习进行医学图像分割的复杂度。
自适应网络结构调整:根据输入数据集的特点,nnUNet能够自动选择和配置合适的网络深度、宽度等超参数,确保模型在复杂性和性能之间取得平衡。
Patch-Based Training and Inference:nnUNet使用基于patch级别的训练方法,通过滑窗的方式遍历整个图像进行训练。在推理阶段,也采用类似的方法来生成整个图像的分割结果。这种方法对于处理大尺寸图像或有限显存的情况非常有效。
集成学习与交叉验证:nnUNet还采用了交叉验证策略以最大程度利用有限的数据集,并结合集成学习技术来提高模型预测的稳定性和准确性。
此外,nnUNet还提供了丰富的文档和示例,帮助用户更好地了解和使用该框架。要使用nnUNet,用户需要安装Python和相应的深度学习框架,然后按照官方文档提供的步骤进行操作即可。

总的来说,nnUNet是一个功能强大、易于使用的深度学习框架,特别适用于医学图像分割任务。它的自适应特性、自动化流程和先进的训练策略使得用户能够更高效地构建和训练模型,同时获得更好的性能表现。

之前已经介绍过nnunet的安装、使用以及自定义网络的教程,本文介绍在nnunet中加入CBAM的方法,阅读本文前,请确保已经掌握以下内容:

【nnUNetv2实践】一、nnUNetv2安装

【nnUNetv2实践】二、nnUNetv2快速入门-训练验证推理集成一条龙教程

【nnUNetv2进阶】三、nnUNetv2 自定义网络-发paper必会-CSDN博客

ChannelAttention的改进地址:

【nnUNetv2进阶】四、nnUNetv2 魔改网络-小试牛刀-加入注意力机制ChannelAttention-CSDN博客

【nnUNetv2进阶】五、nnUNetv2 魔改网络-小试牛刀-加入注意力机制SpatialAttention-CSDN博客

本文介绍在nnunet中加入CBAM的方法,CBAM是一种非常简单的注意力机制,非常适合魔改网络练手之用,更高级的魔改教程后续慢慢推出。

一、CBAM

CBAM的结构很简单,是Channel Attention和Spatial Attention的并行拼接,如下图:

其2D代码如下:

class ChannelAttention(nn.Module):def __init__(self, channels: int) -> None:super().__init__()self.pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)self.act = nn.Sigmoid()def forward(self, x: torch.Tensor) -> torch.Tensor:return x * self.act(self.fc(self.pool(x)))class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super().__init__()assert kernel_size in (3, 7), 'kernel size must be 3 or 7'padding = 3 if kernel_size == 7 else 1self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)self.act = nn.Sigmoid()def forward(self, x):return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))class CBAM(nn.Module):def __init__(self, c1, kernel_size=7):super().__init__()self.channel_attention = ChannelAttention(c1)self.spatial_attention = SpatialAttention(kernel_size)def forward(self, x):return self.spatial_attention(self.channel_attention(x))

请注意:经测试CA和SA串行在nnunet中训练loss总是为nan,故将串行更改为并行。

二、nnunet加入CBAM

之前的教程已经提到过,nnunet的网络需要在dynamic-network-architectures中修改,并在数据集的plan中修改来实现自己的网络训练。

1、网络结构修改

在dynamic-network-architectures的architectures目录下新建cbamunet.py,如下图:

代码内容如下(CA和SA更改为并行,具体见代码):


from typing import Union, Type, List, Tupleimport torch
from dynamic_network_architectures.building_blocks.helper import convert_conv_op_to_dim
from dynamic_network_architectures.initialization.weight_init import InitWeights_He
from torch import nn
from torch.nn.modules.conv import _ConvNd
from torch.nn.modules.dropout import _DropoutNd
from dynamic_network_architectures.building_blocks.helper import maybe_convert_scalar_to_list, get_matching_pool_op
import numpy as np
from dynamic_network_architectures.building_blocks.helper import get_matching_convtranspclass CBAMPlainConvUNet(nn.Module):def __init__(self,input_channels: int,n_stages: int,features_per_stage: Union[int, List[int], Tuple[int, ...]],conv_op: Type[_ConvNd],kernel_sizes: Union[int, List[int], Tuple[int, ...]],strides: Union[int, List[int], Tuple[int, ...]],n_conv_per_stage: Union[int, List[int], Tuple[int, ...]],num_classes: int,n_conv_per_stage_decoder: Union[int, Tuple[int, ...], List[int]],conv_bias: bool = False,norm_op: Union[None, Type[nn.Module]] = None,norm_op_kwargs: dict = None,dropout_op: Union[None, Type[_DropoutNd]] = None,dropout_op_kwargs: dict = None,nonlin: Union[None, Type[torch.nn.Module]] = None,nonlin_kwargs: dict = None,deep_supervision: bool = False,nonlin_first: bool = False):"""nonlin_first: if True you get conv -> nonlin -> norm. Else it's conv -> norm -> nonlin"""super().__init__()if isinstance(n_conv_per_stage, int):n_conv_per_stage = [n_conv_per_stage] * n_stagesif isinstance(n_conv_per_stage_decoder, int):n_conv_per_stage_decoder = [n_conv_per_stage_decoder] * (n_stages - 1)assert len(n_conv_per_stage) == n_stages, "n_conv_per_stage must have as many entries as we have " \f"resolution stages. here: {n_stages}. " \f"n_conv_per_stage: {n_conv_per_stage}"assert len(n_conv_per_stage_decoder) == (n_stages - 1), "n_conv_per_stage_decoder must have one less entries " \f"as we have resolution stages. here: {n_stages} " \f"stages, so it should have {n_stages - 1} entries. " \f"n_conv_per_stage_decoder: {n_conv_per_stage_decoder}"self.encoder = CBAMPlainConvEncoder(input_channels, n_stages, features_per_stage, conv_op, kernel_sizes, strides,n_conv_per_stage, conv_bias, norm_op, norm_op_kwargs, dropout_op,dropout_op_kwargs, nonlin, nonlin_kwargs, return_skips=True,nonlin_first=nonlin_first)self.decoder = CBAMUNetDecoder(self.encoder, num_classes, n_conv_per_stage_decoder, deep_supervision,nonlin_first=nonlin_first)print('using cbam unet...')def forward(self, x):skips = self.encoder(x)return self.decoder(skips)def compute_conv_feature_map_size(self, input_size):assert len(input_size) == convert_conv_op_to_dim(self.encoder.conv_op), "just give the image size without color/feature channels or " \"batch channel. Do not give input_size=(b, c, x, y(, z)). " \"Give input_size=(x, y(, z))!"return self.encoder.compute_conv_feature_map_size(input_size) + self.decoder.compute_conv_feature_map_size(input_size)@staticmethoddef initialize(module):InitWeights_He(1e-2)(module)class CBAMPlainConvEncoder(nn.Module):def __init__(self,input_channels: int,n_stages: int,features_per_stage: Union[int, List[int], Tuple[int, ...]],conv_op: Type[_ConvNd],kernel_sizes: Union[int, List[int], Tuple[int, ...]],strides: Union[int, List[int], Tuple[int, ...]],n_conv_per_stage: Union[int, List[int], Tuple[int, ...]],conv_bias: bool = False,norm_op: Union[None, Type[nn.Module]] = None,norm_op_kwargs: dict = None,dropout_op: Union[None, Type[_DropoutNd]] = None,dropout_op_kwargs: dict = None,nonlin: Union[None, Type[torch.nn.Module]] = None,nonlin_kwargs: dict = None,return_skips: bool = False,nonlin_first: bool = False,pool: str = 'conv'):super().__init__()if isinstance(kernel_sizes, int):kernel_sizes = [kernel_sizes] * n_stagesif isinstance(features_per_stage, int):features_per_stage = [features_per_stage] * n_stagesif isinstance(n_conv_per_stage, int):n_conv_per_stage = [n_conv_per_stage] * n_stagesif isinstance(strides, int):strides = [strides] * n_stagesassert len(kernel_sizes) == n_stages, "kernel_sizes must have as many entries as we have resolution stages (n_stages)"assert len(n_conv_per_stage) == n_stages, "n_conv_per_stage must have as many entries as we have resolution stages (n_stages)"assert len(features_per_stage) == n_stages, "features_per_stage must have as many entries as we have resolution stages (n_stages)"assert len(strides) == n_stages, "strides must have as many entries as we have resolution stages (n_stages). " \"Important: first entry is recommended to be 1, else we run strided conv drectly on the input"stages = []for s in range(n_stages):stage_modules = []if pool == 'max' or pool == 'avg':if (isinstance(strides[s], int) and strides[s] != 1) or \isinstance(strides[s], (tuple, list)) and any([i != 1 for i in strides[s]]):stage_modules.append(get_matching_pool_op(conv_op, pool_type=pool)(kernel_size=strides[s], stride=strides[s]))conv_stride = 1elif pool == 'conv':conv_stride = strides[s]else:raise RuntimeError()stage_modules.append(CBAMStackedConvBlocks(n_conv_per_stage[s], conv_op, input_channels, features_per_stage[s], kernel_sizes[s], conv_stride,conv_bias, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first))stages.append(nn.Sequential(*stage_modules))input_channels = features_per_stage[s]self.stages = nn.Sequential(*stages)self.output_channels = features_per_stageself.strides = [maybe_convert_scalar_to_list(conv_op, i) for i in strides]self.return_skips = return_skips# we store some things that a potential decoder needsself.conv_op = conv_opself.norm_op = norm_opself.norm_op_kwargs = norm_op_kwargsself.nonlin = nonlinself.nonlin_kwargs = nonlin_kwargsself.dropout_op = dropout_opself.dropout_op_kwargs = dropout_op_kwargsself.conv_bias = conv_biasself.kernel_sizes = kernel_sizesdef forward(self, x):ret = []for s in self.stages:x = s(x)ret.append(x)if self.return_skips:return retelse:return ret[-1]def compute_conv_feature_map_size(self, input_size):output = np.int64(0)for s in range(len(self.stages)):if isinstance(self.stages[s], nn.Sequential):for sq in self.stages[s]:if hasattr(sq, 'compute_conv_feature_map_size'):output += self.stages[s][-1].compute_conv_feature_map_size(input_size)else:output += self.stages[s].compute_conv_feature_map_size(input_size)input_size = [i // j for i, j in zip(input_size, self.strides[s])]return outputclass CBAMUNetDecoder(nn.Module):def __init__(self,encoder: Union[CBAMPlainConvEncoder],num_classes: int,n_conv_per_stage: Union[int, Tuple[int, ...], List[int]],deep_supervision,nonlin_first: bool = False,norm_op: Union[None, Type[nn.Module]] = None,norm_op_kwargs: dict = None,dropout_op: Union[None, Type[_DropoutNd]] = None,dropout_op_kwargs: dict = None,nonlin: Union[None, Type[torch.nn.Module]] = None,nonlin_kwargs: dict = None,conv_bias: bool = None):"""This class needs the skips of the encoder as input in its forward.the encoder goes all the way to the bottleneck, so that's where the decoder picks up. stages in the decoderare sorted by order of computation, so the first stage has the lowest resolution and takes the bottleneckfeatures and the lowest skip as inputsthe decoder has two (three) parts in each stage:1) conv transpose to upsample the feature maps of the stage below it (or the bottleneck in case of the first stage)2) n_conv_per_stage conv blocks to let the two inputs get to know each other and merge3) (optional if deep_supervision=True) a segmentation output Todo: enable upsample logits?:param encoder::param num_classes::param n_conv_per_stage::param deep_supervision:"""super().__init__()self.deep_supervision = deep_supervisionself.encoder = encoderself.num_classes = num_classesn_stages_encoder = len(encoder.output_channels)if isinstance(n_conv_per_stage, int):n_conv_per_stage = [n_conv_per_stage] * (n_stages_encoder - 1)assert len(n_conv_per_stage) == n_stages_encoder - 1, "n_conv_per_stage must have as many entries as we have " \"resolution stages - 1 (n_stages in encoder - 1), " \"here: %d" % n_stages_encodertranspconv_op = get_matching_convtransp(conv_op=encoder.conv_op)conv_bias = encoder.conv_bias if conv_bias is None else conv_biasnorm_op = encoder.norm_op if norm_op is None else norm_opnorm_op_kwargs = encoder.norm_op_kwargs if norm_op_kwargs is None else norm_op_kwargsdropout_op = encoder.dropout_op if dropout_op is None else dropout_opdropout_op_kwargs = encoder.dropout_op_kwargs if dropout_op_kwargs is None else dropout_op_kwargsnonlin = encoder.nonlin if nonlin is None else nonlinnonlin_kwargs = encoder.nonlin_kwargs if nonlin_kwargs is None else nonlin_kwargs# we start with the bottleneck and work out way upstages = []transpconvs = []seg_layers = []for s in range(1, n_stages_encoder):input_features_below = encoder.output_channels[-s]input_features_skip = encoder.output_channels[-(s + 1)]stride_for_transpconv = encoder.strides[-s]transpconvs.append(transpconv_op(input_features_below, input_features_skip, stride_for_transpconv, stride_for_transpconv,bias=conv_bias))# input features to conv is 2x input_features_skip (concat input_features_skip with transpconv output)stages.append(CBAMStackedConvBlocks(n_conv_per_stage[s-1], encoder.conv_op, 2 * input_features_skip, input_features_skip,encoder.kernel_sizes[-(s + 1)], 1,conv_bias,norm_op,norm_op_kwargs,dropout_op,dropout_op_kwargs,nonlin,nonlin_kwargs,nonlin_first))# we always build the deep supervision outputs so that we can always load parameters. If we don't do this# then a model trained with deep_supervision=True could not easily be loaded at inference time where# deep supervision is not needed. It's just a convenience thingseg_layers.append(encoder.conv_op(input_features_skip, num_classes, 1, 1, 0, bias=True))self.stages = nn.ModuleList(stages)self.transpconvs = nn.ModuleList(transpconvs)self.seg_layers = nn.ModuleList(seg_layers)def forward(self, skips):"""we expect to get the skips in the order they were computed, so the bottleneck should be the last entry:param skips::return:"""lres_input = skips[-1]seg_outputs = []for s in range(len(self.stages)):x = self.transpconvs[s](lres_input)x = torch.cat((x, skips[-(s+2)]), 1)x = self.stages[s](x)if self.deep_supervision:seg_outputs.append(self.seg_layers[s](x))elif s == (len(self.stages) - 1):seg_outputs.append(self.seg_layers[-1](x))lres_input = x# invert seg outputs so that the largest segmentation prediction is returned firstseg_outputs = seg_outputs[::-1]if not self.deep_supervision:r = seg_outputs[0]else:r = seg_outputsreturn rdef compute_conv_feature_map_size(self, input_size):"""IMPORTANT: input_size is the input_size of the encoder!:param input_size::return:"""# first we need to compute the skip sizes. Skip bottleneck because all output feature maps of our ops will at# least have the size of the skip above that (therefore -1)skip_sizes = []for s in range(len(self.encoder.strides) - 1):skip_sizes.append([i // j for i, j in zip(input_size, self.encoder.strides[s])])input_size = skip_sizes[-1]# print(skip_sizes)assert len(skip_sizes) == len(self.stages)# our ops are the other way around, so let's match things upoutput = np.int64(0)for s in range(len(self.stages)):# print(skip_sizes[-(s+1)], self.encoder.output_channels[-(s+2)])# conv blocksoutput += self.stages[s].compute_conv_feature_map_size(skip_sizes[-(s+1)])# trans convoutput += np.prod([self.encoder.output_channels[-(s+2)], *skip_sizes[-(s+1)]], dtype=np.int64)# segmentationif self.deep_supervision or (s == (len(self.stages) - 1)):output += np.prod([self.num_classes, *skip_sizes[-(s+1)]], dtype=np.int64)return outputclass CBAMStackedConvBlocks(nn.Module):def __init__(self,num_convs: int,conv_op: Type[_ConvNd],input_channels: int,output_channels: Union[int, List[int], Tuple[int, ...]],kernel_size: Union[int, List[int], Tuple[int, ...]],initial_stride: Union[int, List[int], Tuple[int, ...]],conv_bias: bool = False,norm_op: Union[None, Type[nn.Module]] = None,norm_op_kwargs: dict = None,dropout_op: Union[None, Type[_DropoutNd]] = None,dropout_op_kwargs: dict = None,nonlin: Union[None, Type[torch.nn.Module]] = None,nonlin_kwargs: dict = None,nonlin_first: bool = False):""":param conv_op::param num_convs::param input_channels::param output_channels: can be int or a list/tuple of int. If list/tuple are provided, each entry is forone conv. The length of the list/tuple must then naturally be num_convs:param kernel_size::param initial_stride::param conv_bias::param norm_op::param norm_op_kwargs::param dropout_op::param dropout_op_kwargs::param nonlin::param nonlin_kwargs:"""super().__init__()if not isinstance(output_channels, (tuple, list)):output_channels = [output_channels] * num_convsself.convs = nn.Sequential(ConvDropoutNormReLU(conv_op, input_channels, output_channels[0], kernel_size, initial_stride, conv_bias, norm_op,norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first),*[ConvDropoutNormReLU(conv_op, output_channels[i - 1], output_channels[i], kernel_size, 1, conv_bias, norm_op,norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first)for i in range(1, num_convs-1)],CBAM(conv_op, output_channels[-2], output_channels[-1], kernel_size, 1, conv_bias, norm_op,norm_op_kwargs, dropout_op, dropout_op_kwargs, nonlin, nonlin_kwargs, nonlin_first))self.act = nonlin(**nonlin_kwargs)self.output_channels = output_channels[-1]self.initial_stride = maybe_convert_scalar_to_list(conv_op, initial_stride)def forward(self, x):out = self.convs(x)out = self.act(out)return outdef compute_conv_feature_map_size(self, input_size):assert len(input_size) == len(self.initial_stride), "just give the image size without color/feature channels or " \"batch channel. Do not give input_size=(b, c, x, y(, z)). " \"Give input_size=(x, y(, z))!"output = self.convs[0].compute_conv_feature_map_size(input_size)size_after_stride = [i // j for i, j in zip(input_size, self.initial_stride)]for b in self.convs[1:]:output += b.compute_conv_feature_map_size(size_after_stride)return outputclass ConvDropoutNormReLU(nn.Module):def __init__(self,conv_op: Type[_ConvNd],input_channels: int,output_channels: int,kernel_size: Union[int, List[int], Tuple[int, ...]],stride: Union[int, List[int], Tuple[int, ...]],conv_bias: bool = False,norm_op: Union[None, Type[nn.Module]] = None,norm_op_kwargs: dict = None,dropout_op: Union[None, Type[_DropoutNd]] = None,dropout_op_kwargs: dict = None,nonlin: Union[None, Type[torch.nn.Module]] = None,nonlin_kwargs: dict = None,nonlin_first: bool = False):super(ConvDropoutNormReLU, self).__init__()self.input_channels = input_channelsself.output_channels = output_channelsstride = maybe_convert_scalar_to_list(conv_op, stride)self.stride = stridekernel_size = maybe_convert_scalar_to_list(conv_op, kernel_size)if norm_op_kwargs is None:norm_op_kwargs = {}if nonlin_kwargs is None:nonlin_kwargs = {}ops = []self.conv = conv_op(input_channels,output_channels,kernel_size,stride,padding=[(i - 1) // 2 for i in kernel_size],dilation=1,bias=conv_bias,)ops.append(self.conv)if dropout_op is not None:self.dropout = dropout_op(**dropout_op_kwargs)ops.append(self.dropout)if norm_op is not None:self.norm = norm_op(output_channels, **norm_op_kwargs)ops.append(self.norm)if nonlin is not None:self.nonlin = nonlin(**nonlin_kwargs)ops.append(self.nonlin)if nonlin_first and (norm_op is not None and nonlin is not None):ops[-1], ops[-2] = ops[-2], ops[-1]self.all_modules = nn.Sequential(*ops)def forward(self, x):return self.all_modules(x)def compute_conv_feature_map_size(self, input_size):assert len(input_size) == len(self.stride), "just give the image size without color/feature channels or " \"batch channel. Do not give input_size=(b, c, x, y(, z)). " \"Give input_size=(x, y(, z))!"output_size = [i // j for i, j in zip(input_size, self.stride)]  # we always do same paddingreturn np.prod([self.output_channels, *output_size], dtype=np.int64)class ConvDropoutNorm(nn.Module):def __init__(self,conv_op: Type[_ConvNd],input_channels: int,output_channels: int,kernel_size: Union[int, List[int], Tuple[int, ...]],stride: Union[int, List[int], Tuple[int, ...]],conv_bias: bool = False,norm_op: Union[None, Type[nn.Module]] = None,norm_op_kwargs: dict = None,dropout_op: Union[None, Type[_DropoutNd]] = None,dropout_op_kwargs: dict = None,nonlin: Union[None, Type[torch.nn.Module]] = None,nonlin_kwargs: dict = None,nonlin_first: bool = False):super(ConvDropoutNorm, self).__init__()self.input_channels = input_channelsself.output_channels = output_channelsstride = maybe_convert_scalar_to_list(conv_op, stride)self.stride = stridekernel_size = maybe_convert_scalar_to_list(conv_op, kernel_size)if norm_op_kwargs is None:norm_op_kwargs = {}if nonlin_kwargs is None:nonlin_kwargs = {}ops = []self.conv = conv_op(input_channels,output_channels,kernel_size,stride,padding=[(i - 1) // 2 for i in kernel_size],dilation=1,bias=conv_bias,)ops.append(self.conv)if dropout_op is not None:self.dropout = dropout_op(**dropout_op_kwargs)ops.append(self.dropout)if norm_op is not None:self.norm = norm_op(output_channels, **norm_op_kwargs)ops.append(self.norm)self.all_modules = nn.Sequential(*ops)def forward(self, x):return self.all_modules(x)def compute_conv_feature_map_size(self, input_size):assert len(input_size) == len(self.stride), "just give the image size without color/feature channels or " \"batch channel. Do not give input_size=(b, c, x, y(, z)). " \"Give input_size=(x, y(, z))!"output_size = [i // j for i, j in zip(input_size, self.stride)]  # we always do same paddingreturn np.prod([self.output_channels, *output_size], dtype=np.int64)class CBAM(nn.Module):def __init__(self,conv_op: Type[_ConvNd],input_channels: int,output_channels: int,kernel_size: Union[int, List[int], Tuple[int, ...]],stride: Union[int, List[int], Tuple[int, ...]],conv_bias: bool = False,norm_op: Union[None, Type[nn.Module]] = None,norm_op_kwargs: dict = None,dropout_op: Union[None, Type[_DropoutNd]] = None,dropout_op_kwargs: dict = None,nonlin: Union[None, Type[torch.nn.Module]] = None,nonlin_kwargs: dict = None,nonlin_first: bool = False):super(CBAM, self).__init__()self.input_channels = input_channelsself.output_channels = output_channelsstride = maybe_convert_scalar_to_list(conv_op, stride)self.stride = stridekernel_size = maybe_convert_scalar_to_list(conv_op, kernel_size)if norm_op_kwargs is None:norm_op_kwargs = {}if nonlin_kwargs is None:nonlin_kwargs = {}ops = []self.conv = conv_op(input_channels,output_channels,kernel_size,stride,padding=[(i - 1) // 2 for i in kernel_size],dilation=1,bias=conv_bias,)ops.append(self.conv)if dropout_op is not None:self.dropout = dropout_op(**dropout_op_kwargs)ops.append(self.dropout)if norm_op is not None:self.norm = norm_op(output_channels, **norm_op_kwargs)ops.append(self.norm)self.all_modules = nn.Sequential(*ops)self.ca = ChannelAttention(conv_op=conv_op, channels=output_channels)self.sa = SpatialAttention(conv_op=conv_op, kernel_size=5)def forward(self, x):x =  self.all_modules(x)ca = self.ca(x) * xsa = self.sa(x) * xreturn ca + sadef compute_conv_feature_map_size(self, input_size):assert len(input_size) == len(self.stride), "just give the image size without color/feature channels or " \"batch channel. Do not give input_size=(b, c, x, y(, z)). " \"Give input_size=(x, y(, z))!"output_size = [i // j for i, j in zip(input_size, self.stride)]  # we always do same paddingreturn np.prod([self.output_channels, *output_size], dtype=np.int64)class ChannelAttention(nn.Module):"""Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet."""def __init__(self, conv_op, channels: int) -> None:"""Initializes the class and sets the basic configurations and instance variables required."""super().__init__()if conv_op == torch.nn.modules.conv.Conv2d:self.pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)elif conv_op == torch.nn.modules.conv.Conv3d:self.pool = nn.AdaptiveAvgPool3d(1)self.fc = nn.Conv3d(channels, channels, 1, 1, 0, bias=True)self.act = nn.Sigmoid()def forward(self, x: torch.Tensor) -> torch.Tensor:"""Applies forward pass using activation on convolutions of the input, optionally using batch normalization."""return x * self.act(self.fc(self.pool(x)))class SpatialAttention(nn.Module):"""Spatial-attention module."""def __init__(self, conv_op, kernel_size=7):"""Initialize Spatial-attention module with kernel size argument."""super().__init__()assert kernel_size in (3, 5, 7), "kernel size must be 3 or 7"if kernel_size in [5, 7]:padding = kernel_size // 2else:padding = 1if conv_op == torch.nn.modules.conv.Conv2d:self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)elif conv_op == torch.nn.modules.conv.Conv3d:self.cv1 = nn.Conv3d(2, 1, kernel_size, padding=padding, bias=False)self.act = nn.Sigmoid()def forward(self, x):"""Apply channel and spatial attention on input for feature recalibration."""return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))

2、配置文件修改
在完成了模型修改后,还是用上个教程的Task04_Hippocampus数据集来验证(如果没做上个教程的,自行完成数据处理),编辑nnUNet\nnUNet_preprocessed\Dataset004_Hippocampus\nnUNetPlans.json这个配置文件,进行以下改动,把network_class_name改成dynamic_network_architectures.architectures.saunet.SAPlainConvUNet,如下图:

三、模型训练

完成了模型和数据集配置文件的修改后,开始训练模型,使用的数据集还是Task04_Hippocampus,以上的代码支持2d和3d模型,可以使用以下的训练命令:

nnUNetv2_train 4 3d_fullres 0 
nnUNetv2_train 4 3d_fullres 1
nnUNetv2_train 4 3d_fullres 2 
nnUNetv2_train 4 3d_fullres 3 
nnUNetv2_train 4 3d_fullres 4 

因为nnunet训练非常的久,实验资源有限,没有完成全部训练,只完成了代码修改及跑通。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/825159.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Shopee虾皮批量上传全球产品指南

当shopee虾皮需要大量上架新产品时,批量工具可以更好的提升效率。通过本指南,你将了解如何批量上传全球商品,本指南适用于所有站点。 一、什么是批量上传? 您可以通过【中国卖家中心>>全球商品>>批量上传】功能&…

一文教您理解Playwright是如何实现动态等待的

使用过Playwright的同学都会有这样的感受,Playwright对UI页面中元素的识别非常稳定,这离不开其强大的动态等待机制!简单的解释就是,Playwright在对UI页面中的任何元素操作之前,都需要做出一些列的校验工作来确保能够稳…

GaussDB数据库SQL系列-聚合函数

背景 在这篇文章中,我们将深入探讨GaussDB数据库中聚合函数的使用和优化。聚合函数是数据库查询中非常重要的工具,它们可以对一组值执行计算并返回单个值。例如,聚合函数可以用来计算平均值、总和、最大值和最小值。 这些功能在数据分析和报…

【Linux】网络与守护进程

欢迎来到Cefler的博客😁 🕌博客主页:折纸花满衣 🏠个人专栏:题目解析 🌎推荐文章:进程状态、类型、优先级、命令行参数概念、环境变量(重要)、程序地址空间 目录 👉🏻守护…

面试八股——集合——List

主要问题 数组 如果数组索引从0开始时,数组的寻址方式为: 如果数组索引从1开始时,数组的寻址方式为: 此时对于CPU来说增加了一个减法指令,降低寻址效率。 ArrayList⭐ ArrayList构造函数 尤其说一下第三个构造函数流…

【复习笔记】FreeRTOS(五)时间片调度

本文是FreeRTOS复习笔记的第五节,时间片调度。 上一篇文章: 【复习笔记】reeRTOS(四) 列表项的插入和删除 文章目录 1.时间片调度简介1.1. 运行过程 二、实验设计三、测试例程四、实验效果 1.时间片调度简介 FreeRTOS支持多个任务同时拥有一个优先级&am…

设计千万级并发系统架构需要考虑的各方面因素

设计千万级并发系统架构需要考虑多方面因素,包括系统的可伸缩性、高可用性、性能、安全性等。 1、分布式架构: 使用微服务架构:将系统拆分成多个独立的服务,每个服务都可以独立部署和扩展。 使用分布式服务框架:如S…

顺丰同城急送API的坑(附源码)

一、背景 最近公司让我对接顺丰同城急送的API,讲讲里面我遇到的坑 官方的API文档给我的感觉是不怎么规范的,很多细节要靠猜,示例代码也不全,具体细节不多说,如果你现在也需要对接他们API,可以参考本篇博客…

爬虫 | 基于 requests 实现加密 POST 请求发送与身份验证

Hi,大家好,我是半亩花海。本项目旨在实现一个简单的 Python 脚本,用于向指定的 URL 发送 POST 请求,并通过特定的加密算法生成请求头中的签名信息。这个脚本的背后是与某个特定的网络服务交互,发送特定格式的 JSON 数据…

LeetCode in Python 1338. Reduce Array Size to The Half (数组大小减半)

数组大小减半思路简单,主要是熟悉python中collections.Counter的用法,采用贪心策略即可。 示例: 图1 数组大小减半输入输出示例 代码: class Solution:def minSetSize(self, arr):count Counter(arr)n, ans 0, 0for i, valu…

北大字节联合发布视觉自动回归建模(VAR):通过下一代预测生成可扩展的图像

北大和字节发布一个新的图像生成框架VAR。首次使GPT风格的AR模型在图像生成上超越了Diffusion transformer。 同时展现出了与大语言模型观察到的类似Scaling laws的规律。在ImageNet 256x256基准上,VAR将FID从18.65大幅提升到1.80,IS从80.4提升到356.4,推理速度提高了20倍。 相…

关于Jetson空间不足的解决问题(sd卡挂载和conda更改环境安装路径)

文章目录 问题描述挂载sd卡到指定目录查看conda路径更改环境路径指定路径安装conda虚拟环境 问题描述 因为在做毕设的时候,用到了Jetson,发现这个空间太小了,如果下conda的包根本不够用,所以就想挂载sd卡,然后把环境安…

国外GIS软件排名简介<30个>

简介 国外gisgeography网站进行了一次GIS软件排名,通过分析、制图、编辑等因素进行测试,具体规则如下: 分析:矢量/栅格工具、时态、地统计、网络分析和脚本。 制图:地图类型、坐标系、地图布局/元素、标注/注记、3D …

C#到底属于编译型语言还是解释型语言?

C#是一种编译型语言,也称为静态类型语言,这意味着C#代码在运行之前需要经过编译器的编译处理,并生成一个可执行的本地代码文件(通常是.exe或.dll文件)。相反,解释型语言将代码转换为低级代码后直接执行&…

计算机视觉——手机目标检测数据集

这是一个手机目标检测的数据集,数据集的标注工具是labelimg,数据格式是voc格式,要训练yolo模型的话,可以使用脚本改成txt格式,数据集标注了手机,标签名:telephone,数据集总共有1960张,有一部分是…

软件无线电安全之GNU Radio基础 -上

GNU Radio介绍 GNU Radio是一款开源的软件工具集,专注于软件定义无线电(SDR)系统的设计和实现。该工具集支持多种SDR硬件平台,包括USRP、HackRF One和RTL-SDR等。用户可以通过GNU Radio Companion构建流程图,使用不同…

BackTrader 中文文档(二十七)

原文:www.backtrader.com/ 数据 - 多个时间框架 原文:www.backtrader.com/blog/posts/2015-08-24-data-multitimeframe/data-multitimeframe/ 有时,使用不同的时间框架进行投资决策: 周线用于评估趋势 每日执行进入 或者 5 分钟…

软考132-上午题-【软件工程】-沟通路径

一、定义 1-1、沟通路径1 沟通路径 1-2、沟通路径2 沟通路径 n-1 二、真题 真题1: 真题2: 真题3:

发布 Chrome/Edge浏览器extension扩展到应用商店

Chrom Extension发布流程 创建和发布自定义 Chrome 应用和扩展程序:https://support.google.com/chrome/a/answer/2714278?hlzh-Hans 在 Chrome 应用商店中发布:https://developer.chrome.com/docs/webstore/publish?hlzh-cn 注册开发者帐号&#…

图解CPU的实模式与保护模式

哈喽,大家好,我是呼噜噜,好久没有更新old linux了,在上一篇文章Linux0.12内核源码解读(7)-陷阱门初始化中,我们简要地提及了中断,但是中断机制在计算机世界里非常重要,处处都离不开中断&#xf…