import torch from torch import nn from torch.nn import functional as F from torch.autograd import Function, Variable from torch.nn import Module, parameter import warnings try: from queue import Queue except ImportError: from Queue import Queue # from torch.nn.modules.batchnorm import _BatchNorm from functools import partial from timm.layers import DropPath, trunc_normal_ # from timm import register_model # from timm.layers.helpers import to_2tuple # from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD # LVC class Encoding(nn.Module): def __init__(self, in_channels, num_codes): super(Encoding, self).__init__() # init codewords and smoothing factor self.in_channels, self.num_codes = in_channels, num_codes num_codes = 64 std = 1. / ((num_codes * in_channels)**0.5) # [num_codes, channels] self.codewords = nn.Parameter( torch.empty(num_codes, in_channels, dtype=torch.float).uniform_(-std, std), requires_grad=True) # [num_codes] self.scale = nn.Parameter(torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0), requires_grad=True) @staticmethod def scaled_l2(x, codewords, scale): num_codes, in_channels = codewords.size() b = x.size(0) expanded_x = x.unsqueeze(2).expand((b, x.size(1), num_codes, in_channels)) # ---处理codebook (num_code, c1) reshaped_codewords = codewords.view((1, 1, num_codes, in_channels)) # 把scale从1, num_code变成 batch, c2, N, num_codes reshaped_scale = scale.view((1, 1, num_codes)) # N, num_codes # ---计算rik = z1 - d # b, N, num_codes scaled_l2_norm = reshaped_scale * (expanded_x - reshaped_codewords).pow(2).sum(dim=3) return scaled_l2_norm @staticmethod def aggregate(assignment_weights, x, codewords): num_codes, in_channels = codewords.size() # ---处理codebook reshaped_codewords = codewords.view((1, 1, num_codes, in_channels)) b = x.size(0) # ---处理特征向量x b, c1, N expanded_x = x.unsqueeze(2).expand((b, x.size(1), num_codes, in_channels)) #变换rei b, N, num_codes,- assignment_weights = assignment_weights.unsqueeze(3) # b, N, num_codes, # ---开始计算eik,必须在Rei计算完之后 encoded_feat = (assignment_weights * (expanded_x - reshaped_codewords)).sum(1) return encoded_feat def forward(self, x): assert x.dim() == 4 and x.size(1) == self.in_channels b, in_channels, w, h = x.size() # [batch_size, height x width, channels] x = x.view(b, self.in_channels, -1).transpose(1, 2).contiguous() # assignment_weights: [batch_size, channels, num_codes] assignment_weights = F.softmax(self.scaled_l2(x, self.codewords, self.scale), dim=2) # aggregate encoded_feat = self.aggregate(assignment_weights, x, self.codewords) return encoded_feat # 1*1 3*3 1*1 class ConvBlock(nn.Module): def __init__(self, in_channels, out_channels, stride=1, res_conv=False, act_layer=nn.ReLU, groups=1, norm_layer=partial(nn.BatchNorm2d, eps=1e-6), drop_block=None, drop_path=None): super(ConvBlock, self).__init__() self.in_channels = in_channels expansion = 4 c = out_channels // expansion self.conv1 = nn.Conv2d(in_channels, c, kernel_size=1, stride=1, padding=0, bias=False) # [64, 256, 1, 1] self.bn1 = norm_layer(c) self.act1 = act_layer(inplace=True) self.conv2 = nn.Conv2d(c, c, kernel_size=3, stride=stride, groups=groups, padding=1, bias=False) self.bn2 = norm_layer(c) self.act2 = act_layer(inplace=True) self.conv3 = nn.Conv2d(c, out_channels, kernel_size=1, stride=1, padding=0, bias=False) self.bn3 = norm_layer(out_channels) self.act3 = act_layer(inplace=True) if res_conv: self.residual_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False) self.residual_bn = norm_layer(out_channels) self.res_conv = res_conv self.drop_block = drop_block self.drop_path = drop_path def zero_init_last_bn(self): nn.init.zeros_(self.bn3.weight) def forward(self, x, return_x_2=True): residual = x x = self.conv1(x) x = self.bn1(x) if self.drop_block is not None: x = self.drop_block(x) x = self.act1(x) x = self.conv2(x) #if x_t_r is None else self.conv2(x + x_t_r) x = self.bn2(x) if self.drop_block is not None: x = self.drop_block(x) x2 = self.act2(x) x = self.conv3(x2) x = self.bn3(x) if self.drop_block is not None: x = self.drop_block(x) if self.drop_path is not None: x = self.drop_path(x) if self.res_conv: residual = self.residual_conv(residual) residual = self.residual_bn(residual) x += residual x = self.act3(x) if return_x_2: return x, x2 else: return x class Mean(nn.Module): def __init__(self, dim, keep_dim=False): super(Mean, self).__init__() self.dim = dim self.keep_dim = keep_dim def forward(self, input): return input.mean(self.dim, self.keep_dim) class Mlp(nn.Module): """ Implementation of MLP with 1*1 convolutions. Input: tensor with shape [B, C, H, W] """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Conv2d(in_features, hidden_features, 1) self.act = act_layer() self.fc2 = nn.Conv2d(hidden_features, out_features, 1) self.drop = nn.Dropout(drop) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Conv2d): trunc_normal_(m.weight, std=.02) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class LayerNormChannel(nn.Module): """ LayerNorm only for Channel Dimension. Input: tensor in shape [B, C, H, W] """ def __init__(self, num_channels, eps=1e-05): super().__init__() self.weight = nn.Parameter(torch.ones(num_channels)) self.bias = nn.Parameter(torch.zeros(num_channels)) self.eps = eps def forward(self, x): u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) x = self.weight.unsqueeze(-1).unsqueeze(-1) * x \ + self.bias.unsqueeze(-1).unsqueeze(-1) return x class GroupNorm(nn.GroupNorm): """ Group Normalization with 1 group. Input: tensor in shape [B, C, H, W] """ def __init__(self, num_channels, **kwargs): super().__init__(1, num_channels, **kwargs)