You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Is the weight initialization method not the same between Pytorch and Jax? I use the training config below which can sample a better image:
model=Unet(
dim=64,
dim_mults= (1, 2, 4, 8)
)
diffusion=GaussianDiffusion(
model,
image_size=128,
timesteps=1000, # number of stepssampling_timesteps=250, # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])loss_type='l1'# L1 or L2
)
trainer=Trainer(
diffusion,
path,
train_batch_size=16,
train_lr=8e-5,
train_num_steps=700000, # total training stepsgradient_accumulate_every=2, # gradient accumulation stepsema_decay=0.995, # exponential moving average decayamp_level='O1', # turn on mixed precision
)
trainer.train()
The text was updated successfully, but these errors were encountered:
why set
lr = 2e-4
for oxford102 flowers dataset? I've tried on denoising-diffusion-pytorch and my implementation denoising-diffusion-mindspore, the loss waves around 0.4 and the sampled image are always noisy.Is the weight initialization method not the same between Pytorch and Jax? I use the training config below which can sample a better image:
The text was updated successfully, but these errors were encountered: