221 lines
7.1 KiB
Python
221 lines
7.1 KiB
Python
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
from torch.autograd import Function, Variable
|
|
from torch.nn import Module, parameter
|
|
|
|
|
|
import warnings
|
|
try:
|
|
from queue import Queue
|
|
except ImportError:
|
|
from Queue import Queue
|
|
|
|
# from torch.nn.modules.batchnorm import _BatchNorm
|
|
from functools import partial
|
|
|
|
|
|
from timm.layers import DropPath, trunc_normal_
|
|
# from timm import register_model
|
|
# from timm.layers.helpers import to_2tuple
|
|
# from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
|
# LVC
|
|
class Encoding(nn.Module):
|
|
def __init__(self, in_channels, num_codes):
|
|
super(Encoding, self).__init__()
|
|
# init codewords and smoothing factor
|
|
self.in_channels, self.num_codes = in_channels, num_codes
|
|
num_codes = 64
|
|
std = 1. / ((num_codes * in_channels)**0.5)
|
|
# [num_codes, channels]
|
|
self.codewords = nn.Parameter(
|
|
torch.empty(num_codes, in_channels, dtype=torch.float).uniform_(-std, std), requires_grad=True)
|
|
# [num_codes]
|
|
self.scale = nn.Parameter(torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0), requires_grad=True)
|
|
|
|
@staticmethod
|
|
def scaled_l2(x, codewords, scale):
|
|
num_codes, in_channels = codewords.size()
|
|
b = x.size(0)
|
|
expanded_x = x.unsqueeze(2).expand((b, x.size(1), num_codes, in_channels))
|
|
|
|
# ---处理codebook (num_code, c1)
|
|
reshaped_codewords = codewords.view((1, 1, num_codes, in_channels))
|
|
|
|
# 把scale从1, num_code变成 batch, c2, N, num_codes
|
|
reshaped_scale = scale.view((1, 1, num_codes)) # N, num_codes
|
|
|
|
# ---计算rik = z1 - d # b, N, num_codes
|
|
scaled_l2_norm = reshaped_scale * (expanded_x - reshaped_codewords).pow(2).sum(dim=3)
|
|
return scaled_l2_norm
|
|
|
|
@staticmethod
|
|
def aggregate(assignment_weights, x, codewords):
|
|
num_codes, in_channels = codewords.size()
|
|
|
|
# ---处理codebook
|
|
reshaped_codewords = codewords.view((1, 1, num_codes, in_channels))
|
|
b = x.size(0)
|
|
|
|
# ---处理特征向量x b, c1, N
|
|
expanded_x = x.unsqueeze(2).expand((b, x.size(1), num_codes, in_channels))
|
|
|
|
#变换rei b, N, num_codes,-
|
|
assignment_weights = assignment_weights.unsqueeze(3) # b, N, num_codes,
|
|
|
|
# ---开始计算eik,必须在Rei计算完之后
|
|
encoded_feat = (assignment_weights * (expanded_x - reshaped_codewords)).sum(1)
|
|
return encoded_feat
|
|
|
|
def forward(self, x):
|
|
assert x.dim() == 4 and x.size(1) == self.in_channels
|
|
b, in_channels, w, h = x.size()
|
|
|
|
# [batch_size, height x width, channels]
|
|
x = x.view(b, self.in_channels, -1).transpose(1, 2).contiguous()
|
|
|
|
# assignment_weights: [batch_size, channels, num_codes]
|
|
assignment_weights = F.softmax(self.scaled_l2(x, self.codewords, self.scale), dim=2)
|
|
|
|
# aggregate
|
|
encoded_feat = self.aggregate(assignment_weights, x, self.codewords)
|
|
return encoded_feat
|
|
|
|
|
|
# 1*1 3*3 1*1
|
|
class ConvBlock(nn.Module):
|
|
def __init__(self, in_channels, out_channels, stride=1, res_conv=False, act_layer=nn.ReLU, groups=1,
|
|
norm_layer=partial(nn.BatchNorm2d, eps=1e-6), drop_block=None, drop_path=None):
|
|
super(ConvBlock, self).__init__()
|
|
self.in_channels = in_channels
|
|
expansion = 4
|
|
c = out_channels // expansion
|
|
|
|
self.conv1 = nn.Conv2d(in_channels, c, kernel_size=1, stride=1, padding=0, bias=False) # [64, 256, 1, 1]
|
|
self.bn1 = norm_layer(c)
|
|
self.act1 = act_layer(inplace=True)
|
|
|
|
self.conv2 = nn.Conv2d(c, c, kernel_size=3, stride=stride, groups=groups, padding=1, bias=False)
|
|
self.bn2 = norm_layer(c)
|
|
self.act2 = act_layer(inplace=True)
|
|
|
|
self.conv3 = nn.Conv2d(c, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
|
|
self.bn3 = norm_layer(out_channels)
|
|
self.act3 = act_layer(inplace=True)
|
|
|
|
if res_conv:
|
|
self.residual_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
|
|
self.residual_bn = norm_layer(out_channels)
|
|
|
|
self.res_conv = res_conv
|
|
self.drop_block = drop_block
|
|
self.drop_path = drop_path
|
|
|
|
def zero_init_last_bn(self):
|
|
nn.init.zeros_(self.bn3.weight)
|
|
|
|
def forward(self, x, return_x_2=True):
|
|
residual = x
|
|
|
|
x = self.conv1(x)
|
|
x = self.bn1(x)
|
|
if self.drop_block is not None:
|
|
x = self.drop_block(x)
|
|
x = self.act1(x)
|
|
|
|
x = self.conv2(x) #if x_t_r is None else self.conv2(x + x_t_r)
|
|
x = self.bn2(x)
|
|
if self.drop_block is not None:
|
|
x = self.drop_block(x)
|
|
x2 = self.act2(x)
|
|
|
|
x = self.conv3(x2)
|
|
x = self.bn3(x)
|
|
if self.drop_block is not None:
|
|
x = self.drop_block(x)
|
|
|
|
if self.drop_path is not None:
|
|
x = self.drop_path(x)
|
|
|
|
if self.res_conv:
|
|
residual = self.residual_conv(residual)
|
|
residual = self.residual_bn(residual)
|
|
|
|
x += residual
|
|
x = self.act3(x)
|
|
|
|
if return_x_2:
|
|
return x, x2
|
|
else:
|
|
return x
|
|
|
|
|
|
class Mean(nn.Module):
|
|
def __init__(self, dim, keep_dim=False):
|
|
super(Mean, self).__init__()
|
|
self.dim = dim
|
|
self.keep_dim = keep_dim
|
|
|
|
def forward(self, input):
|
|
return input.mean(self.dim, self.keep_dim)
|
|
|
|
|
|
class Mlp(nn.Module):
|
|
"""
|
|
Implementation of MLP with 1*1 convolutions. Input: tensor with shape [B, C, H, W]
|
|
"""
|
|
def __init__(self, in_features, hidden_features=None,
|
|
out_features=None, act_layer=nn.GELU, drop=0.):
|
|
super().__init__()
|
|
out_features = out_features or in_features
|
|
hidden_features = hidden_features or in_features
|
|
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
|
|
self.act = act_layer()
|
|
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
|
|
self.drop = nn.Dropout(drop)
|
|
self.apply(self._init_weights)
|
|
|
|
def _init_weights(self, m):
|
|
if isinstance(m, nn.Conv2d):
|
|
trunc_normal_(m.weight, std=.02)
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def forward(self, x):
|
|
x = self.fc1(x)
|
|
x = self.act(x)
|
|
x = self.drop(x)
|
|
x = self.fc2(x)
|
|
x = self.drop(x)
|
|
return x
|
|
|
|
|
|
class LayerNormChannel(nn.Module):
|
|
"""
|
|
LayerNorm only for Channel Dimension.
|
|
Input: tensor in shape [B, C, H, W]
|
|
"""
|
|
def __init__(self, num_channels, eps=1e-05):
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.ones(num_channels))
|
|
self.bias = nn.Parameter(torch.zeros(num_channels))
|
|
self.eps = eps
|
|
|
|
def forward(self, x):
|
|
u = x.mean(1, keepdim=True)
|
|
s = (x - u).pow(2).mean(1, keepdim=True)
|
|
x = (x - u) / torch.sqrt(s + self.eps)
|
|
x = self.weight.unsqueeze(-1).unsqueeze(-1) * x \
|
|
+ self.bias.unsqueeze(-1).unsqueeze(-1)
|
|
return x
|
|
|
|
|
|
class GroupNorm(nn.GroupNorm):
|
|
"""
|
|
Group Normalization with 1 group.
|
|
Input: tensor in shape [B, C, H, W]
|
|
"""
|
|
def __init__(self, num_channels, **kwargs):
|
|
super().__init__(1, num_channels, **kwargs) |