构建参差模块

  • 残差单元,主要包括残差拟合和shortcut connetion 两个结构。
    • stride=2 ,减少计算量,本文只是实现ResNet的原理,不考虑工程上的实现细节。
    • extra作用: 保证$f(x)+x$ 能element wise add.
    • cifar10 数据: 3通道、数据大小为$32 \times 32$
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch  
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
from torch.nn import functional as F

class ResBlock(nn.Module):
"""
resnet block
"""
def __init__(self,inputChannels,outputChannels,stride=2):
super(ResBlock,self).__init__()
self.conv1 = nn.Conv2d(in_channels=inputChannels,out_channels=outputChannels,kernel_size=3,stride=stride,padding=1)
self.bn1 = nn.BatchNorm2d(outputChannels)
self.conv2 = nn.Conv2d(in_channels=outputChannels,out_channels=outputChannels,kernel_size=3,stride=1,padding=1)
self.bn2 = nn.BatchNorm2d(outputChannels)
self.extra = nn.Sequential()
if inputChannels != outputChannels:
# [b,ch_in,h,w]=> [b,ch_out,h,w]
self.extra = nn.Sequential(
nn.Conv2d(inputChannels,outputChannels,kernel_size=1,stride=stride),nn.BatchNorm2d(outputChannels)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out = self.extra(x) + out
return out

ResNet18

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class ResNet18(nn.Module):  
def __init__(self):
super(ResNet18,self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3,64,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(64)
)

# followed 4 blocks
self.block1 = ResBlock(64,128,stride=2)
self.block2 = ResBlock(128,256,stride=2)
self.block3 = ResBlock(256,512,stride=2)
self.block4 = ResBlock(512,512,stride=1)
self.outlayer = nn.Linear(512,10)

def forward(self,x):
x = F.relu(self.conv1(x))
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.block4(x)
x = F.adaptive_avg_pool2d(x,[1,1])
x = x.view(x.size(0),-1)
x = self.outlayer(x)
return x

实战CIFAR10

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
      
def main():
batch_size = 32
cifar_train = datasets.CIFAR10('./data/cifar10',
train=True,
transform=transforms.Compose(
[
transforms.Resize((32,32)),
transforms.ToTensor()
]
),
download=True)
cifar_train = DataLoader(cifar_train,batch_size=batch_size,shuffle=True)

cifar_test = datasets.CIFAR10('./data/cifar10',
train=False,
transform=transforms.Compose(
[
transforms.Resize((32, 32)),
transforms.ToTensor()
]
),
download=True)
cifar_test = DataLoader(cifar_test, batch_size=batch_size, shuffle=False )

model = ResNet18()
print(model)
optiminzer = optim.Adam(model.parameters(),lr=1e-3)
criteon = nn.CrossEntropyLoss()


for epoch in range(20):
model.train()
for batchidx ,(x,label) in enumerate(cifar_train):
logits = model(x)
loss = criteon(logits,label)
# backprop
optiminzer.zero_grad()
loss.backward()
optiminzer.step()

#
print("loss:",epoch, loss.item())

# test
model.eval()
with torch.no_grad(): # 测试阶段不需要计算梯度
total_correct = 0
total_num = 0
for x, label in cifar_test:
logits = model(x)
pred = logits.argmax(dim=1)
total_correct += torch.eq(pred,label).float().sum().item()
total_num += x.size(0)
acc = total_correct / total_num
print("acc",epoch,acc)