From c03eb4ba380cc817819abab78d8ee96bee340f4d Mon Sep 17 00:00:00 2001
From: govertv <42623130+govertv@users.noreply.github.com>
Date: Mon, 14 Jun 2021 16:12:33 +0200
Subject: [PATCH] Add n_query is min option for minimal number of iterations
 (#605)

---
 asreview/entry_points/simulate.py | 10 +++++----
 asreview/review/base.py           | 13 +++++++++---
 asreview/settings.py              |  4 +++-
 asreview/types.py                 | 34 +++++++++++++++++++++++++++++++
 docs/source/API/cli.rst           |  5 +++--
 tests/test_simulate.py            |  9 ++++++++
 6 files changed, 65 insertions(+), 10 deletions(-)
 create mode 100644 asreview/types.py

diff --git a/asreview/entry_points/simulate.py b/asreview/entry_points/simulate.py
index 09fc9c424..7a73bd92c 100644
--- a/asreview/entry_points/simulate.py
+++ b/asreview/entry_points/simulate.py
@@ -24,6 +24,7 @@
 from asreview.config import DEFAULT_QUERY_STRATEGY
 from asreview.entry_points.base import BaseEntryPoint, _base_parser
 from asreview.review import review_simulate
+from asreview.types import type_n_queries
 
 
 class SimulateEntryPoint(BaseEntryPoint):
@@ -171,11 +172,12 @@ def _simulate_parser(prog="simulate", description=DESCRIPTION_SIMULATE):
              f"Default {DEFAULT_N_INSTANCES}.")
     parser.add_argument(
         "--n_queries",
-        type=int,
+        type=type_n_queries,
         default=None,
-        help="The number of queries. By default, the program "
-             "stops after all documents are reviewed or is "
-             "interrupted by the user."
+        help="The number of queries. Alternatively, entering 'min' will stop the "
+             "simulation when all relevant records have been found. By default, "
+             "the program stops after all records are reviewed or is interrupted "
+             "by the user."
     )
     parser.add_argument(
         "-n", "--n_papers",
diff --git a/asreview/review/base.py b/asreview/review/base.py
index 33fb35d58..4a5e36447 100644
--- a/asreview/review/base.py
+++ b/asreview/review/base.py
@@ -258,9 +258,16 @@ def _stop_iter(self, query_i, n_pool):
         if self.n_papers is not None and n_train >= self.n_papers:
             stop_iter = True
 
-        # don't stop if there is no stopping criteria
-        if self.n_queries is not None and query_i >= self.n_queries:
-            stop_iter = True
+        # If n_queries is set to min, stop when all relevant papers are included
+        if self.n_queries == 'min':
+            n_included = np.count_nonzero(self.y[self.train_idx] == 1)
+            n_total_relevant = np.count_nonzero(self.y == 1)
+            if n_included == n_total_relevant:
+                stop_iter = True
+        # Otherwise, stop when reaching n_queries (if provided)
+        elif self.n_queries is not None:
+            if query_i >= self.n_queries:
+                stop_iter = True
 
         return stop_iter
 
diff --git a/asreview/settings.py b/asreview/settings.py
index cde1be5ef..aef527c7d 100644
--- a/asreview/settings.py
+++ b/asreview/settings.py
@@ -22,6 +22,8 @@
 from asreview.models.query import get_query_model
 from asreview.models.feature_extraction import get_feature_model
 from asreview.utils import pretty_format
+from asreview.types import type_n_queries
+
 
 SETTINGS_TYPE_DICT = {
     "data_name": str,
@@ -31,7 +33,7 @@
     "feature_extraction": str,
     "n_papers": int,
     "n_instances": int,
-    "n_queries": int,
+    "n_queries": type_n_queries,
     "n_prior_included": int,
     "n_prior_excluded": int,
     "mode": str,
diff --git a/asreview/types.py b/asreview/types.py
new file mode 100644
index 000000000..eb41d92a9
--- /dev/null
+++ b/asreview/types.py
@@ -0,0 +1,34 @@
+# Copyright 2019-2020 The ASReview Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+def type_n_queries(value):
+    """Custom type used for --n_queries argument.
+
+    Parameters
+    ----------
+    value: str
+        The argument value for --n_queries
+
+    Returns
+    -------
+    type_n_queries:
+        A string containing 'min' or an integer.
+    """
+    if value == 'min':
+        return value
+    else:
+        try:
+            return int(value)
+        except ValueError:
+            raise ValueError("Value for n_queries is not 'min' or a valid integer")
diff --git a/docs/source/API/cli.rst b/docs/source/API/cli.rst
index 3b2c842bd..3faccbae6 100644
--- a/docs/source/API/cli.rst
+++ b/docs/source/API/cli.rst
@@ -181,8 +181,9 @@ Examples:
 
 .. option:: --n_queries N_QUERIES
 
-    The number of queries. By default, the program stops after all documents are reviewed
-    or is interrupted by the user.
+    The number of queries. Alternatively, entering :code:`min` will stop the simulation when all relevant
+    records have been found. By default, the program stops after all records are reviewed
+    or is interrupted by the user. 
 
 .. option:: -n N_PAPERS, --n_papers N_PAPERS
 
diff --git a/tests/test_simulate.py b/tests/test_simulate.py
index 42dceb506..e6813a40b 100644
--- a/tests/test_simulate.py
+++ b/tests/test_simulate.py
@@ -330,3 +330,12 @@ def check_model(monkeypatch=None,
     if state_file is not None:
         with open_state(state_file, read_only=True) as state:
             state_checker(state)
+
+
+def test_n_queries_min(tmpdir):
+
+    check_model(model="nb",
+                state_file=None,
+                use_granular=True,
+                n_instances=1,
+                n_queries='min')