min-fid-jax Minimal parallel FID computation with Jax (Flax). See main.py for example training integration with the diffusers package