diff --git a/configs/pretrain/Darcy-pretrain.yaml b/configs/pretrain/Darcy-pretrain.yaml index d5291ad..cc74317 100644 --- a/configs/pretrain/Darcy-pretrain.yaml +++ b/configs/pretrain/Darcy-pretrain.yaml @@ -14,6 +14,7 @@ model: modes2: [20, 20, 20, 20] fc_dim: 128 act: gelu + pad_ratio: 0.0 train: batchsize: 20 diff --git a/train_operator.py b/train_operator.py index a6e761e..6cfb2ae 100644 --- a/train_operator.py +++ b/train_operator.py @@ -87,7 +87,7 @@ def train_2d(args, config): nx=data_config['nx'], sub=data_config['sub'], pde_sub=data_config['pde_sub'], - num=data_config['n_samples'], + num=data_config['n_sample'], offset=data_config['offset']) train_loader = DataLoader(dataset, batch_size=config['train']['batchsize'], shuffle=True) model = FNO2d(modes1=config['model']['modes1'], @@ -121,7 +121,9 @@ def train_2d(args, config): # parse options parser = ArgumentParser(description='Basic paser') - parser.add_argument('--config_path', type=str, help='Path to the configuration file') + parser.add_argument('--config_path', type=str, + default='configs/pretrain/Darcy-pretrain.yaml', + help='Path to the configuration file') parser.add_argument('--log', action='store_true', help='Turn on the wandb') args = parser.parse_args() diff --git a/train_utils/train_2d.py b/train_utils/train_2d.py index 10037d3..1484642 100644 --- a/train_utils/train_2d.py +++ b/train_utils/train_2d.py @@ -74,18 +74,18 @@ def train_2d_operator(model, if data_weight > 0: pred = model(data_ic).squeeze(dim=-1) pred = pred * mollifier - data_loss = myloss(pred, y) + data_loss = myloss(pred, u) - a = x[..., 0] + a = data_ic[..., 0] f_loss = darcy_loss(pred, a) loss = data_weight * data_loss + f_weight * f_loss loss.backward() optimizer.step() - loss_dict['train_loss'] += loss.item() * y.shape[0] - loss_dict['f_loss'] += f_loss.item() * y.shape[0] - loss_dict['data_loss'] += data_loss.item() * y.shape[0] + loss_dict['train_loss'] += loss.item() * u.shape[0] + loss_dict['f_loss'] += f_loss.item() * u.shape[0] + loss_dict['data_loss'] += data_loss.item() * u.shape[0] scheduler.step() train_loss_val = loss_dict['train_loss'] / len(train_loader.dataset)