From cac0b57df8e6b1cc8358ed22b51c6ea88335978b Mon Sep 17 00:00:00 2001 From: Mikhail Mekhedkin Meskhi Date: Wed, 20 Jun 2018 19:49:30 +0400 Subject: [PATCH] Fixed Query Startegy Parameter --- complexity.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/complexity.py b/complexity.py index 0ac730a..eee1f07 100644 --- a/complexity.py +++ b/complexity.py @@ -47,7 +47,7 @@ def plot_ds(grid_size, loc, X, y, xx, yy, title, seeds=None, colspan=1, rowspan= Perform Active Learning QueryStrategy (Random Sampling or Uncertainty Sampling) ''' -def active(classifiers, datasets, experiments, qs, quota=25, plot_every_n=5): +def active(classifiers, datasets, experiments, query_strat, quota=25, plot_every_n=5): for dataset_index, ((X_src, y_src), (X_tgt, y_tgt)) in enumerate(datasets): u_tgt = [None] * len(X_tgt) est_src = ce.ComplexityEstimator(X_src, y_src, n_windows=10, nK=1) @@ -99,8 +99,7 @@ def active(classifiers, datasets, experiments, qs, quota=25, plot_every_n=5): for i in range(quota): # Loop through the number of queries - if qs == 1 : - qs_name = 'RandomSampling' + if query_strat == 'RandomSampling' : loc, y_loc = oracle.random_query() # Sample target using RandomSampling strategy u_tgt[loc] = y_loc X_known.append(X_tgt[loc]) @@ -122,8 +121,7 @@ def active(classifiers, datasets, experiments, qs, quota=25, plot_every_n=5): ax.set_xlabel('Accuracy='+('%.2f' % score).lstrip('0')) w += 1 - elif qs == 2: - qs_name = 'UncertaintySampling' + elif query_strat == 'UncertaintySampling': model.fit(X_known, y_known) # Fit model on source only to predict probabilities loc, X_chosen = oracle.uncertainty_sampling(model) # Sample target using UncertaintySampling strategy X_known.append(X_tgt[loc]) @@ -145,9 +143,9 @@ def active(classifiers, datasets, experiments, qs, quota=25, plot_every_n=5): ax.set_xlabel('Accuracy='+('%.2f' % score).lstrip('0')) w += 1 - figure.suptitle(experiments[dataset_index] + qs_name ) + figure.suptitle(experiments[dataset_index] + query_strat ) figure.tight_layout() - fname = './vis/' + str(experiments[dataset_index] + qs_name) + '.png' + fname = './vis/' + str(experiments[dataset_index] + query_strat ) + '.png' figure.savefig(fname) plt.tight_layout() @@ -157,6 +155,7 @@ def main(): clfs = [SVC(), GaussianNB(), DecisionTreeClassifier(), MLPClassifier(hidden_layer_sizes=(10,10,10,10,10,10), solver='lbfgs', alpha=2, random_state=1, activation='relu')] datasets = [] experiments = [] + query_strat = 'RandomSampling' # datasets.append((make_gaussian_quantiles(n_samples=500, n_features=10, n_classes=2), # make_gaussian_quantiles(n_samples=500, n_features=10, n_classes=2))) @@ -173,7 +172,7 @@ def main(): datasets.append((mnist.load_mnist(), mnist.load_mnist_rotated())) experiments.append('MNIST_vs_MNIST_Rotated') - active(classifiers=clfs, datasets=datasets, experiments=experiments, qs=1) + active(classifiers=clfs, datasets=datasets, experiments=experiments, query_strat=query_strat) if __name__ == "__main__": main()