Skip to content

Commit

Permalink
Merge pull request #3602 from ZipRecruiter/mattb.fix.proper-default-e…
Browse files Browse the repository at this point in the history
…val-metric-text-regression

fix: use proper eval default main eval metrics for text regression model
  • Loading branch information
alanakbik authored Jan 27, 2025
2 parents 30974f2 + 7c302c6 commit 087e441
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 14 deletions.
8 changes: 4 additions & 4 deletions flair/models/pairwise_regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def evaluate(
f"spearman: {metric.spearmanr():.4f}"
)

scores = {
eval_metrics = {
"loss": eval_loss.item(),
"mse": metric.mean_squared_error(),
"mae": metric.mean_absolute_error(),
Expand All @@ -354,12 +354,12 @@ def evaluate(
}

if main_evaluation_metric[0] in ("correlation", "other"):
main_score = scores[main_evaluation_metric[1]]
main_score = eval_metrics[main_evaluation_metric[1]]
else:
main_score = scores["spearman"]
main_score = eval_metrics["spearman"]

return Result(
main_score=main_score,
detailed_results=detailed_result,
scores=scores,
scores=eval_metrics,
)
27 changes: 17 additions & 10 deletions flair/models/text_regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def evaluate(
out_path: Optional[Union[str, Path]] = None,
embedding_storage_mode: EmbeddingStorageMode = "none",
mini_batch_size: int = 32,
main_evaluation_metric: tuple[str, str] = ("micro avg", "f1-score"),
main_evaluation_metric: tuple[str, str] = ("correlation", "pearson"),
exclude_labels: Optional[list[str]] = None,
gold_label_dictionary: Optional[Dictionary] = None,
return_loss: bool = True,
Expand Down Expand Up @@ -195,16 +195,23 @@ def evaluate(
f"spearman: {metric.spearmanr():.4f}"
)

result: Result = Result(
main_score=metric.pearsonr(),
eval_metrics = {
"loss": eval_loss.item(),
"mse": metric.mean_squared_error(),
"mae": metric.mean_absolute_error(),
"pearson": metric.pearsonr(),
"spearman": metric.spearmanr(),
}

if main_evaluation_metric[0] in ("correlation", "other"):
main_score = eval_metrics[main_evaluation_metric[1]]
else:
main_score = eval_metrics["spearman"]

result = Result(
main_score=main_score,
detailed_results=detailed_result,
scores={
"loss": eval_loss.item(),
"mse": metric.mean_squared_error(),
"mae": metric.mean_absolute_error(),
"pearson": metric.pearsonr(),
"spearman": metric.spearmanr(),
},
scores=eval_metrics,
)

return result
Expand Down

0 comments on commit 087e441

Please sign in to comment.