Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
loubnabnl committed Aug 24, 2023
1 parent f16d804 commit 60ad952
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 7 deletions.
17 changes: 10 additions & 7 deletions pii/ner/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,21 @@ def get_args():
type=str,
default="bigcode/pii-annotated-toloka-donwsample-emails"
)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--train_batch_size", type=int, default=4)
parser.add_argument("--eval_batch_size", type=int, default=1)
parser.add_argument("--learning_rate", type=float, default=1e-5)
parser.add_argument("--lr_scheduler_type", type=str, default="cosine")
parser.add_argument("--num_train_epochs", type=int, default=20)
parser.add_argument("--weight_decay", type=float, default=0.01)
parser.add_argument("--warmup_steps", type=int, default=100)
parser.add_argument("--gradient_checkpointing", action="store_true")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--eval_accumulation_steps", type=int, default=4)
parser.add_argument("--num_proc", type=int, default=8)
parser.add_argument("--bf16", action="store_true")
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--num_workers", type=int, default=16)
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--eval_freq", type=int, default=100)
parser.add_argument("--save_freq", type=int, default=1000)
parser.add_argument("--debug", action="store_true")
Expand Down Expand Up @@ -104,8 +106,8 @@ def run_training(args, ner_dataset, model, tokenizer):
output_dir=args.output_dir,
evaluation_strategy="steps",
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
per_device_train_batch_size=args.train_batch_size,
per_device_eval_batch_size=args.eval_batch_size,
eval_steps=args.eval_freq,
save_steps=args.save_freq,
logging_steps=10,
Expand All @@ -117,9 +119,10 @@ def run_training(args, ner_dataset, model, tokenizer):
warmup_steps=args.warmup_steps,
gradient_checkpointing=args.gradient_checkpointing,
gradient_accumulation_steps=args.gradient_accumulation_steps,
eval_accumulation_steps=args.eval_accumulation_steps,
fp16=args.fp16,
bf16=args.bf16,
run_name=f"pii-bs{args.batch_size}-lr{args.learning_rate}-wd{args.weight_decay}-epochs{args.num_train_epochs}",
run_name=f"pii-bs{args.train_batch_size}-lr{args.learning_rate}-wd{args.weight_decay}-epochs{args.num_train_epochs}",
report_to="wandb",
)

Expand Down Expand Up @@ -179,9 +182,9 @@ def main(args):
)

# split to train and test
data = data.train_test_split(test_size=0.2, shuffle=True, seed=args.seed)
data = data.train_test_split(test_size=0.1, shuffle=True, seed=args.seed)
test_valid = data["test"].train_test_split(
test_size=0.6, shuffle=True, seed=args.seed
test_size=0.85, shuffle=True, seed=args.seed
)
train_data = data["train"]
valid_data = test_valid["train"]
Expand Down
26 changes: 26 additions & 0 deletions pii/ner/utils/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,30 @@
_seqeval_metric = load("seqeval")


# NER tags
CATEGORIES = [
"NAME",
"NAME_LICENSE",
"NAME_EXAMPLE",
"EMAIL",
"EMAIL_LICENSE",
"EMAIL_EXAMPLE",
"USERNAME",
"USERNAME_LICENSE",
"USERNAME_EXAMPLE",
"KEY",
"IP_ADDRESS",
"PASSWORD",
]
IGNORE_CLASS = ["AMBIGUOUS", "ID"]

LABEL2ID = {"O": 0}
for cat in CATEGORIES:
LABEL2ID[f"B-{cat}"] = len(LABEL2ID)
LABEL2ID[f"I-{cat}"] = len(LABEL2ID)
ID2LABEL = {v: k for k, v in LABEL2ID.items()}


def compute_ap(pred, truth):
pred_proba = 1 - softmax(pred, axis=-1)[..., 0]
pred_proba, truth = pred_proba.flatten(), np.array(truth).flatten()
Expand All @@ -18,6 +42,8 @@ def compute_ap(pred, truth):

def compute_metrics(p):
predictions, labels = p
print(f"predictions.shape: {predictions.shape} and type {type(predictions)}")
print(f"labels.shape: {labels.shape} and type {type(labels)}")
avg_prec = compute_ap(predictions, labels)
predictions = np.argmax(predictions, axis=2)

Expand Down

0 comments on commit 60ad952

Please sign in to comment.