-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathconfigs.py
72 lines (65 loc) · 3.06 KB
/
configs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import jsonargparse
def get_config_parser():
parser = jsonargparse.ArgumentParser()
parser.add_argument('--dataset', help='path to the dataset',
type=str, default='/notebooks/taxonomy.csv')
parser.add_argument(
'--batch_size', help='batch sized used in train/val/test', type=int,
default=16
)
parser.add_argument(
'--num_workers', help='number of workers for dataloader', type=int,
default=4
)
parser.add_argument(
'--negative_ratio',
help='percentage of negative pairs to sample compared to positve pairs',
type=float, default=1
)
parser.add_argument('--lr', type=float, help='learning rate', default=5e-5)
parser.add_argument('--seed', type=int,
help='seed for reproducable training', default=123435)
parser.add_argument('--precision', type=int,
help='precision of floating point computation, normally this should not be changed',
default=32
)
parser.add_argument('--check_val_every_n_epoch', type=int,
help='number of epoches between validation', default=5)
parser.add_argument('--max_epochs', type=int,
help='maximum number of epochs for training',
default=10
)
parser.add_argument(
'--dev', help='whether to perform fast dev run for debugging purpose',
action='store_true'
)
parser.add_argument('--resume_from_checkpoint', type=str,
help='checkpoint to resume traning from', default=None)
parser.add_argument('--run_dir', type=str,
help='directory to store files related to the training',
default='/notebooks/runs'
)
parser.add_argument('--model_name', type=str,
help='name of the model file when saving',
default='model')
parser.add_argument('--project_name', type=str,
help='project name for wandb logging',
default='taxonomy')
parser.add_argument('--weight_decay', type=float,
help='weight decay (L2 regularization)', default=0.02)
parser.add_argument('--prompt_learning', type=bool,
help='whether to use prompt learning', default=False)
parser.add_argument('--learnable_layers', type=int,
help='number of transformer layers to unfreeze',
default=0
)
parser.add_argument('--prompt_tuning', type=bool, default=False)
parser.add_argument('--base_model', type=str, default='EleutherAI/gpt-neo-1.3B')
parser.add_argument('--logger', type=str, default='wandb')
# TODO: add support for lora hyperparameters
parser.add_argument('--lora', type=bool, help='whether to use lora', default=False)
return parser
def get_config(config_path):
parser = get_config_parser()
args = parser.parse_path(config_path)
return args