diff --git a/README.md b/README.md index eef66342..8db07a6b 100644 --- a/README.md +++ b/README.md @@ -98,22 +98,22 @@ To generate images, run the following command: Single and Multi host inference is supported with sharding annotations: ```bash - python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_run" + python -m src.maxdiffusion.benchmarks.inference.sdxlbase.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_run" ``` Single host pmap version: ```bash - python -m src.maxdiffusion.generate_sdxl_replicated + python -m src.maxdiffusion.benchmarks.inference.sdxlbase.generate_sdxl_replicated ``` - **Stable Diffusion 2 base** ```bash - python -m src.maxdiffusion.generate src/maxdiffusion/configs/base_2_base.yml run_name="my_run" + python -m src.maxdiffusion.benchmarks.inference.sd2.generate src/maxdiffusion/configs/base_2_base.yml run_name="my_run" - **Stable Diffusion 2.1** ```bash - python -m src.maxdiffusion.generate src/maxdiffusion/configs/base21.yml run_name="my_run" + python -m src.maxdiffusion.benchmarks.inference.sd2.generate src/maxdiffusion/configs/base21.yml run_name="my_run" ``` ## SDXL Lightning @@ -121,7 +121,7 @@ To generate images, run the following command: Single and Multi host inference is supported with sharding annotations: ```bash - python -m src.maxdiffusion.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_run" lightning_repo="ByteDance/SDXL-Lightning" lightning_ckpt="sdxl_lightning_4step_unet.safetensors" + python -m src.maxdiffusion.benchmarks.inference.sdxlbase.generate_sdxl src/maxdiffusion/configs/base_xl.yml run_name="my_run" lightning_repo="ByteDance/SDXL-Lightning" lightning_ckpt="sdxl_lightning_4step_unet.safetensors" ``` ## ControlNet @@ -131,13 +131,13 @@ To generate images, run the following command: - Stable Diffusion 1.4 ```bash - python src/maxdiffusion/controlnet/generate_controlnet_replicated.py + python src/maxdiffusion/benchmarks/inference/controlnet/generate_controlnet_replicated.py ``` - Stable Diffusion XL ```bash - python src/maxdiffusion/controlnet/generate_controlnet_sdxl_replicated.py + python src/maxdiffusion/benchmarks/inference/controlnet/generate_controlnet_sdxl_replicated.py ``` diff --git a/metrics.json b/metrics.json new file mode 100644 index 00000000..51f8fe9f --- /dev/null +++ b/metrics.json @@ -0,0 +1 @@ +{"metrics": {"compile_time": 34.46660280227661, "inference_time_1": 5.136486291885376, "inference_time_2": 5.136287450790405, "inference_time_3": 5.136406660079956, "inference_time_4": 5.136098861694336, "gcs_metrics": true, "save_config_to_gcs": false, "log_period": 100, "from_pt": false, "split_head_dim": true, "norm_num_groups": 32, "train_new_unet": false, "dcn_data_parallelism": -1, "dcn_fsdp_parallelism": 1, "dcn_tensor_parallelism": 1, "ici_data_parallelism": -1, "ici_fsdp_parallelism": 1, "ici_tensor_parallelism": 1, "resolution": 1024, "center_crop": false, "random_flip": false, "tokenize_captions_num_proc": 4, "transform_images_num_proc": 4, "reuse_example_batch": false, "enable_data_shuffling": true, "cache_latents_text_encoder_outputs": true, "learning_rate": 4e-07, "scale_lr": false, "max_train_samples": -1, "max_train_steps": 200, "num_train_epochs": 1, "seed": 0, "per_device_batch_size": 2, "warmup_steps_fraction": 0.0, "learning_rate_schedule_steps": 200, "adam_b1": 0.9, "adam_b2": 0.999, "adam_eps": 1e-08, "adam_weight_decay": 0.01, "enable_profiler": true, "skip_first_n_steps_for_profiler": 1, "profiler_steps": 5, "guidance_scale": 9, "guidance_rescale": 0.0, "num_inference_steps": 20, "lightning_from_pt": true, "enable_mllog": false, "use_controlnet": false, "controlnet_from_pt": true, "controlnet_conditioning_scale": 0.5}, "dimensions": {"date": "20240531-215048", "run_name": "my_run", "metrics_file": "", "model_name": "SDXL-1.0", "pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-base-1.0", "revision": "refs/pr/95", "dtype": "bfloat16", "attention": "dot_product", "flash_block_sizes": "{}", "diffusion_scheduler_config": "{'_class_name': '', 'prediction_type': '', 'rescale_zero_terminal_snr': False, 'timestep_spacing': ''}", "base_output_directory": "", "mesh_axes": "['data', 'fsdp', 'tensor']", "logical_axis_rules": "(('batch', 'data'), ('activation_batch', 'data'), ('activation_length', 'fsdp'), ('out_channels', 'fsdp'), ('conv_out', 'fsdp'), ('length', 'fsdp'))", "data_sharding": "(('data', 'fsdp', 'tensor'),)", "dataset_name": "diffusers/pokemon-gpt4-captions", "dataset_save_location": "/tmp/pokemon-gpt4-captions_xl", "train_data_dir": "", "dataset_config_name": "", "cache_dir": "", "image_column": "image", "caption_column": "text", "output_dir": "sdxl-model-finetuned", "prompt": "A magical castle in the middle of a forest, artistic drawing", "negative_prompt": "purple, red", "lightning_repo": "", "lightning_ckpt": "", "controlnet_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0", "controlnet_image": "https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/Google_%22G%22_logo.svg/1024px-Google_%22G%22_logo.svg.png", "tensorboard_dir": "sdxl-model-finetuned/my_run/tensorboard/", "checkpoint_dir": "sdxl-model-finetuned/my_run/checkpoints/", "metrics_dir": "sdxl-model-finetuned/my_run/metrics/"}} \ No newline at end of file diff --git a/src/maxdiffusion/controlnet/__init__.py b/src/maxdiffusion/benchmarks/__init__.py similarity index 100% rename from src/maxdiffusion/controlnet/__init__.py rename to src/maxdiffusion/benchmarks/__init__.py diff --git a/src/maxdiffusion/benchmarks/inference/__init__.py b/src/maxdiffusion/benchmarks/inference/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/maxdiffusion/benchmarks/inference/controlnet/__init__.py b/src/maxdiffusion/benchmarks/inference/controlnet/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/maxdiffusion/controlnet/generate_controlnet_replicated.py b/src/maxdiffusion/benchmarks/inference/controlnet/generate_controlnet_replicated.py similarity index 67% rename from src/maxdiffusion/controlnet/generate_controlnet_replicated.py rename to src/maxdiffusion/benchmarks/inference/controlnet/generate_controlnet_replicated.py index 6131abd4..84101618 100644 --- a/src/maxdiffusion/controlnet/generate_controlnet_replicated.py +++ b/src/maxdiffusion/benchmarks/inference/controlnet/generate_controlnet_replicated.py @@ -14,7 +14,10 @@ limitations under the License. """ +import datetime +import json import os +import time from typing import Sequence from absl import app @@ -30,6 +33,8 @@ cc.set_cache_dir(os.path.expanduser("~/jax_cache")) +NUM_ITER = 5 + def run(config): rng = jax.random.PRNGKey(config.seed) @@ -67,7 +72,11 @@ def run(config): negative_prompt_ids = shard(negative_prompt_ids) processed_image = shard(processed_image) - output = pipe( + metrics_dict = {} + for iter in range(NUM_ITER): + if iter == 0: + s = time.time() + output = pipe( prompt_ids=prompt_ids, image=processed_image, params=p_params, @@ -76,7 +85,48 @@ def run(config): neg_prompt_ids=negative_prompt_ids, controlnet_conditioning_scale=controlnet_conditioning_scale, jit=True, - ).images + ).images + + metrics_dict["compile_time"] = time.time() - s + else: + s = time.time() + output = pipe( + prompt_ids=prompt_ids, + image=processed_image, + params=p_params, + prng_seed=rng, + num_inference_steps=config.num_inference_steps, + neg_prompt_ids=negative_prompt_ids, + controlnet_conditioning_scale=controlnet_conditioning_scale, + jit=True, + ).images + inference_time = time.time() - s + metrics_dict[f"inference_time_{iter}"] = inference_time + + dimensions_dict = {} + current_dt = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + dimensions_dict["date"] = current_dt + + + for dim in config.get_keys(): + val = config.get(dim) + if isinstance(val, str): + if dim == "model_name": + dimensions_dict[dim] = "ControlNet" + str(config.get(dim)) + else: + dimensions_dict[dim] = str(config.get(dim)) + elif isinstance(val, int) or isinstance(val, float): # noqa: E721 + metrics_dict[dim] = val + else: + dimensions_dict[dim] = str(val) + + final_dict = {} + final_dict["metrics"] = metrics_dict + final_dict["dimensions"] = dimensions_dict + + with open("metrics.json", 'w') as f: + f.write(json.dumps(final_dict)) + output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) output_images[0].save("generated_image.png") diff --git a/src/maxdiffusion/controlnet/generate_controlnet_sdxl_replicated.py b/src/maxdiffusion/benchmarks/inference/controlnet/generate_controlnet_sdxl_replicated.py similarity index 70% rename from src/maxdiffusion/controlnet/generate_controlnet_sdxl_replicated.py rename to src/maxdiffusion/benchmarks/inference/controlnet/generate_controlnet_sdxl_replicated.py index 4760f198..16038de9 100644 --- a/src/maxdiffusion/controlnet/generate_controlnet_sdxl_replicated.py +++ b/src/maxdiffusion/benchmarks/inference/controlnet/generate_controlnet_sdxl_replicated.py @@ -14,7 +14,10 @@ limitations under the License. """ +import datetime +import json import os +import time from typing import Sequence from absl import app @@ -33,6 +36,8 @@ ) import cv2 +NUM_ITER = 5 + cc.set_cache_dir(os.path.expanduser("~/jax_cache")) def create_key(seed=0): @@ -84,7 +89,25 @@ def run(config): negative_prompt_ids = shard(negative_prompt_ids) processed_image = shard(processed_image) - output = pipe( + metrics_dict = {} + for iter in range(NUM_ITER): + if iter == 0: + s = time.time() + output = pipe( + prompt_ids=prompt_ids, + image=processed_image, + params=p_params, + prng_seed=rng, + num_inference_steps=config.num_inference_steps, + neg_prompt_ids=negative_prompt_ids, + controlnet_conditioning_scale=controlnet_conditioning_scale, + jit=True, + ).images + + metrics_dict["compile_time"] = time.time() - s + else: + s = time.time() + output = pipe( prompt_ids=prompt_ids, image=processed_image, params=p_params, @@ -93,7 +116,33 @@ def run(config): neg_prompt_ids=negative_prompt_ids, controlnet_conditioning_scale=controlnet_conditioning_scale, jit=True, - ).images + ).images + + inference_time = time.time() - s + metrics_dict[f"inference_time_{iter}"] = inference_time + + dimensions_dict = {} + current_dt = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + dimensions_dict["date"] = current_dt + + for dim in config.get_keys(): + val = config.get(dim) + if isinstance(val, str): + if dim == "model_name": + dimensions_dict[dim] = "ControlNet" + str(config.get(dim)) + else: + dimensions_dict[dim] = str(config.get(dim)) + elif isinstance(val, int) or isinstance(val, float): + metrics_dict[dim] = val + else: + dimensions_dict[dim] = str(val) + + final_dict = {} + final_dict["metrics"] = metrics_dict + final_dict["dimensions"] = dimensions_dict + + with open("metrics.json", 'w') as f: + f.write(json.dumps(final_dict)) output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) output_images[0].save("generated_image.png") diff --git a/src/maxdiffusion/benchmarks/inference/sd2/__init__.py b/src/maxdiffusion/benchmarks/inference/sd2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/maxdiffusion/generate.py b/src/maxdiffusion/benchmarks/inference/sd2/generate.py similarity index 86% rename from src/maxdiffusion/generate.py rename to src/maxdiffusion/benchmarks/inference/sd2/generate.py index 394625b7..0d483547 100644 --- a/src/maxdiffusion/generate.py +++ b/src/maxdiffusion/benchmarks/inference/sd2/generate.py @@ -14,7 +14,9 @@ limitations under the License. """ +import datetime import functools +import json import os import time from typing import Sequence @@ -45,6 +47,8 @@ cc.set_cache_dir(os.path.expanduser("~/jax_cache")) +NUM_ITER = 5 + def loop_body(step, args, model, pipeline, prompt_embeds, guidance_scale, guidance_rescale): latents, scheduler_state, state = args latents_input = jnp.concatenate([latents] * 2) @@ -193,13 +197,39 @@ def run_inference(unet_state, vae_state, params, rng, config, batch_size, pipeli out_shardings=None ) - s = time.time() - p_run_inference(unet_state, vae_state, params).block_until_ready() - print("compile time: ", (time.time() - s)) + metrics_dict = {} + for iter in range(NUM_ITER): + if iter == 0: + s = time.time() + p_run_inference(unet_state, vae_state, params).block_until_ready() + metrics_dict["compile_time"] = time.time() - s + else: + s = time.time() + images = p_run_inference(unet_state, vae_state, params).block_until_ready() + images.block_until_ready() + inference_time = time.time() - s + metrics_dict[f"inference_time_{iter}"] = inference_time + + dimensions_dict = {} + current_dt = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + dimensions_dict["date"] = current_dt + + for dim in config.get_keys(): + val = config.get(dim) + if isinstance(val, str): + dimensions_dict[str(dim)] = str(config.get(dim)) + elif isinstance(val, int) or isinstance(val, float): + metrics_dict[dim] = val + else: + dimensions_dict[str(dim)] = str(val) + + final_dict = {} + final_dict["metrics"] = metrics_dict + final_dict["dimensions"] = dimensions_dict + + with open("metrics.json", 'w') as f: + f.write(json.dumps(final_dict)) - s = time.time() - images = p_run_inference(unet_state, vae_state, params).block_until_ready() - print("inference time: ",(time.time() - s)) numpy_images = np.array(images) images = VaeImageProcessor.numpy_to_pil(numpy_images) for i, image in enumerate(images): diff --git a/src/maxdiffusion/benchmarks/inference/sdxlbase/__init__.py b/src/maxdiffusion/benchmarks/inference/sdxlbase/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/maxdiffusion/generate_sdxl.py b/src/maxdiffusion/benchmarks/inference/sdxlbase/generate_sdxl.py similarity index 89% rename from src/maxdiffusion/generate_sdxl.py rename to src/maxdiffusion/benchmarks/inference/sdxlbase/generate_sdxl.py index 27627ff3..560a22a0 100644 --- a/src/maxdiffusion/generate_sdxl.py +++ b/src/maxdiffusion/benchmarks/inference/sdxlbase/generate_sdxl.py @@ -14,6 +14,8 @@ limitations under the License. """ +import datetime +import json import os import functools from absl import app @@ -42,8 +44,6 @@ create_device_mesh, get_dtype, get_states, - activate_profiler, - deactivate_profiler, device_put_replicated, get_flash_block_sizes, ) @@ -55,6 +55,8 @@ cc.set_cache_dir(os.path.expanduser("~/jax_cache")) +NUM_ITER = 5 + def loop_body(step, args, model, pipeline, added_cond_kwargs, prompt_embeds, guidance_scale, guidance_rescale): latents, scheduler_state, state = args latents_input = jnp.concatenate([latents] * 2) @@ -254,27 +256,42 @@ def run_inference(unet_state, vae_state, params, rng, config, batch_size, pipeli out_shardings=None ) - s = time.time() - p_run_inference(unet_state, vae_state, params).block_until_ready() - print("compile time: ", (time.time() - s)) - s = time.time() - images = p_run_inference(unet_state, vae_state, params).block_until_ready() - images.block_until_ready() - print("inference time: ",(time.time() - s)) - s = time.time() - images = p_run_inference(unet_state, vae_state, params).block_until_ready() #run_inference(unet_state, vae_state, latents, scheduler_state) - images.block_until_ready() - print("inference time: ",(time.time() - s)) - s = time.time() - images = p_run_inference(unet_state, vae_state, params).block_until_ready() # run_inference(unet_state, vae_state, latents, scheduler_state) - images.block_until_ready() - print("inference time: ",(time.time() - s)) - s = time.time() - activate_profiler(config) - images = p_run_inference(unet_state, vae_state, params).block_until_ready() - deactivate_profiler(config) - images.block_until_ready() - print("inference time: ",(time.time() - s)) + metrics_dict = {} + for iter in range(NUM_ITER): + if iter == 0: + s = time.time() + p_run_inference(unet_state, vae_state, params).block_until_ready() + metrics_dict["compile_time"] = time.time() - s + else: + s = time.time() + images = p_run_inference(unet_state, vae_state, params).block_until_ready() + images.block_until_ready() + inference_time = time.time() - s + metrics_dict[f"inference_time_{iter}"] = inference_time + + dimensions_dict = {} + current_dt = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + dimensions_dict["date"] = current_dt + + for dim in config.get_keys(): + val = config.get(dim) + if isinstance(val, str): + dimensions_dict[str(dim)] = str(config.get(dim)) + elif isinstance(val, int) or isinstance(val, float): + metrics_dict[dim] = val + else: + dimensions_dict[str(dim)] = str(val) + + final_dict = {} + final_dict["metrics"] = metrics_dict + final_dict["dimensions"] = dimensions_dict + + print("final_dict is ", final_dict) + + with open("metrics.json", 'w') as f: + f.write(json.dumps(final_dict)) + + images = jax.experimental.multihost_utils.process_allgather(images) numpy_images = np.array(images) images = VaeImageProcessor.numpy_to_pil(numpy_images) diff --git a/src/maxdiffusion/generate_sdxl_replicated.py b/src/maxdiffusion/benchmarks/inference/sdxlbase/generate_sdxl_replicated.py similarity index 86% rename from src/maxdiffusion/generate_sdxl_replicated.py rename to src/maxdiffusion/benchmarks/inference/sdxlbase/generate_sdxl_replicated.py index 0a4bb742..feba7d4e 100644 --- a/src/maxdiffusion/generate_sdxl_replicated.py +++ b/src/maxdiffusion/benchmarks/inference/sdxlbase/generate_sdxl_replicated.py @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. """ +import datetime +import json import time import jax @@ -30,6 +32,8 @@ NUM_DEVICES = jax.device_count() +NUM_ITER = 5 + # 1. Let's start by downloading the model and loading it into our pipeline class # Adhering to JAX's functional approach, the model's parameters are returned seperatetely and # will have to be passed to the pipeline during inference @@ -138,19 +142,37 @@ def generate(prompt, negative_prompt, seed=default_seed, guidance_scale=default_ # Fills the C++ dispatch cache. # When using jit, this extra step is done automatically, but when using AOT compilation, # it doesn't happen until the function call is made. -start = time.time() -prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang" -neg_prompt = "cartoon, illustration, animation. face. male, female" -images = generate(prompt, neg_prompt) -print(f"First inference in {time.time() - start}") -# 9. From this point forward, any calls to generate should result in a faster inference -# time and it won't change. -start = time.time() prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang" neg_prompt = "cartoon, illustration, animation. face. male, female" -images = generate(prompt, neg_prompt) -print(f"Inference in {time.time() - start}") + +metrics_dict = {} +for iter in range(NUM_ITER): + if iter == 0: + s = time.time() + images = generate(prompt, neg_prompt) + metrics_dict["compile_time"] = time.time() - s + else: + # 9. From this point forward, any calls to generate should result in a faster inference + # time and it won't change. + s = time.time() + images = generate(prompt, neg_prompt) + inference_time = time.time() - s + metrics_dict[f"inference_time_{iter}"] = inference_time + +dimensions_dict = {} +current_dt = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") +dimensions_dict["date"] = current_dt +dimensions_dict["model_name"] = "SDXL-1.0-Replicated" + + +final_dict = {} +final_dict["metrics"] = metrics_dict +final_dict["dimensions"] = dimensions_dict + + +with open("metrics.json", 'w') as f: + f.write(json.dumps(final_dict)) for i, image in enumerate(images): image.save(f"castle_{i}.png") diff --git a/src/maxdiffusion/configs/base15.yml b/src/maxdiffusion/configs/base15.yml index 3ece19c6..a97e21fe 100644 --- a/src/maxdiffusion/configs/base15.yml +++ b/src/maxdiffusion/configs/base15.yml @@ -22,6 +22,7 @@ gcs_metrics: True save_config_to_gcs: False log_period: 100 +model_name: 'SDXL-1-5' pretrained_model_name_or_path: 'runwayml/stable-diffusion-v1-5' revision: 'flax' dtype: 'bfloat16' @@ -150,6 +151,7 @@ num_inference_steps: 30 enable_mllog: False # controlnet +use_controlnet: False controlnet_model_name_or_path: 'lllyasviel/sd-controlnet-canny' controlnet_from_pt: True controlnet_conditioning_scale: 1.0 diff --git a/src/maxdiffusion/configs/base21.yml b/src/maxdiffusion/configs/base21.yml index 5eba6b80..42e9862a 100644 --- a/src/maxdiffusion/configs/base21.yml +++ b/src/maxdiffusion/configs/base21.yml @@ -22,6 +22,7 @@ gcs_metrics: True save_config_to_gcs: False log_period: 100 +model_name: 'SDXL-2-1' pretrained_model_name_or_path: 'stabilityai/stable-diffusion-2-1' revision: 'bf16' dtype: 'bfloat16' diff --git a/src/maxdiffusion/configs/base_xl.yml b/src/maxdiffusion/configs/base_xl.yml index dcdb7288..f042b69a 100644 --- a/src/maxdiffusion/configs/base_xl.yml +++ b/src/maxdiffusion/configs/base_xl.yml @@ -22,6 +22,7 @@ gcs_metrics: True save_config_to_gcs: False log_period: 100 +model_name: 'SDXL-1.0' pretrained_model_name_or_path: 'stabilityai/stable-diffusion-xl-base-1.0' revision: 'refs/pr/95' dtype: 'bfloat16' @@ -157,6 +158,7 @@ lightning_ckpt: "" enable_mllog: False #controlnet +use_controlnet: False controlnet_model_name_or_path: 'diffusers/controlnet-canny-sdxl-1.0' controlnet_from_pt: True controlnet_conditioning_scale: 0.5 diff --git a/src/maxdiffusion/configs/base_xl_lightning.yml b/src/maxdiffusion/configs/base_xl_lightning.yml index d8b69ca4..f0bbd46b 100644 --- a/src/maxdiffusion/configs/base_xl_lightning.yml +++ b/src/maxdiffusion/configs/base_xl_lightning.yml @@ -22,6 +22,7 @@ gcs_metrics: True save_config_to_gcs: False log_period: 100 +model_name: 'SDXL-Lightning' pretrained_model_name_or_path: 'stabilityai/stable-diffusion-xl-base-1.0' revision: 'refs/pr/95' dtype: 'bfloat16' diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index eed27f76..02c7d297 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -136,6 +136,11 @@ def __setattr__(self, attr, value): def get_keys(self): return _config.keys + def get(self, attr): + if attr not in _config.keys: + raise ValueError(f"Requested key {attr}, not in config") + return _config.keys[attr] + def initialize(argv, **kwargs): global _config, config _config = _HyperParameters(argv, **kwargs) diff --git a/src/maxdiffusion/tests/controlnet_tests.py b/src/maxdiffusion/tests/controlnet_tests.py index 122bb12b..7499e97e 100644 --- a/src/maxdiffusion/tests/controlnet_tests.py +++ b/src/maxdiffusion/tests/controlnet_tests.py @@ -5,8 +5,8 @@ from ..import pyconfig from absl.testing import absltest -from maxdiffusion.controlnet.generate_controlnet_replicated import run as generate_run -from maxdiffusion.controlnet.generate_controlnet_sdxl_replicated import run as generate_run_sdxl +from maxdiffusion.benchmarks.inference.controlnet.generate_controlnet_replicated import run as generate_run +from maxdiffusion.benchmarks.inference.controlnet.generate_controlnet_sdxl_replicated import run as generate_run_sdxl from PIL import Image from skimage.metrics import structural_similarity as ssim diff --git a/src/maxdiffusion/tests/generate_sdxl_smoke_test.py b/src/maxdiffusion/tests/generate_sdxl_smoke_test.py index 9f6788b4..c66f367a 100644 --- a/src/maxdiffusion/tests/generate_sdxl_smoke_test.py +++ b/src/maxdiffusion/tests/generate_sdxl_smoke_test.py @@ -5,7 +5,7 @@ from ..import pyconfig from absl.testing import absltest -from maxdiffusion.generate_sdxl import run as generate_run_xl +from maxdiffusion.benchmarks.inference.sdxlbase.generate_sdxl import run as generate_run_xl from PIL import Image from skimage.metrics import structural_similarity as ssim diff --git a/src/maxdiffusion/tests/generate_smoke_test.py b/src/maxdiffusion/tests/generate_smoke_test.py index 646cf275..02c88e7f 100644 --- a/src/maxdiffusion/tests/generate_smoke_test.py +++ b/src/maxdiffusion/tests/generate_smoke_test.py @@ -5,7 +5,7 @@ from ..import pyconfig from absl.testing import absltest -from maxdiffusion.generate import run as generate_run +from maxdiffusion.benchmarks.inference.sd2.generate import run as generate_run from PIL import Image from skimage.metrics import structural_similarity as ssim diff --git a/~/jax_cache/pmap__generate-549e878e7619a99860fec014d5e293107470fd768b3583a67718f7764640b2b4 b/~/jax_cache/pmap__generate-549e878e7619a99860fec014d5e293107470fd768b3583a67718f7764640b2b4 new file mode 100644 index 00000000..b076e7b4 Binary files /dev/null and b/~/jax_cache/pmap__generate-549e878e7619a99860fec014d5e293107470fd768b3583a67718f7764640b2b4 differ