1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
| def main(): batch_size = 32 cifar_train = datasets.CIFAR10('./data/cifar10', train=True, transform=transforms.Compose( [ transforms.Resize((32,32)), transforms.ToTensor() ] ), download=True) cifar_train = DataLoader(cifar_train,batch_size=batch_size,shuffle=True) cifar_test = datasets.CIFAR10('./data/cifar10', train=False, transform=transforms.Compose( [ transforms.Resize((32, 32)), transforms.ToTensor() ] ), download=True) cifar_test = DataLoader(cifar_test, batch_size=batch_size, shuffle=False ) model = ResNet18() print(model) optiminzer = optim.Adam(model.parameters(),lr=1e-3) criteon = nn.CrossEntropyLoss() for epoch in range(20): model.train() for batchidx ,(x,label) in enumerate(cifar_train): logits = model(x) loss = criteon(logits,label) optiminzer.zero_grad() loss.backward() optiminzer.step() print("loss:",epoch, loss.item()) model.eval() with torch.no_grad(): total_correct = 0 total_num = 0 for x, label in cifar_test: logits = model(x) pred = logits.argmax(dim=1) total_correct += torch.eq(pred,label).float().sum().item() total_num += x.size(0) acc = total_correct / total_num print("acc",epoch,acc)
|