From ca66e19195cec77a9e53c194f7de06ccbe5eafe7 Mon Sep 17 00:00:00 2001 From: apolinario Date: Thu, 31 Mar 2022 22:36:03 +0200 Subject: [PATCH] Add AudioCLIP to Disco Diffusion v5.1 --- Disco_Diffusion.ipynb | 94 +++++++++++++++++++++++++++------------- disco.py | 99 ++++++++++++++++++++++++++++++------------- 2 files changed, 134 insertions(+), 59 deletions(-) diff --git a/Disco_Diffusion.ipynb b/Disco_Diffusion.ipynb index babfd640..33b65a2d 100644 --- a/Disco_Diffusion.ipynb +++ b/Disco_Diffusion.ipynb @@ -60,7 +60,9 @@ "\n", "3D animation implementation added by Adam Letts (https://twitter.com/gandamu_ml) in collaboration with Somnai.\n", "\n", - "Turbo feature by Chris Allen (https://twitter.com/zippy731)" + "Turbo feature by Chris Allen (https://twitter.com/zippy731)\n", + "\n", + "AudioCLIP integration by Apolinário (https://twitter.com/multimodalart)" ] }, { @@ -463,12 +465,13 @@ "\n", "if is_colab:\n", " gitclone(\"https://github.com/openai/CLIP\")\n", + " gitclone(\"https://github.com/russelldc/AudioCLIP.git\")\n", " #gitclone(\"https://github.com/facebookresearch/SLIP.git\")\n", " gitclone(\"https://github.com/crowsonkb/guided-diffusion\")\n", " gitclone(\"https://github.com/assafshocher/ResizeRight.git\")\n", " pipie(\"./CLIP\")\n", " pipie(\"./guided-diffusion\")\n", - " multipip_res = subprocess.run(['pip', 'install', 'lpips', 'datetime', 'timm', 'ftfy'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n", + " multipip_res = subprocess.run(['pip', 'install', 'lpips', 'datetime', 'timm', 'ftfy', 'pytorch-ignite', 'visdom'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n", " print(multipip_res)\n", " subprocess.run(['apt', 'install', 'imagemagick'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n", " gitclone(\"https://github.com/isl-org/MiDaS.git\")\n", @@ -578,6 +581,9 @@ "import warnings\n", "warnings.filterwarnings(\"ignore\", category=UserWarning)\n", "\n", + "sys.path.append('./AudioCLIP')\n", + "from audioclip import AudioCLIP\n", + "\n", "# AdaBins stuff\n", "if USE_ADABINS:\n", " if is_colab:\n", @@ -1189,6 +1195,14 @@ " else:\n", " image_prompt = []\n", "\n", + " print(args.audio_prompts_series)\n", + " if args.audio_prompts_series is not None and frame_num >= len(args.audio_prompts_series):\n", + " audio_prompt = args.audio_prompts_series[-1]\n", + " elif args.audio_prompts_series is not None:\n", + " audio_prompt = args.audio_prompts_series[frame_num]\n", + " else:\n", + " audio_prompt = []\n", + " \n", " print(f'Frame {frame_num} Prompt: {frame_prompt}')\n", "\n", " model_stats = []\n", @@ -1197,35 +1211,49 @@ " model_stat = {\"clip_model\":None,\"target_embeds\":[],\"make_cutouts\":None,\"weights\":[]}\n", " model_stat[\"clip_model\"] = clip_model\n", " \n", - " \n", - " for prompt in frame_prompt:\n", - " txt, weight = parse_prompt(prompt)\n", - " txt = clip_model.encode_text(clip.tokenize(prompt).to(device)).float()\n", - " \n", - " if args.fuzzy_prompt:\n", - " for i in range(25):\n", - " model_stat[\"target_embeds\"].append((txt + torch.randn(txt.shape).cuda() * args.rand_mag).clamp(0,1))\n", - " model_stat[\"weights\"].append(weight)\n", - " else:\n", - " model_stat[\"target_embeds\"].append(txt)\n", - " model_stat[\"weights\"].append(weight)\n", - " \n", - " if image_prompt:\n", - " model_stat[\"make_cutouts\"] = MakeCutouts(clip_model.visual.input_resolution, cutn, skip_augs=skip_augs) \n", - " for prompt in image_prompt:\n", + " isAudio = isinstance(clip_model,AudioCLIP)\n", + " #If it is AudioCLIP, process the Audio prompts. Otherwise process either image or text prompts\n", + " if isAudio:\n", + " if audio_prompt:\n", + " for prompt in audio_prompt:\n", + " torch.set_grad_enabled(False)\n", " path, weight = parse_prompt(prompt)\n", - " img = Image.open(fetch(path)).convert('RGB')\n", - " img = TF.resize(img, min(side_x, side_y, *img.size), T.InterpolationMode.LANCZOS)\n", - " batch = model_stat[\"make_cutouts\"](TF.to_tensor(img).to(device).unsqueeze(0).mul(2).sub(1))\n", - " embed = clip_model.encode_image(normalize(batch)).float()\n", - " if fuzzy_prompt:\n", - " for i in range(25):\n", - " model_stat[\"target_embeds\"].append((embed + torch.randn(embed.shape).cuda() * rand_mag).clamp(0,1))\n", - " weights.extend([weight / cutn] * cutn)\n", - " else:\n", - " model_stat[\"target_embeds\"].append(embed)\n", - " model_stat[\"weights\"].extend([weight / cutn] * cutn)\n", - " \n", + " clip_model.eval()\n", + " audio_enc = clip_model.create_audio_encoding(path)\n", + " audio_enc = audio_enc / audio_enc.norm(dim=-1, keepdim=True)\n", + " embed = audio_enc.float()\n", + " model_stat[\"target_embeds\"].append(embed)\n", + " model_stat[\"weights\"].append(weight)\n", + " torch.set_grad_enabled(True) \n", + " else:\n", + " for prompt in frame_prompt:\n", + " txt, weight = parse_prompt(prompt)\n", + " txt = clip_model.encode_text(clip.tokenize(prompt).to(device)).float()\n", + " \n", + " if args.fuzzy_prompt:\n", + " for i in range(25):\n", + " model_stat[\"target_embeds\"].append((txt + torch.randn(txt.shape).cuda() * args.rand_mag).clamp(0,1))\n", + " model_stat[\"weights\"].append(weight)\n", + " else:\n", + " model_stat[\"target_embeds\"].append(txt)\n", + " model_stat[\"weights\"].append(weight)\n", + " \n", + " if image_prompt:\n", + " model_stat[\"make_cutouts\"] = MakeCutouts(clip_model.visual.input_resolution, cutn, skip_augs=skip_augs) \n", + " for prompt in image_prompt:\n", + " path, weight = parse_prompt(prompt)\n", + " img = Image.open(fetch(path)).convert('RGB')\n", + " img = TF.resize(img, min(side_x, side_y, *img.size), T.InterpolationMode.LANCZOS)\n", + " batch = model_stat[\"make_cutouts\"](TF.to_tensor(img).to(device).unsqueeze(0).mul(2).sub(1))\n", + " embed = clip_model.encode_image(normalize(batch)).float()\n", + " if fuzzy_prompt:\n", + " for i in range(25):\n", + " model_stat[\"target_embeds\"].append((embed + torch.randn(embed.shape).cuda() * rand_mag).clamp(0,1))\n", + " weights.extend([weight / cutn] * cutn)\n", + " else:\n", + " model_stat[\"target_embeds\"].append(embed)\n", + " model_stat[\"weights\"].extend([weight / cutn] * cutn)\n", + "\n", " model_stat[\"target_embeds\"] = torch.cat(model_stat[\"target_embeds\"])\n", " model_stat[\"weights\"] = torch.tensor(model_stat[\"weights\"], device=device)\n", " if model_stat[\"weights\"].sum().abs() < 1e-3:\n", @@ -2949,6 +2977,11 @@ "\n", "image_prompts = {\n", " # 0:['ImagePromptsWorkButArentVeryGood.png:2',],\n", + "}\n", + "\n", + "#Audio prompts only work if the AudioCLIP model is activated\n", + "audio_prompts = { \n", + " #0: ['AudioCLIP/assets/bird_sounds.wav']\n", "}\n" ], "outputs": [], @@ -3049,6 +3082,7 @@ " 'batchNum': batchNum,\n", " 'prompts_series':split_prompts(text_prompts) if text_prompts else None,\n", " 'image_prompts_series':split_prompts(image_prompts) if image_prompts else None,\n", + " 'audio_prompts_series':split_prompts(audio_prompts) if audio_prompts else None,\n", " 'seed': seed,\n", " 'display_rate':display_rate,\n", " 'n_batches':n_batches if animation_mode == 'None' else 1,\n", diff --git a/disco.py b/disco.py index 44b40946..f9aecbbd 100644 --- a/disco.py +++ b/disco.py @@ -388,12 +388,13 @@ def createPath(filepath): if is_colab: gitclone("https://github.com/openai/CLIP") + gitclone("https://github.com/russelldc/AudioCLIP.git") #gitclone("https://github.com/facebookresearch/SLIP.git") gitclone("https://github.com/crowsonkb/guided-diffusion") gitclone("https://github.com/assafshocher/ResizeRight.git") pipie("./CLIP") pipie("./guided-diffusion") - multipip_res = subprocess.run(['pip', 'install', 'lpips', 'datetime', 'timm', 'ftfy'], stdout=subprocess.PIPE).stdout.decode('utf-8') + multipip_res = subprocess.run(['pip', 'install', 'lpips', 'datetime', 'timm', 'ftfy', 'pytorch-ignite', 'visdom'], stdout=subprocess.PIPE).stdout.decode('utf-8') print(multipip_res) subprocess.run(['apt', 'install', 'imagemagick'], stdout=subprocess.PIPE).stdout.decode('utf-8') gitclone("https://github.com/isl-org/MiDaS.git") @@ -503,6 +504,9 @@ def createPath(filepath): import warnings warnings.filterwarnings("ignore", category=UserWarning) +sys.path.append('./AudioCLIP') +from audioclip import AudioCLIP + # AdaBins stuff if USE_ADABINS: if is_colab: @@ -1096,6 +1100,14 @@ def do_run(): else: image_prompt = [] + print(args.audio_prompts_series) + if args.audio_prompts_series is not None and frame_num >= len(args.audio_prompts_series): + audio_prompt = args.audio_prompts_series[-1] + elif args.audio_prompts_series is not None: + audio_prompt = args.audio_prompts_series[frame_num] + else: + audio_prompt = [] + print(f'Frame {frame_num} Prompt: {frame_prompt}') model_stats = [] @@ -1103,36 +1115,50 @@ def do_run(): cutn = 16 model_stat = {"clip_model":None,"target_embeds":[],"make_cutouts":None,"weights":[]} model_stat["clip_model"] = clip_model + + isAudio = isinstance(clip_model,AudioCLIP) + #If it is AudioCLIP, process the Audio prompts. Otherwise process either image or text prompts + if isAudio: + if audio_prompt: + for prompt in audio_prompt: + torch.set_grad_enabled(False) + path, weight = parse_prompt(prompt) + clip_model.eval() + audio_enc = clip_model.create_audio_encoding(path) + audio_enc = audio_enc / audio_enc.norm(dim=-1, keepdim=True) + embed = audio_enc.float() + model_stat["target_embeds"].append(embed) + model_stat["weights"].append(weight) + torch.set_grad_enabled(True) + else: + for prompt in frame_prompt: + txt, weight = parse_prompt(prompt) + txt = clip_model.encode_text(clip.tokenize(prompt).to(device)).float() + + if args.fuzzy_prompt: + for i in range(25): + model_stat["target_embeds"].append((txt + torch.randn(txt.shape).cuda() * args.rand_mag).clamp(0,1)) + model_stat["weights"].append(weight) + else: + model_stat["target_embeds"].append(txt) + model_stat["weights"].append(weight) + if image_prompt: + model_stat["make_cutouts"] = MakeCutouts(clip_model.visual.input_resolution, cutn, skip_augs=skip_augs) + for prompt in image_prompt: + path, weight = parse_prompt(prompt) + img = Image.open(fetch(path)).convert('RGB') + img = TF.resize(img, min(side_x, side_y, *img.size), T.InterpolationMode.LANCZOS) + batch = model_stat["make_cutouts"](TF.to_tensor(img).to(device).unsqueeze(0).mul(2).sub(1)) + embed = clip_model.encode_image(normalize(batch)).float() + if fuzzy_prompt: + for i in range(25): + model_stat["target_embeds"].append((embed + torch.randn(embed.shape).cuda() * rand_mag).clamp(0,1)) + weights.extend([weight / cutn] * cutn) + else: + model_stat["target_embeds"].append(embed) + model_stat["weights"].extend([weight / cutn] * cutn) - for prompt in frame_prompt: - txt, weight = parse_prompt(prompt) - txt = clip_model.encode_text(clip.tokenize(prompt).to(device)).float() - - if args.fuzzy_prompt: - for i in range(25): - model_stat["target_embeds"].append((txt + torch.randn(txt.shape).cuda() * args.rand_mag).clamp(0,1)) - model_stat["weights"].append(weight) - else: - model_stat["target_embeds"].append(txt) - model_stat["weights"].append(weight) - - if image_prompt: - model_stat["make_cutouts"] = MakeCutouts(clip_model.visual.input_resolution, cutn, skip_augs=skip_augs) - for prompt in image_prompt: - path, weight = parse_prompt(prompt) - img = Image.open(fetch(path)).convert('RGB') - img = TF.resize(img, min(side_x, side_y, *img.size), T.InterpolationMode.LANCZOS) - batch = model_stat["make_cutouts"](TF.to_tensor(img).to(device).unsqueeze(0).mul(2).sub(1)) - embed = clip_model.encode_image(normalize(batch)).float() - if fuzzy_prompt: - for i in range(25): - model_stat["target_embeds"].append((embed + torch.randn(embed.shape).cuda() * rand_mag).clamp(0,1)) - weights.extend([weight / cutn] * cutn) - else: - model_stat["target_embeds"].append(embed) - model_stat["weights"].extend([weight / cutn] * cutn) - model_stat["target_embeds"] = torch.cat(model_stat["target_embeds"]) model_stat["weights"] = torch.tensor(model_stat["weights"], device=device) if model_stat["weights"].sum().abs() < 1e-3: @@ -2145,6 +2171,7 @@ def do_superres(img, filepath): RN50x64 = False #@param{type:"boolean"} SLIPB16 = False #@param{type:"boolean"} SLIPL16 = False #@param{type:"boolean"} +AudioCLIP_model = False #@param {type:"boolean"} #@markdown If you're having issues with model downloads, check this to compare SHA's: check_model_SHA = False #@param{type:"boolean"} @@ -2306,6 +2333,15 @@ def do_superres(img, filepath): clip_models.append(SLIPL16model) +if AudioCLIP_model: + torch.set_grad_enabled(False) + if not os.path.exists(f'{model_path}/AudioCLIP-Full-Training.pt'): + wget("https://github.com/AndreyGuzhov/AudioCLIP/releases/download/v0.1/AudioCLIP-Full-Training.pt", model_path) + + ac = AudioCLIP(pretrained=f'{model_path}/AudioCLIP-Full-Training.pt').cuda() + torch.set_grad_enabled(True) + clip_models.append(ac) + normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]) lpips_model = lpips.LPIPS(net='vgg').to(device) @@ -2785,6 +2821,10 @@ def split_prompts(prompts): # 0:['ImagePromptsWorkButArentVeryGood.png:2',], } +#Audio prompts only work if the AudioCLIP model is activated +audio_prompts = { + #0: ['AudioCLIP/assets/bird_sounds.wav'] +} # %% """ @@ -2872,6 +2912,7 @@ def move_files(start_num, end_num, old_folder, new_folder): 'batchNum': batchNum, 'prompts_series':split_prompts(text_prompts) if text_prompts else None, 'image_prompts_series':split_prompts(image_prompts) if image_prompts else None, + 'audio_prompts_series':split_prompts(audio_prompts) if audio_prompts else None, 'seed': seed, 'display_rate':display_rate, 'n_batches':n_batches if animation_mode == 'None' else 1,