-
Notifications
You must be signed in to change notification settings - Fork 61
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
20 changed files
with
269 additions
and
298 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
# Copyright (C) 2018 Elvis Yu-Jing Lin <[email protected]> | ||
# | ||
# This work is licensed under the MIT License. To view a copy of this license, | ||
# This work is licensed under the MIT License. To view a copy of this license, | ||
# visit https://opensource.org/licenses/MIT. | ||
|
||
"""AttGAN, generator, and discriminator.""" | ||
|
@@ -16,8 +16,8 @@ | |
MAX_DIM = 64 * 16 # 1024 | ||
|
||
class Generator(nn.Module): | ||
def __init__(self, enc_dim=64, enc_layers=5, enc_norm_fn='batchnorm', enc_acti_fn='lrelu', | ||
dec_dim=64, dec_layers=5, dec_norm_fn='batchnorm', dec_acti_fn='relu', | ||
def __init__(self, enc_dim=64, enc_layers=5, enc_norm_fn='batchnorm', enc_acti_fn='lrelu', | ||
dec_dim=64, dec_layers=5, dec_norm_fn='batchnorm', dec_acti_fn='relu', | ||
n_attrs=13, shortcut_layers=1, inject_layers=0, img_size=128): | ||
super(Generator, self).__init__() | ||
self.shortcut_layers = min(shortcut_layers, dec_layers - 1) | ||
|
@@ -85,7 +85,7 @@ def forward(self, x, a=None, mode='enc-dec'): | |
|
||
class Discriminators(nn.Module): | ||
# No instancenorm in fcs in source code, which is different from paper. | ||
def __init__(self, dim=64, norm_fn='instancenorm', acti_fn='lrelu', | ||
def __init__(self, dim=64, norm_fn='instancenorm', acti_fn='lrelu', | ||
fc_dim=1024, fc_norm_fn='none', fc_acti_fn='lrelu', n_layers=5, img_size=128): | ||
super(Discriminators, self).__init__() | ||
self.f_size = img_size // 2**n_layers | ||
|
@@ -100,11 +100,11 @@ def __init__(self, dim=64, norm_fn='instancenorm', acti_fn='lrelu', | |
n_in = n_out | ||
self.conv = nn.Sequential(*layers) | ||
self.fc_adv = nn.Sequential( | ||
LinearBlock(1024 * self.f_size * self.f_size, fc_dim, fc_norm_fn, fc_acti_fn), | ||
LinearBlock(1024 * self.f_size * self.f_size, fc_dim, fc_norm_fn, fc_acti_fn), | ||
LinearBlock(fc_dim, 1, 'none', 'none') | ||
) | ||
self.fc_cls = nn.Sequential( | ||
LinearBlock(1024 * self.f_size * self.f_size, fc_dim, fc_norm_fn, fc_acti_fn), | ||
LinearBlock(1024 * self.f_size * self.f_size, fc_dim, fc_norm_fn, fc_acti_fn), | ||
LinearBlock(fc_dim, 13, 'none', 'none') | ||
) | ||
|
||
|
@@ -122,32 +122,32 @@ def forward(self, x): | |
|
||
# multilabel_soft_margin_loss = sigmoid + binary_cross_entropy | ||
|
||
l1 = 100.0 | ||
l2 = 10.0 | ||
l3 = 1.0 | ||
|
||
class AttGAN(): | ||
def __init__(self, args): | ||
self.mode = args.mode | ||
self.gpu = args.gpu | ||
self.multi_gpu = args.multi_gpu if 'multi_gpu' in args else False | ||
self.lambda_1 = args.lambda_1 | ||
self.lambda_2 = args.lambda_2 | ||
self.lambda_3 = args.lambda_3 | ||
self.lambda_gp = args.lambda_gp | ||
|
||
self.G = Generator( | ||
args.enc_dim, args.enc_layers, args.enc_norm, args.enc_acti, | ||
args.dec_dim, args.dec_layers, args.dec_norm, args.dec_acti, | ||
args.enc_dim, args.enc_layers, args.enc_norm, args.enc_acti, | ||
args.dec_dim, args.dec_layers, args.dec_norm, args.dec_acti, | ||
args.n_attrs, args.shortcut_layers, args.inject_layers, args.img_size | ||
) | ||
self.G.train() | ||
if self.gpu: self.G.cuda() | ||
summary(self.G, [(3, args.img_size, args.img_size), (args.n_attrs,)], batch_size=4, use_gpu=self.gpu) | ||
summary(self.G, [(3, args.img_size, args.img_size), (args.n_attrs, 1, 1)], batch_size=4, device='cuda' if args.gpu else 'cpu') | ||
|
||
self.D = Discriminators( | ||
args.dis_dim, args.dis_norm, args.dis_acti, | ||
args.dis_dim, args.dis_norm, args.dis_acti, | ||
args.dis_fc_dim, args.dis_fc_norm, args.dis_fc_acti, args.dis_layers, args.img_size | ||
) | ||
self.D.train() | ||
if self.gpu: self.D.cuda() | ||
summary(self.D, [(3, args.img_size, args.img_size)], batch_size=4, use_gpu=self.gpu) | ||
summary(self.D, [(3, args.img_size, args.img_size)], batch_size=4, device='cuda' if args.gpu else 'cpu') | ||
|
||
if self.multi_gpu: | ||
self.G = nn.DataParallel(self.G) | ||
|
@@ -179,14 +179,14 @@ def trainG(self, img_a, att_a, att_a_, att_b, att_b_): | |
gf_loss = F.binary_cross_entropy_with_logits(d_fake, torch.ones_like(d_fake)) | ||
gc_loss = F.binary_cross_entropy_with_logits(dc_fake, att_b) | ||
gr_loss = F.l1_loss(img_recon, img_a) | ||
g_loss = gf_loss + l2 * gc_loss + l1 * gr_loss | ||
g_loss = gf_loss + self.lambda_2 * gc_loss + self.lambda_1 * gr_loss | ||
|
||
self.optim_G.zero_grad() | ||
g_loss.backward() | ||
self.optim_G.step() | ||
|
||
errG = { | ||
'g_loss': g_loss.item(), 'gf_loss': gf_loss.item(), | ||
'g_loss': g_loss.item(), 'gf_loss': gf_loss.item(), | ||
'gc_loss': gc_loss.item(), 'gr_loss': gr_loss.item() | ||
} | ||
return errG | ||
|
@@ -235,7 +235,7 @@ def interpolate(a, b=None): | |
F.binary_cross_entropy_with_logits(d_fake, torch.zeros_like(d_fake)) | ||
df_gp = gradient_penalty(self.D, img_a) | ||
dc_loss = F.binary_cross_entropy_with_logits(dc_real, att_a) | ||
d_loss = df_loss + 10 * df_gp + l3 * dc_loss | ||
d_loss = df_loss + self.lambda_gp * df_gp + self.lambda_3 * dc_loss | ||
|
||
self.optim_D.zero_grad() | ||
d_loss.backward() | ||
|
@@ -257,9 +257,9 @@ def eval(self): | |
|
||
def save(self, path): | ||
states = { | ||
'G': self.G.state_dict(), | ||
'D': self.D.state_dict(), | ||
'optim_G': self.optim_G.state_dict(), | ||
'G': self.G.state_dict(), | ||
'D': self.D.state_dict(), | ||
'optim_G': self.optim_G.state_dict(), | ||
'optim_D': self.optim_D.state_dict() | ||
} | ||
torch.save(states, path) | ||
|
@@ -302,6 +302,10 @@ def saveG(self, path): | |
parser.add_argument('--dec_acti', dest='dec_acti', type=str, default='relu') | ||
parser.add_argument('--dis_acti', dest='dis_acti', type=str, default='lrelu') | ||
parser.add_argument('--dis_fc_acti', dest='dis_fc_acti', type=str, default='relu') | ||
parser.add_argument('--lambda_1', dest='lambda_1', type=float, default=100.0) | ||
parser.add_argument('--lambda_2', dest='lambda_2', type=float, default=10.0) | ||
parser.add_argument('--lambda_3', dest='lambda_3', type=float, default=1.0) | ||
parser.add_argument('--lambda_gp', dest='lambda_gp', type=float, default=10.0) | ||
parser.add_argument('--mode', dest='mode', default='wgan', choices=['wgan', 'lsgan', 'dcgan']) | ||
parser.add_argument('--lr', dest='lr', type=float, default=0.0002, help='learning rate') | ||
parser.add_argument('--beta1', dest='beta1', type=float, default=0.5) | ||
|
@@ -310,4 +314,4 @@ def saveG(self, path): | |
args = parser.parse_args() | ||
args.n_attrs = 13 | ||
args.betas = (args.beta1, args.beta2) | ||
attgan = AttGAN(args) | ||
attgan = AttGAN(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
# Copyright (C) 2018 Elvis Yu-Jing Lin <[email protected]> | ||
# | ||
# This work is licensed under the MIT License. To view a copy of this license, | ||
# This work is licensed under the MIT License. To view a copy of this license, | ||
# visit https://opensource.org/licenses/MIT. | ||
|
||
"""Custom datasets for CelebA and CelebA-HQ.""" | ||
|
@@ -10,9 +10,31 @@ | |
import torch | ||
import torch.utils.data as data | ||
import torchvision.transforms as transforms | ||
from skimage import io | ||
from PIL import Image | ||
|
||
|
||
class Custom(data.Dataset): | ||
def __init__(self, data_path, attr_path, image_size, selected_attrs): | ||
self.data_path = data_path | ||
att_list = open(attr_path, 'r', encoding='utf-8').readlines()[1].split() | ||
atts = [att_list.index(att) + 1 for att in selected_attrs] | ||
self.images = np.loadtxt(attr_path, skiprows=2, usecols=[0], dtype=np.str) | ||
self.labels = np.loadtxt(attr_path, skiprows=2, usecols=atts, dtype=np.int) | ||
|
||
self.tf = transforms.Compose([ | ||
transforms.Resize(image_size), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | ||
]) | ||
|
||
def __getitem__(self, index): | ||
img = self.tf(Image.open(os.path.join(self.data_path, self.images[index]))) | ||
att = torch.tensor((self.labels[index] + 1) // 2) | ||
return img, att | ||
|
||
def __len__(self): | ||
return len(self.images) | ||
|
||
class CelebA(data.Dataset): | ||
def __init__(self, data_path, attr_path, image_size, mode, selected_attrs): | ||
super(CelebA, self).__init__() | ||
|
@@ -33,16 +55,15 @@ def __init__(self, data_path, attr_path, image_size, mode, selected_attrs): | |
self.labels = labels[182637:] | ||
|
||
self.tf = transforms.Compose([ | ||
transforms.ToPILImage(), | ||
transforms.CenterCrop(170), | ||
transforms.Resize(image_size), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | ||
transforms.CenterCrop(170), | ||
transforms.Resize(image_size), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | ||
]) | ||
|
||
self.length = len(self.images) | ||
def __getitem__(self, index): | ||
img = self.tf(io.imread(os.path.join(self.data_path, self.images[index]))) | ||
img = self.tf(Image.open(os.path.join(self.data_path, self.images[index]))) | ||
att = torch.tensor((self.labels[index] + 1) // 2) | ||
return img, att | ||
def __len__(self): | ||
|
@@ -72,15 +93,14 @@ def __init__(self, data_path, attr_path, image_list_path, image_size, mode, sele | |
self.labels = labels[28500:] | ||
|
||
self.tf = transforms.Compose([ | ||
transforms.ToPILImage(), | ||
transforms.Resize(image_size), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | ||
transforms.Resize(image_size), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | ||
]) | ||
|
||
self.length = len(self.images) | ||
def __getitem__(self, index): | ||
img = self.tf(io.imread(os.path.join(self.data_path, self.images[index]))) | ||
img = self.tf(Image.open(os.path.join(self.data_path, self.images[index]))) | ||
att = torch.tensor((self.labels[index] + 1) // 2) | ||
return img, att | ||
def __len__(self): | ||
|
@@ -125,7 +145,7 @@ def _set(att, value, att_name): | |
import torchvision.utils as vutils | ||
|
||
attrs_default = [ | ||
'Bald', 'Bangs', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Bushy_Eyebrows', | ||
'Bald', 'Bangs', 'Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Bushy_Eyebrows', | ||
'Eyeglasses', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'No_Beard', 'Pale_Skin', 'Young' | ||
] | ||
|
||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
6 | ||
Bald Bangs Black_Hair Blond_Hair Brown_Hair Bushy_Eyebrows Eyeglasses Male Mouth_Slightly_Open Mustache No_Beard Pale_Skin Young | ||
donald_trump.jpg -1 -1 -1 1 -1 1 -1 1 1 -1 1 -1 -1 | ||
emma_watson.jpg -1 1 -1 -1 1 -1 -1 -1 -1 -1 1 -1 1 | ||
jay_chou.jpeg -1 1 1 -1 -1 1 -1 1 1 -1 1 -1 1 | ||
ji-eun_lee.jpg -1 1 -1 -1 1 -1 -1 -1 1 -1 1 -1 1 | ||
tom_cruise.jpg -1 -1 -1 -1 1 1 -1 1 1 -1 1 -1 -1 | ||
yui_aragaki.jpg -1 1 -1 -1 1 -1 -1 -1 -1 -1 1 -1 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
# Copyright (C) 2018 Elvis Yu-Jing Lin <[email protected]> | ||
# | ||
# This work is licensed under the MIT License. To view a copy of this license, | ||
# This work is licensed under the MIT License. To view a copy of this license, | ||
# visit https://opensource.org/licenses/MIT. | ||
|
||
"""Helper functions for training.""" | ||
|
Oops, something went wrong.