From 7c302c6afea086acd2aa0a29475ee81887f4a049 Mon Sep 17 00:00:00 2001 From: Matt Buchovecky Date: Sun, 26 Jan 2025 12:20:25 -0800 Subject: [PATCH] fix: use proper eval default main eval metrics for text regression model also refactor variables to avoid type conflicts --- flair/models/pairwise_regression_model.py | 8 +++---- flair/models/text_regression_model.py | 27 ++++++++++++++--------- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/flair/models/pairwise_regression_model.py b/flair/models/pairwise_regression_model.py index 9a1c2704be..bc77b54dce 100644 --- a/flair/models/pairwise_regression_model.py +++ b/flair/models/pairwise_regression_model.py @@ -345,7 +345,7 @@ def evaluate( f"spearman: {metric.spearmanr():.4f}" ) - scores = { + eval_metrics = { "loss": eval_loss.item(), "mse": metric.mean_squared_error(), "mae": metric.mean_absolute_error(), @@ -354,12 +354,12 @@ def evaluate( } if main_evaluation_metric[0] in ("correlation", "other"): - main_score = scores[main_evaluation_metric[1]] + main_score = eval_metrics[main_evaluation_metric[1]] else: - main_score = scores["spearman"] + main_score = eval_metrics["spearman"] return Result( main_score=main_score, detailed_results=detailed_result, - scores=scores, + scores=eval_metrics, ) diff --git a/flair/models/text_regression_model.py b/flair/models/text_regression_model.py index d1ad98d4e0..a0a99e6402 100644 --- a/flair/models/text_regression_model.py +++ b/flair/models/text_regression_model.py @@ -137,7 +137,7 @@ def evaluate( out_path: Optional[Union[str, Path]] = None, embedding_storage_mode: EmbeddingStorageMode = "none", mini_batch_size: int = 32, - main_evaluation_metric: tuple[str, str] = ("micro avg", "f1-score"), + main_evaluation_metric: tuple[str, str] = ("correlation", "pearson"), exclude_labels: Optional[list[str]] = None, gold_label_dictionary: Optional[Dictionary] = None, return_loss: bool = True, @@ -195,16 +195,23 @@ def evaluate( f"spearman: {metric.spearmanr():.4f}" ) - result: Result = Result( - main_score=metric.pearsonr(), + eval_metrics = { + "loss": eval_loss.item(), + "mse": metric.mean_squared_error(), + "mae": metric.mean_absolute_error(), + "pearson": metric.pearsonr(), + "spearman": metric.spearmanr(), + } + + if main_evaluation_metric[0] in ("correlation", "other"): + main_score = eval_metrics[main_evaluation_metric[1]] + else: + main_score = eval_metrics["spearman"] + + result = Result( + main_score=main_score, detailed_results=detailed_result, - scores={ - "loss": eval_loss.item(), - "mse": metric.mean_squared_error(), - "mae": metric.mean_absolute_error(), - "pearson": metric.pearsonr(), - "spearman": metric.spearmanr(), - }, + scores=eval_metrics, ) return result