-
Notifications
You must be signed in to change notification settings - Fork 19
Expand file tree
/
Copy pathengine.py
More file actions
113 lines (85 loc) · 3.52 KB
/
engine.py
File metadata and controls
113 lines (85 loc) · 3.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch, os, time, classify, utils
import numpy as np
import torch.nn as nn
from copy import deepcopy
from torch.optim.lr_scheduler import MultiStepLR
root_path = "./target_models"
model_path = os.path.join(root_path, "target_ckp")
device = "cuda"
def test(model, criterion, dataloader):
tf = time.time()
model.eval()
loss, cnt, ACC = 0.0, 0, 0
for img, iden in dataloader:
img, iden = img.to(device), iden.to(device)
bs = img.size(0)
iden = iden.view(-1)
out_prob = model(img)[-1]
out_iden = torch.argmax(out_prob, dim=1).view(-1)
ACC += torch.sum(iden == out_iden).item()
cnt += bs
return ACC * 100.0 / cnt
def train_reg(args, model, criterion, optimizer, trainloader, testloader, n_epochs):
best_ACC = 0.0
model_name = args['dataset']['model_name']
#scheduler = MultiStepLR(optimizer, milestones=adjust_epochs, gamma=gamma)
for epoch in range(n_epochs):
tf = time.time()
ACC, cnt, loss_tot = 0, 0, 0.0
model.train()
for i, (img, iden) in enumerate(trainloader):
img, iden = img.to(device), iden.to(device)
bs = img.size(0)
iden = iden.view(-1)
feats, out_prob = model(img)
cross_loss = criterion(out_prob, iden)
loss = cross_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
out_iden = torch.argmax(out_prob, dim=1).view(-1)
ACC += torch.sum(iden == out_iden).item()
loss_tot += loss.item() * bs
cnt += bs
train_loss, train_acc = loss_tot * 1.0 / cnt, ACC * 100.0 / cnt
test_acc = test(model, criterion, testloader)
interval = time.time() - tf
if test_acc > best_ACC:
best_ACC = test_acc
best_model = deepcopy(model)
if (epoch+1) % 10 == 0:
torch.save({'state_dict':model.state_dict()}, os.path.join(model_path, "allclass_epoch{}.tar").format(epoch))
print("Epoch:{}\tTime:{:.2f}\tTrain Loss:{:.2f}\tTrain Acc:{:.2f}\tTest Acc:{:.2f}".format(epoch, interval, train_loss, train_acc, test_acc))
#scheduler.step()
print("Best Acc:{:.2f}".format(best_ACC))
return best_model, best_ACC
def train_vib(args, model, criterion, optimizer, trainloader, testloader, n_epochs):
best_ACC = 0.0
model_name = args['dataset']['model_name']
for epoch in range(n_epochs):
tf = time.time()
ACC, cnt, loss_tot = 0, 0, 0.0
for i, (img, iden) in enumerate(trainloader):
img, one_hot, iden = img.to(device), one_hot.to(device), iden.to(device)
bs = img.size(0)
iden = iden.view(-1)
___, out_prob, mu, std = model(img, "train")
cross_loss = criterion(out_prob, one_hot)
info_loss = - 0.5 * (1 + 2 * std.log() - mu.pow(2) - std.pow(2)).sum(dim=1).mean()
loss = cross_loss + beta * info_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
out_iden = torch.argmax(out_prob, dim=1).view(-1)
ACC += torch.sum(iden == out_iden).item()
loss_tot += loss.item() * bs
cnt += bs
train_loss, train_acc = loss_tot * 1.0 / cnt, ACC * 100.0 / cnt
test_loss, test_acc = test(model, criterion, testloader)
interval = time.time() - tf
if test_acc > best_ACC:
best_ACC = test_acc
best_model = deepcopy(model)
print("Epoch:{}\tTime:{:.2f}\tTrain Loss:{:.2f}\tTrain Acc:{:.2f}\tTest Acc:{:.2f}".format(epoch, interval, train_loss, train_acc, test_acc))
print("Best Acc:{:.2f}".format(best_ACC))
return best_model, best_ACC