diff --git a/sml/metrics/classification/classification.py b/sml/metrics/classification/classification.py index 9bbd12ca..2d3882fc 100644 --- a/sml/metrics/classification/classification.py +++ b/sml/metrics/classification/classification.py @@ -214,7 +214,7 @@ def fun_score( y_true_binary, y_pred_binary = transform_binary(y_true, y_pred, i) fun_result.append(fun(y_true_binary, y_pred_binary)) elif average == 'binary': - if transform is True: + if transform: y_true_binary, y_pred_binary = transform_binary(y_true, y_pred, pos_label) else: y_true_binary, y_pred_binary = y_true, y_pred