-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathquantize_module_.py
More file actions
144 lines (109 loc) · 4.89 KB
/
quantize_module_.py
File metadata and controls
144 lines (109 loc) · 4.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# coding=utf-8
import torch
import torch.nn as nn
from quantize.quantize_method import quantize_weights_bias_gemm, quantize_activations_gemm
import torch.nn.functional as F
class QWConv2D(torch.nn.Conv2d):
def __init__(self, n_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(QWConv2D, self).__init__(n_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
# nn.init.xavier_normal_(self.weight, 1)
# nn.init.constant_(self.weight, 1)
def forward(self, input):
"""
关键在于使用函数 F.conv2d, 而不是使用模块 nn.ConV2d
"""
qweight = quantize_weights_bias_gemm(self.weight)
if self.bias is not None:
qbias = quantize_weights_bias_gemm(self.bias)
else:
qbias = None
return F.conv2d(input, qweight, qbias, self.stride,
self.padding, self.dilation, self.groups)
class QWAConv2D(torch.nn.Conv2d):
def __init__(self, n_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(QWAConv2D, self).__init__(n_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
# nn.init.xavier_normal_(self.weight, 1)
# nn.init.constant_(self.weight, 1)
def forward(self, input):
qweight = quantize_weights_bias_gemm(self.weight)
if self.bias is not None:
qbias = quantize_weights_bias_gemm(self.bias)
else:
qbias = None
qinput = quantize_activations_gemm(input)
return F.conv2d(qinput, qweight, qbias, self.stride,
self.padding, self.dilation, self.groups)
class QWLinear(nn.Linear):
def __init__(self, in_features, out_features, bias=True, num_bits=8, num_bits_weight=None,
num_bits_grad=None, biprecision=False):
super(QWLinear, self).__init__(in_features, out_features, bias)
def forward(self, input):
qweight = quantize_weights_bias_gemm(self.weight)
if self.bias is not None:
qbias = quantize_weights_bias_gemm(self.bias)
else:
qbias = None
return F.linear(input, qweight, qbias)
class QWALinear(nn.Linear):
def __init__(self, in_features, out_features, bias=True):
super(QWALinear, self).__init__(in_features, out_features, bias)
def forward(self, input):
qinput = quantize_activations_gemm(input)
qweight = quantize_weights_bias_gemm(self.weight)
if self.bias is not None:
qbias = quantize_weights_bias_gemm(self.bias)
else:
qbias = None
return F.linear(qinput, qweight, qbias)
"""
论文中 scalar layer 层设计 (多个 GPU )
"""
class Scalar(nn.Module):
def __init__(self):
super(Scalar, self).__init__() # 这一行很重要
# 第1种错误
# self.scalar = torch.tensor([0.01], requires_grad=True)
# RuntimeError: Expected object of type torch.FloatTensor
# but found type torch.cuda.FloatTensor for argument
# 第2种错误
# self.scalar = torch.tensor([0.01], requires_grad=True).cuda()
# RuntimeError: arguments are located on different GPUs
# 第3种错误
# self.scalar = nn.Parameter(torch.tensor(0.01, requires_grad=True))
# RuntimeError: slice() cannot be applied to a 0-dim tensor,
# 而加了方括号正确为 1-dim tensor
# 第4中错误
# scalar = nn.Parameter(torch.tensor([0.01], requires_grad=True))
# self.register_buffer("scalar", scalar)
# scalar没有梯度更新(全是None), register_buffer 用于存储非训练参数, 如bn的平均值存储
# 第1种方法, 可以使用
# self.scalar = nn.Parameter(torch.tensor([0.01], requires_grad=True))
# 第2种方法, 可以使用
# scalar = nn.Parameter(torch.tensor([0.01], requires_grad=True))
# self.register_parameter("scalar", scalar)
# 根据训练经验, 设为 2.5
self.scalar = nn.Parameter(torch.tensor([1.0], requires_grad=True, dtype=torch.float))
def forward(self, i):
return self.scalar * i
if __name__ == "__main__":
qconv = QWConv2D(1, 1, 3)
qconv.zero_grad()
x = torch.ones(1, 1, 3, 3, requires_grad=True).float()
y = qconv(x)
y.backward()
print("QConv2D 权重梯度", qconv.weight.grad)
# 直接求梯度
a = torch.ones(3, 3, requires_grad=True).float()
w = nn.init.constant_(torch.empty(3, 3, requires_grad=True), 1)
qw = quantize_weights_bias_gemm(w)
z = (qw * a).sum()
z.backward()
print("求权重梯度", w.grad)
# 验证量化梯度
qa = quantize_weights_bias_gemm(a).sum()
qa.backward()
print("直接求量化权重梯度", a.grad)