diff --git a/asreview/entry_points/simulate.py b/asreview/entry_points/simulate.py index 09fc9c424..7a73bd92c 100644 --- a/asreview/entry_points/simulate.py +++ b/asreview/entry_points/simulate.py @@ -24,6 +24,7 @@ from asreview.config import DEFAULT_QUERY_STRATEGY from asreview.entry_points.base import BaseEntryPoint, _base_parser from asreview.review import review_simulate +from asreview.types import type_n_queries class SimulateEntryPoint(BaseEntryPoint): @@ -171,11 +172,12 @@ def _simulate_parser(prog="simulate", description=DESCRIPTION_SIMULATE): f"Default {DEFAULT_N_INSTANCES}.") parser.add_argument( "--n_queries", - type=int, + type=type_n_queries, default=None, - help="The number of queries. By default, the program " - "stops after all documents are reviewed or is " - "interrupted by the user." + help="The number of queries. Alternatively, entering 'min' will stop the " + "simulation when all relevant records have been found. By default, " + "the program stops after all records are reviewed or is interrupted " + "by the user." ) parser.add_argument( "-n", "--n_papers", diff --git a/asreview/review/base.py b/asreview/review/base.py index 33fb35d58..4a5e36447 100644 --- a/asreview/review/base.py +++ b/asreview/review/base.py @@ -258,9 +258,16 @@ def _stop_iter(self, query_i, n_pool): if self.n_papers is not None and n_train >= self.n_papers: stop_iter = True - # don't stop if there is no stopping criteria - if self.n_queries is not None and query_i >= self.n_queries: - stop_iter = True + # If n_queries is set to min, stop when all relevant papers are included + if self.n_queries == 'min': + n_included = np.count_nonzero(self.y[self.train_idx] == 1) + n_total_relevant = np.count_nonzero(self.y == 1) + if n_included == n_total_relevant: + stop_iter = True + # Otherwise, stop when reaching n_queries (if provided) + elif self.n_queries is not None: + if query_i >= self.n_queries: + stop_iter = True return stop_iter diff --git a/asreview/settings.py b/asreview/settings.py index cde1be5ef..aef527c7d 100644 --- a/asreview/settings.py +++ b/asreview/settings.py @@ -22,6 +22,8 @@ from asreview.models.query import get_query_model from asreview.models.feature_extraction import get_feature_model from asreview.utils import pretty_format +from asreview.types import type_n_queries + SETTINGS_TYPE_DICT = { "data_name": str, @@ -31,7 +33,7 @@ "feature_extraction": str, "n_papers": int, "n_instances": int, - "n_queries": int, + "n_queries": type_n_queries, "n_prior_included": int, "n_prior_excluded": int, "mode": str, diff --git a/asreview/types.py b/asreview/types.py new file mode 100644 index 000000000..eb41d92a9 --- /dev/null +++ b/asreview/types.py @@ -0,0 +1,34 @@ +# Copyright 2019-2020 The ASReview Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +def type_n_queries(value): + """Custom type used for --n_queries argument. + + Parameters + ---------- + value: str + The argument value for --n_queries + + Returns + ------- + type_n_queries: + A string containing 'min' or an integer. + """ + if value == 'min': + return value + else: + try: + return int(value) + except ValueError: + raise ValueError("Value for n_queries is not 'min' or a valid integer") diff --git a/docs/source/API/cli.rst b/docs/source/API/cli.rst index 3b2c842bd..3faccbae6 100644 --- a/docs/source/API/cli.rst +++ b/docs/source/API/cli.rst @@ -181,8 +181,9 @@ Examples: .. option:: --n_queries N_QUERIES - The number of queries. By default, the program stops after all documents are reviewed - or is interrupted by the user. + The number of queries. Alternatively, entering :code:`min` will stop the simulation when all relevant + records have been found. By default, the program stops after all records are reviewed + or is interrupted by the user. .. option:: -n N_PAPERS, --n_papers N_PAPERS diff --git a/tests/test_simulate.py b/tests/test_simulate.py index 42dceb506..e6813a40b 100644 --- a/tests/test_simulate.py +++ b/tests/test_simulate.py @@ -330,3 +330,12 @@ def check_model(monkeypatch=None, if state_file is not None: with open_state(state_file, read_only=True) as state: state_checker(state) + + +def test_n_queries_min(tmpdir): + + check_model(model="nb", + state_file=None, + use_granular=True, + n_instances=1, + n_queries='min')