Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Signed-off-by: YunLiu <[email protected]>
  • Loading branch information
KumoLiu committed Oct 9, 2024
1 parent b2d1c0a commit 61b513b
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions generation/maisi/scripts/diff_model_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def train_one_epoch(
device: torch.device,
logger: logging.Logger,
local_rank: int,
amp: bool = True,
) -> torch.Tensor:
"""
Train the model for one epoch.
Expand All @@ -212,6 +213,7 @@ def train_one_epoch(
device (torch.device): Device to use for training.
logger (logging.Logger): Logger for logging information.
local_rank (int): Local rank for distributed training.
amp (bool): Use automatic mixed precision training.
Returns:
torch.Tensor: Training loss for the epoch.
Expand All @@ -237,7 +239,7 @@ def train_one_epoch(

optimizer.zero_grad(set_to_none=True)

with autocast("cuda", enabled=True):
with autocast("cuda", enabled=amp):
noise = torch.randn(
(num_images_per_batch, 4, images.size(-3), images.size(-2), images.size(-1)), device=device
)
Expand All @@ -256,9 +258,13 @@ def train_one_epoch(

loss = loss_pt(noise_pred.float(), noise.float())

scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
if amp:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()

lr_scheduler.step()

Expand Down Expand Up @@ -312,14 +318,16 @@ def save_checkpoint(
)


def diff_model_train(env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int) -> None:
def diff_model_train(env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int, amp: bool = True) -> None:
"""
Main function to train a diffusion model.
Args:
env_config_path (str): Path to the environment configuration file.
model_config_path (str): Path to the model configuration file.
model_def_path (str): Path to the model definition file.
num_gpus (int): Number of GPUs to use for training.
amp (bool): Use automatic mixed precision training.
"""
args = load_config(env_config_path, model_config_path, model_def_path)
local_rank, world_size, device = initialize_distributed(num_gpus)
Expand Down Expand Up @@ -392,6 +400,7 @@ def diff_model_train(env_config_path: str, model_config_path: str, model_def_pat
device,
logger,
local_rank,
amp=amp,
)

loss_torch = loss_torch.tolist()
Expand Down Expand Up @@ -431,6 +440,7 @@ def diff_model_train(env_config_path: str, model_config_path: str, model_def_pat
"--model_def", type=str, default="./configs/config_maisi.json", help="Path to model definition file"
)
parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use for training")
parser.add_argument("--no_amp", dest="amp", action="store_false", help="Disable automatic mixed precision training")

args = parser.parse_args()
diff_model_train(args.env_config, args.model_config, args.model_def, args.num_gpus)
diff_model_train(args.env_config, args.model_config, args.model_def, args.num_gpus, args.amp)

0 comments on commit 61b513b

Please sign in to comment.