Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

why set lr = 2e-4 for oxford102 flowers dataset? #3

Open
lvyufeng opened this issue Nov 27, 2022 · 0 comments
Open

why set lr = 2e-4 for oxford102 flowers dataset? #3

lvyufeng opened this issue Nov 27, 2022 · 0 comments

Comments

@lvyufeng
Copy link

lvyufeng commented Nov 27, 2022

why set lr = 2e-4 for oxford102 flowers dataset? I've tried on denoising-diffusion-pytorch and my implementation denoising-diffusion-mindspore, the loss waves around 0.4 and the sampled image are always noisy.

Is the weight initialization method not the same between Pytorch and Jax? I use the training config below which can sample a better image:

model = Unet(
    dim = 64,
    dim_mults = (1, 2, 4, 8)
)

diffusion = GaussianDiffusion(
    model,
    image_size = 128,
    timesteps = 1000,           # number of steps
    sampling_timesteps = 250,   # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
    loss_type = 'l1'            # L1 or L2
)

trainer = Trainer(
    diffusion,
    path,
    train_batch_size = 16,
    train_lr = 8e-5,
    train_num_steps = 700000,         # total training steps
    gradient_accumulate_every = 2,    # gradient accumulation steps
    ema_decay = 0.995,                # exponential moving average decay
    amp_level = 'O1',                        # turn on mixed precision
)

trainer.train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant