-
Notifications
You must be signed in to change notification settings - Fork 251
/
Copy pathEFF2d.py
119 lines (94 loc) · 3.94 KB
/
EFF2d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
import torch.nn as nn
import math
#论文:SUnet: A multi-organ segmentation network based on multiple attention
#论文地址:https://www.sciencedirect.com/science/article/abs/pii/S0010482523010612
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
padding = 3 if kernel_size == 7 else 1
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv1(x)
return self.sigmoid(x)
class Efficient_Attention_Gate(nn.Module):
def __init__(self, F_g, F_l, F_int, num_groups=32):
super(Efficient_Attention_Gate, self).__init__()
self.num_groups = num_groups
self.grouped_conv_g = nn.Sequential(
nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True, groups=num_groups),
nn.BatchNorm2d(F_int),
nn.ReLU(inplace=True)
)
self.grouped_conv_x = nn.Sequential(
nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True, groups=num_groups),
nn.BatchNorm2d(F_int),
nn.ReLU(inplace=True)
)
self.psi = nn.Sequential(
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
self.relu = nn.ReLU(inplace=True)
def forward(self, g, x):
g1 = self.grouped_conv_g(g)
x1 = self.grouped_conv_x(x)
psi = self.psi(self.relu(x1 + g1))
out = x * psi
out += x
return out
class EfficientChannelAttention(nn.Module):
def __init__(self, channels, gamma=2, b=1):
super(EfficientChannelAttention, self).__init__()
# 设计自适应卷积核,便于后续做1*1卷积
kernel_size = int(abs((math.log(channels, 2) + b) / gamma))
kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
# 全局平局池化
self.avg_pool = nn.AdaptiveAvgPool2d(1)
# 基于1*1卷积学习通道之间的信息
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
# 激活函数
self.sigmoid = nn.Sigmoid()
def forward(self, x):
# 首先,空间维度做全局平局池化,[b,c,h,w]==>[b,c,1,1]
v = self.avg_pool(x)
# 然后,基于1*1卷积学习通道之间的信息;其中,使用前面设计的自适应卷积核
v = self.conv(v.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
# 最终,经过sigmoid 激活函数处理
v = self.sigmoid(v)
return v
class EFF(nn.Module):
def __init__(self, in_dim, is_bottom=False):
super().__init__()
self.is_bottom = is_bottom
if not is_bottom:
self.EAG = Efficient_Attention_Gate(in_dim, in_dim, in_dim)
else:
self.EAG = nn.Identity()
self.ECA = EfficientChannelAttention(in_dim*2)
self.SA = SpatialAttention()
def forward(self, x, skip):
if not self.is_bottom:
EAG_skip = self.EAG(x, skip)
x = torch.cat((EAG_skip, x), dim=1)
# x = EAG_skip + x
else:
x = self.EAG(x)
x = self.ECA(x) * x
x = self.SA(x) * x
return x
if __name__ == '__main__':
block = EFF(in_dim=512, is_bottom=False)
x1 = torch.randn(1, 512, 71, 71)
x2 = torch.randn(1, 512, 71, 71)
# 将张量通过 EFF 模块
output = block(x1, x2)
print(x1.size())
print(x2.size())
print(output.size())