forked from NVIDIA/Cosmos
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgeneral.py
353 lines (298 loc) · 14 KB
/
general.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import numpy as np
import torch
from huggingface_hub import snapshot_download
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from nemo import lightning as nl
from nemo.lightning.megatron_parallel import MegatronParallel
MegatronParallel.init_ddp = lambda self: None
from nemo.collections.diffusion.mcore_parallel_utils import Utils
from nemo.collections.diffusion.sampler.conditioner import VideoConditioner
from nemo.collections.diffusion.sampler.conditioner_configs import (
FPSConfig,
ImageSizeConfig,
NumFramesConfig,
PaddingMaskConfig,
TextConfig,
)
from nemo.collections.diffusion.sampler.cosmos.cosmos_diffusion_pipeline import CosmosDiffusionPipeline
from transformers import T5EncoderModel, T5TokenizerFast
from cosmos1.models.diffusion.nemo.inference.inference_utils import process_prompt, save_video
from cosmos1.utils import log
EXAMPLE_PROMPT = (
"The teal robot is cooking food in a kitchen. Steam rises from a simmering pot "
"as the robot chops vegetables on a worn wooden cutting board. Copper pans hang "
"from an overhead rack, catching glints of afternoon light, while a well-loved "
"cast iron skillet sits on the stovetop next to scattered measuring spoons and "
"a half-empty bottle of olive oil."
)
def parse_args():
parser = argparse.ArgumentParser(description="Video foundation model inference")
parser.add_argument(
"--model",
type=str,
default="Cosmos-1.0-Diffusion-7B-Text2World",
choices=["Cosmos-1.0-Diffusion-7B-Text2World", "Cosmos-1.0-Diffusion-14B-Text2World"],
)
parser.add_argument(
"--prompt",
type=str,
default=EXAMPLE_PROMPT,
help="Prompt which the sampled video condition on",
)
# We turn on negative prompt by default. set to "" to turn it off.
parser.add_argument(
"--negative_prompt",
type=str,
default=(
"The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
"over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
"underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, "
"jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, "
"fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. "
"Overall, the video is of poor quality."
),
help="Negative prompt which the sampled video condition on",
)
parser.add_argument("--subject_name", type=str, default="", help="Name of fine-tuned subject")
parser.add_argument("--guidance", type=float, default=7, help="Classifier-free guidance scale")
parser.add_argument("--sampler", type=str, default="RES", help="Currently only supports RES sampler.")
parser.add_argument("--video_save_path", type=str, default="outputs", help="Path to save the video")
parser.add_argument("--fps", type=int, default=24, help="FPS of the sampled video")
parser.add_argument("--height", type=int, default=704, help="Height of image to sample")
parser.add_argument("--width", type=int, default=1280, help="Width of image to sample")
parser.add_argument("--seed", type=int, default=1, help="Random seed")
parser.add_argument("--num_devices", type=int, default=1, help="Number of devices for inference")
parser.add_argument("--cp_size", type=int, default=1, help="Number of cp ranks for multi-gpu inference.")
parser.add_argument("--num_steps", type=float, default=35, help="Number of diffusion sampling steps")
parser.add_argument("--num_video_frames", type=int, default=121, help="Number of video frames to sample")
parser.add_argument("--tokenizer_dir", type=str, default="", help="Directory for video tokenizer")
parser.add_argument("--cosmos_assets_dir", type=str, default="", help="Directory containing cosmos assets")
parser.add_argument("--prompt_upsampler_dir", type=str, default="", help="Prompt upsampler weights directory")
parser.add_argument("--guardrail_dir", type=str, default="", help="Guardrails weights directory")
parser.add_argument("--nemo_checkpoint", type=str, default="", help="Video diffusion model nemo weights")
parser.add_argument("--t5_cache_dir", type=str, default=None, help="Path to T5 model")
parser.add_argument(
"--enable_prompt_upsampler", action="store_true", help="Whether to use prompt upsampling before generation"
)
args = parser.parse_args()
return args
def print_rank_0(string: str):
rank = torch.distributed.get_rank()
if rank == 0:
log.info(string)
@torch.no_grad()
def encode_for_batch(tokenizer: T5TokenizerFast, encoder: T5EncoderModel, prompts: list[str], max_length: int = 512):
"""
Encode a batch of text prompts to a batch of T5 embeddings.
Parameters:
tokenizer: T5 embedding tokenizer.
encoder: T5 embedding text encoder.
prompts: A batch of text prompts.
max_length: Sequence length of text embedding (defaults to 512).
"""
batch_encoding = tokenizer.batch_encode_plus(
prompts,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=max_length,
return_length=True,
return_offsets_mapping=False,
)
# We expect all the processing is done on GPU.
input_ids = batch_encoding.input_ids.cuda()
attn_mask = batch_encoding.attention_mask.cuda()
outputs = encoder(input_ids=input_ids, attention_mask=attn_mask)
encoded_text = outputs.last_hidden_state
lengths = attn_mask.sum(dim=1).cpu()
for batch_id in range(encoded_text.shape[0]):
encoded_text[batch_id][lengths[batch_id] :] = 0
return encoded_text
def init_video_tokenizer(args):
"""
Initializes video tokenizer based on specified video tokenizer config / path.
"""
from nemo.collections.diffusion.models.model import DiT7BConfig, DiT14BConfig
vae_path = os.path.join(args.cosmos_assets_dir, args.tokenizer_dir)
if "7b" in args.nemo_checkpoint.lower():
dit_config = DiT7BConfig(vae_path=vae_path)
if "14b" in args.nemo_checkpoint.lower():
dit_config = DiT14BConfig(vae_path=vae_path)
vae = dit_config.configure_vae()
return vae
def check_prompt(args):
prompt = args.prompt
subject_string = None
if args.subject_name:
subject_string = f"A video of sks {args.subject_name}"
prompt = process_prompt(
prompt=prompt,
checkpoint_dir=args.cosmos_assets_dir,
prompt_upsampler_dir=args.prompt_upsampler_dir,
guardrails_dir=args.guardrail_dir,
enable_prompt_upsampler=args.enable_prompt_upsampler,
)
if subject_string:
prompt = f"{subject_string}. {prompt}"
return prompt
def prepare_data_batch(args, vae, t5_embeding_max_length=512):
tokenizer = T5TokenizerFast.from_pretrained("google-t5/t5-11b", cache_dir=args.t5_cache_dir)
text_encoder = T5EncoderModel.from_pretrained("google-t5/t5-11b", cache_dir=args.t5_cache_dir)
text_encoder.to("cuda")
text_encoder.eval()
# Encode text to T5 embedding
out = encode_for_batch(tokenizer, text_encoder, [args.prompt])[0]
encoded_text = torch.tensor(out, dtype=torch.bfloat16)
# Padding T5 embedding to t5_embeding_max_length
L, C = encoded_text.shape
t5_embed = torch.zeros(1, t5_embeding_max_length, C, dtype=torch.bfloat16)
t5_embed[0, :L] = encoded_text
if args.negative_prompt:
out = encode_for_batch(tokenizer, text_encoder, [args.negative_prompt])[0]
encoded_text = torch.tensor(out, dtype=torch.bfloat16)
# Padding T5 embedding to t5_embeding_max_length
L, C = encoded_text.shape
neg_t5_embed = torch.zeros(1, t5_embeding_max_length, C, dtype=torch.bfloat16)
neg_t5_embed[0, :L] = encoded_text
else:
neg_t5_embed = None
# Prepare data sample
t, h, w = args.num_video_frames, args.height, args.width
state_shape = [
vae.channel,
vae.get_latent_num_frames(t),
h // vae.spatial_compression_factor,
w // vae.spatial_compression_factor,
]
data_batch = {
"video": torch.zeros((1, 3, t, h, w), dtype=torch.uint8).cuda(),
"t5_text_embeddings": t5_embed,
"t5_text_mask": torch.ones(1, t5_embeding_max_length, dtype=torch.bfloat16).cuda(),
# other conditions
"image_size": torch.tensor(
[[args.height, args.width, args.height, args.width]] * 1, dtype=torch.bfloat16
).cuda(),
"fps": torch.tensor([args.fps] * 1, dtype=torch.bfloat16).cuda(),
"num_frames": torch.tensor([args.num_video_frames] * 1, dtype=torch.bfloat16).cuda(),
"padding_mask": torch.zeros((1, 1, args.height, args.width), dtype=torch.bfloat16).cuda(),
}
if args.negative_prompt:
data_batch["neg_t5_text_embeddings"] = neg_t5_embed
data_batch["neg_t5_text_mask"] = torch.ones(1, t5_embeding_max_length, dtype=torch.bfloat16)
return data_batch, state_shape
def setup_diffusion_pipeline(args):
"""
Initialize DiT model, parallel strategy, and diffusion pipeline for inference.
"""
# Initialize DiT model
from nemo.collections.diffusion.models.model import DiT7BConfig, DiT14BConfig, DiTModel
if "7b" in args.nemo_checkpoint.lower():
dit_config = DiT7BConfig()
if "14b" in args.nemo_checkpoint.lower():
dit_config = DiT14BConfig()
dit_model = DiTModel(dit_config)
# Initialize model parallel strategy. Here, we only use context parallel.
strategy = nl.MegatronStrategy(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
context_parallel_size=args.cp_size,
pipeline_dtype=torch.bfloat16,
)
# Initialize ptl trainer
trainer = nl.Trainer(
devices=args.num_devices, # you can change the numebr of devices to suit your setup
max_steps=1,
accelerator="gpu",
strategy=strategy,
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
)
# Convert trainer to fabric for inference
fabric = trainer.to_fabric()
fabric.strategy.checkpoint_io.save_ckpt_format = "zarr"
fabric.strategy.checkpoint_io.validate_access_integrity = False
model = fabric.load_model(args.nemo_checkpoint, dit_model).to(device="cuda", dtype=torch.bfloat16)
# Set up diffusion pipeline
conditioner = VideoConditioner(
text=TextConfig(),
fps=FPSConfig(),
num_frames=NumFramesConfig(),
image_size=ImageSizeConfig(),
padding_mask=PaddingMaskConfig(),
)
diffusion_pipeline = CosmosDiffusionPipeline(
net=model.module, conditioner=conditioner, sampler_type=args.sampler, seed=args.seed
)
return diffusion_pipeline
def run_diffusion_inference(args, data_batch, state_shape, vae, diffusion_pipeline):
# prepare data
data_batch = {k: v.cuda() if torch.is_tensor(v) else v for k, v in data_batch.items()}
data_batch["inference_fwd"] = True
sample = diffusion_pipeline.generate_samples_from_batch(
data_batch,
guidance=args.guidance,
state_shape=state_shape,
num_steps=args.num_steps,
is_negative_prompt=True if "neg_t5_text_embeddings" in data_batch else False,
)
rank = torch.distributed.get_rank()
if rank == 0:
# Post-processing and save video
sigma_data = 0.5
grid = (1.0 + vae.decode(sample / sigma_data)).clamp(0, 2) / 2
grid = (grid[0].permute(1, 2, 3, 0) * 255).to(torch.uint8).cpu().numpy().astype(np.uint8)
save_video(
grid=grid,
fps=args.fps,
H=args.height,
W=args.width,
video_save_quality=5,
video_save_path=args.video_save_path,
checkpoint_dir=args.cosmos_assets_dir,
guardrails_dir=args.guardrail_dir,
)
print_rank_0(f"saved video to {args.video_save_path}!")
def main(args):
if args.guardrail_dir == "":
args.guardrail_dir = snapshot_download("nvidia/Cosmos-1.0-Guardrail")
if args.tokenizer_dir == "":
args.tokenizer_dir = snapshot_download("nvidia/Cosmos-1.0-Tokenizer-CV8x8x8")
if args.prompt_upsampler_dir == "" and args.enable_prompt_upsampler:
args.prompt_upsampler_dir = snapshot_download("nvidia/Cosmos-1.0-Prompt-Upsampler-12B-Text2World")
if args.nemo_checkpoint == "":
args.nemo_checkpoint = snapshot_download(f"nvidia/{args.model}", allow_patterns=["nemo/*"])
args.nemo_checkpoint = os.path.join(args.nemo_checkpoint, "nemo")
# Initialize megatron model parallel environment
Utils.initialize_distributed(1, 1, context_parallel_size=args.cp_size)
model_parallel_cuda_manual_seed(args.seed)
args.prompt = check_prompt(args)
# Load video tokenizer
print_rank_0("initializing video tokenizer...")
vae = init_video_tokenizer(args)
# Prepare data batch
print_rank_0("preparing data batch...")
data_batch, state_shape = prepare_data_batch(args, vae)
# Setup model / diffusion pipeline
print_rank_0("setting up diffusion pipeline...")
diffusion_pipeline = setup_diffusion_pipeline(args)
# Generate video from prompt
print_rank_0("generating video...")
run_diffusion_inference(args, data_batch, state_shape, vae, diffusion_pipeline)
if __name__ == "__main__":
args = parse_args()
main(args)