-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathadalib.py
131 lines (114 loc) · 4.5 KB
/
adalib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# 2022.09.29-Implementation for building AdaBin model
# Huawei Technologies Co., Ltd. <[email protected]>
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
class BinaryQuantize(Function):
'''
binary quantize function
(https://github.com/htqin/IR-Net/blob/master/CIFAR-10/ResNet20/1w1a/modules/binaryfunction.py)
'''
@staticmethod
def forward(ctx, input, k, t):
ctx.save_for_backward(input, k, t)
out = torch.sign(input)
# print(input.mean())
# print(input.std())
return out
@staticmethod
def backward(ctx, grad_output):
input, k, t = ctx.saved_tensors
k, t = k.cuda(), t.cuda()
grad_input = k * t * (1-torch.pow(torch.tanh(input * t), 2)) * grad_output
return grad_input, None, None
class Maxout(nn.Module):
'''
Nonlinear function
'''
def __init__(self, channel, neg_init=0.25, pos_init=1.0):
super(Maxout, self).__init__()
self.neg_scale = nn.Parameter(neg_init*torch.ones(channel))
self.pos_scale = nn.Parameter(pos_init*torch.ones(channel))
self.relu = nn.ReLU()
def forward(self, x):
# Maxout
x = self.pos_scale.view(1,-1,1,1)*self.relu(x) - self.neg_scale.view(1,-1,1,1)*self.relu(-x)
return x
class BinaryActivation(nn.Module):
'''
learnable distance and center for activation
'''
def __init__(self):
super(BinaryActivation, self).__init__()
self.shift = nn.Parameter(torch.tensor(0.0))
self.alpha_a = nn.Parameter(torch.tensor(1.0))
self.beta_a = nn.Parameter(torch.tensor(0.0))
self.k = torch.tensor([1]).float().cpu()
self.t = torch.tensor([1]).float().cpu()
def gradient_approx(self, x):
'''
from Bi-Real Net
(https://github.com/liuzechun/Bi-Real-net/blob/master/pytorch_implementation/BiReal18_34/birealnet.py)
'''
out_forward = torch.sign(x)
mask1 = x < -1
mask2 = x < 0
mask3 = x < 1
out1 = (-1) * mask1.type(torch.float32) + (x*x + 2*x) * (1-mask1.type(torch.float32))
out2 = out1 * mask2.type(torch.float32) + (-x*x + 2*x) * (1-mask2.type(torch.float32))
out3 = out2 * mask3.type(torch.float32) + 1 * (1- mask3.type(torch.float32))
out = out_forward.detach() - out3.detach() + out3
return out
def forward(self, x):
x = x-self.shift
# print(f"beta_a: {self.beta_a.mean()}")
# print(f"alpha_a: {self.alpha_a.mean()}")
# print(f"x_mean: {x.mean()}")
# print(f"x_std: {x.std()}")
x = self.gradient_approx(x)
# print(f"x_mean: {x.mean()}")
# x = BinaryQuantize().apply(x, self.k, self.t)
return self.alpha_a*(x + self.beta_a)
class LambdaLayer(nn.Module):
'''
for DownSample
'''
def __init__(self, lambd):
super(LambdaLayer, self).__init__()
self.lambd = lambd
def forward(self, x):
return self.lambd(x)
class AdaBin_Conv2d(nn.Conv2d):
'''
AdaBin Convolution
'''
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False, a_bit=1, w_bit=1):
super(AdaBin_Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
self.a_bit = a_bit
self.w_bit = w_bit
self.k = torch.tensor([1]).float().cpu()
self.t = torch.tensor([1]).float().cpu()
self.binary_a = BinaryActivation()
self.filter_size = self.kernel_size[0]*self.kernel_size[1]*self.in_channels
def forward(self, inputs):
if self.a_bit==1:
# print("activation:")
# print(inputs.mean())
# print(inputs.std())
inputs = self.binary_a(inputs)
if self.w_bit==1:
w = self.weight
beta_w = w.mean((1,2,3)).view(-1,1,1,1)
alpha_w = torch.sqrt(((w-beta_w)**2).sum((1,2,3))/self.filter_size).view(-1,1,1,1)
w = (w - beta_w)
wb = BinaryQuantize().apply(w, self.k, self.t)
# print(f"wb_mean: {wb.mean()}")
weight = wb * alpha_w + beta_w
# print("weight:")
# print(weight.mean())
# print(weight.std())
else:
weight = self.weight
output = F.conv2d(inputs, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
return output