-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
163 lines (123 loc) · 5.5 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
# import torchvision.models as models
from torch.utils.tensorboard import SummaryWriter
import torchvision
import PIL
from PIL import Image
from bdjscc import BDJSCC_ada as model
writer = SummaryWriter()
torch.cuda.set_device(1)
# torch.cuda.set_per_process_memory_fraction(0.3, 2)
# Save the model
def save_model(model, optimizer, epoch, loss, filename):
save_dict = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch,
'loss': loss
}
torch.save(save_dict, filename)
def store_test_image(input, output, epoch, i):
images_in = input[:9]
images_out = output[:9]
nrow = int(images_in.shape[0] ** 0.5)
grid_in = torchvision.utils.make_grid(images_in, nrow=nrow, normalize=True)
grid_out = torchvision.utils.make_grid(images_out, nrow=nrow, normalize=True)
# images_in = input.cpu().detach().numpy()
# images_out = output.cpu().detach().numpy()
# writer.add_image('Test Image', grid, epoch)
# Convert to PIL image
ndarr_in = grid_in.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
im_in = PIL.Image.fromarray(ndarr_in)
ndarr_out = grid_out.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
im_out = PIL.Image.fromarray(ndarr_out)
images_path = os.path.join(os.getcwd(), 'images')
os.makedirs(images_path, exist_ok=True)
im_in.save(os.path.join(images_path, f'test_image_{epoch}_{i}_in.png'))
im_out.save(os.path.join(images_path, f'test_image_{epoch}_{i}_out.png'))
if __name__ == '__main__':
# Hyperparameters
batch_size = 32
epochs = 5
learning_rate = 1e-4
checkpoint_path = os.path.join(os.getcwd(), 'checkpoints')
os.makedirs(checkpoint_path, exist_ok=True)
checkpoint_tar = os.path.join(checkpoint_path, 'checkpoint_ada_thick_rprelu_omini.tar')
checkpoint_tar_store = os.path.join(checkpoint_path, 'checkpoint_ada_thick_rprelu_omini.tar')
# checkpoint_tar = os.path.join(checkpoint_path, 'checkpoint_ada_thick_rprelu.tar')
# checkpoint_tar_store = os.path.join(checkpoint_path, 'checkpoint_ada_thick_rprelu_omini-.tar')
# Load the model
model = model().cuda()
# Load the optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
# lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs*10, eta_min=1e-5)
# Load the loss function
criterion = nn.MSELoss()
# Load the dataset
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
train_dataset = datasets.ImageFolder('/data/Users/lanli/ReActNet-master/dataset/imagenet/train', transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True)
# Load the test image
# image_test = Image.open('/data/Users/lanli/ReActNet-master/dataset/imagenet/val/n01440764/ILSVRC2012_val_00000293.JPEG')
# transforms = transforms.Compose([
# transforms.Resize((256, 256)),
# transforms.ToTensor()
# ])
# image_test = transforms(image_test).unsqueeze(0).cuda()
# Train the model(autoencoder)
if os.path.exists(checkpoint_tar):
checkpoint = torch.load(checkpoint_tar)
model.load_state_dict(checkpoint['model'], strict=False)
# optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']
loss = checkpoint['loss']
print(f"Resuming training from epoch {start_epoch} with loss {loss}")
else:
start_epoch = 0
loss = 0
start_epoch = 0
loss_best = 1
for epoch in range(start_epoch, epochs):
for i, (images, _) in enumerate(train_loader):
images = images.cuda()
# Forward pass
output = model(images)
# Crop the input and output
# images = images[:, :, 4:252, 4:252]
# output = output[:, :, 4:252, 4:252]
# Compute the loss
loss = criterion(output, images)
# Zero the gradients
optimizer.zero_grad()
# Backward pass
loss.backward()
# Update the weights
optimizer.step()
if i % 10 == 0 and i != 0:
print(f"Epoch [{epoch}/{epochs}], Step [{i}/{len(train_loader)}], Loss: {loss.item()}")
mse = loss.item()
psnr = 10 * torch.log10(torch.tensor(1.0)) - 10 * torch.log10(torch.tensor(mse))
print(f"PSNR: {psnr}")
# Test the model
# model.eval()
# with torch.no_grad():
# output_test = model(image_test)
# loss_test = criterion(output_test[:, :, 4:252, 4:252], image_test[:, :, 4:252, 4:252])
# print(f"Test Loss: {loss_test.item()}")
# model.train()
if i % 500 == 0 and i != 0:
if loss.item() < loss_best:
loss_best = loss.item()
print(f"Saving the model with loss {loss_best}")
save_model(model, optimizer, epoch, loss, checkpoint_tar_store)
store_test_image(images, output, epoch, i)
# lr_scheduler.step()
writer.add_scalar('Loss/train1', loss.item(), epoch * len(train_loader) + i)