Skip to content

Commit

Permalink
Add hotdog classification training code
Browse files Browse the repository at this point in the history
  • Loading branch information
wtffqbpl committed Feb 7, 2025
1 parent 871200f commit d9b353c
Showing 1 changed file with 73 additions and 1 deletion.
74 changes: 73 additions & 1 deletion computer_vision/image_augmation.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ def training_batch(model: nn.Module, x, y, loss_fn, optimizer, devices=None):
def train(model: nn.Module,
data_iter: data.DataLoader,
test_iter: data.DataLoader,
loss_fn, optimizer,
loss_fn,
optimizer,
num_epochs: int,
devices=None):

Expand Down Expand Up @@ -305,6 +306,77 @@ def test_cifar10_with_resnet18(self):

train(model, train_iter, test_iter, loss_fn, optimizer, num_epochs, devices)

def test_hotdog_classification(self):
dlf.DATA_HUB['hotdog'] = (dlf.DATA_URL + 'hotdog.zip', 'fba480ffa8aa7e0febbb511d181409f899b9baa5')
data_dir = dlf.download_extract('hotdog')
print(data_dir)

train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'))

hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs = [train_imgs[-i - 1][0] for i in range(8)]

# Show images
Image.show_images(hotdogs + not_hotdogs, 2, 8, scale=1.4)

# Using mean and variance of the RGB channels for each channel normalization.
normalize = torchvision.transforms.Normalize(
[0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

train_augs = torchvision.transforms.Compose([
torchvision.transforms.RandomResizedCrop(224),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
normalize])

test_augs = torchvision.transforms.Compose([
torchvision.transforms.Resize([256, 256]),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.ToTensor(),
normalize])

finetune_net = torchvision.models.resnet18(pretrained=True)

print(finetune_net.fc)
# Linear(in_features=512, out_features=1000, bias=True)

# hyperparameters
batch_size = 256
param_group = True
learning_rate = 5e-5
num_epochs = 5

finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2)
nn.init.xavier_uniform_(finetune_net.fc.weight)

train_iter = data.DataLoader(torchvision.datasets.ImageFolder(
os.path.join(data_dir, 'train'), transform=train_augs),
batch_size=batch_size, shuffle=True)
test_iter = data.DataLoader(torchvision.datasets.ImageFolder(
os.path.join(data_dir, 'test'), transform=test_augs),
batch_size=batch_size)

devices = dlf.devices()

loss_fn = nn.CrossEntropyLoss()

if param_group:
params_1x = [param for name, param in finetune_net.named_parameters()
if name not in ['fc.weight', 'fc.bias']]
trainer = torch.optim.SGD(
[
{'params': params_1x},
{'params': finetune_net.fc.parameters(), 'lr': learning_rate * 10}
],
lr=learning_rate, weight_decay=0.001)
else:
trainer = torch.optim.SGD(finetune_net.parameters(), lr=learning_rate,
weight_decay=0.001)

train(finetune_net, train_iter, test_iter, loss_fn, trainer, num_epochs, devices)
self.assertTrue(True)


if __name__ == "__main__":
unittest.main(verbosity=True)

0 comments on commit d9b353c

Please sign in to comment.