Skip to content

Commit

Permalink
Fixed Query Startegy Parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikhail Mekhedkin Meskhi committed Jun 20, 2018
1 parent 0de3ebc commit cac0b57
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions complexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def plot_ds(grid_size, loc, X, y, xx, yy, title, seeds=None, colspan=1, rowspan=
Perform Active Learning
QueryStrategy (Random Sampling or Uncertainty Sampling)
'''
def active(classifiers, datasets, experiments, qs, quota=25, plot_every_n=5):
def active(classifiers, datasets, experiments, query_strat, quota=25, plot_every_n=5):
for dataset_index, ((X_src, y_src), (X_tgt, y_tgt)) in enumerate(datasets):
u_tgt = [None] * len(X_tgt)
est_src = ce.ComplexityEstimator(X_src, y_src, n_windows=10, nK=1)
Expand Down Expand Up @@ -99,8 +99,7 @@ def active(classifiers, datasets, experiments, qs, quota=25, plot_every_n=5):

for i in range(quota): # Loop through the number of queries

if qs == 1 :
qs_name = 'RandomSampling'
if query_strat == 'RandomSampling' :
loc, y_loc = oracle.random_query() # Sample target using RandomSampling strategy
u_tgt[loc] = y_loc
X_known.append(X_tgt[loc])
Expand All @@ -122,8 +121,7 @@ def active(classifiers, datasets, experiments, qs, quota=25, plot_every_n=5):
ax.set_xlabel('Accuracy='+('%.2f' % score).lstrip('0'))
w += 1

elif qs == 2:
qs_name = 'UncertaintySampling'
elif query_strat == 'UncertaintySampling':
model.fit(X_known, y_known) # Fit model on source only to predict probabilities
loc, X_chosen = oracle.uncertainty_sampling(model) # Sample target using UncertaintySampling strategy
X_known.append(X_tgt[loc])
Expand All @@ -145,9 +143,9 @@ def active(classifiers, datasets, experiments, qs, quota=25, plot_every_n=5):
ax.set_xlabel('Accuracy='+('%.2f' % score).lstrip('0'))
w += 1

figure.suptitle(experiments[dataset_index] + qs_name )
figure.suptitle(experiments[dataset_index] + query_strat )
figure.tight_layout()
fname = './vis/' + str(experiments[dataset_index] + qs_name) + '.png'
fname = './vis/' + str(experiments[dataset_index] + query_strat ) + '.png'
figure.savefig(fname)

plt.tight_layout()
Expand All @@ -157,6 +155,7 @@ def main():
clfs = [SVC(), GaussianNB(), DecisionTreeClassifier(), MLPClassifier(hidden_layer_sizes=(10,10,10,10,10,10), solver='lbfgs', alpha=2, random_state=1, activation='relu')]
datasets = []
experiments = []
query_strat = 'RandomSampling'

# datasets.append((make_gaussian_quantiles(n_samples=500, n_features=10, n_classes=2),
# make_gaussian_quantiles(n_samples=500, n_features=10, n_classes=2)))
Expand All @@ -173,7 +172,7 @@ def main():
datasets.append((mnist.load_mnist(), mnist.load_mnist_rotated()))
experiments.append('MNIST_vs_MNIST_Rotated')

active(classifiers=clfs, datasets=datasets, experiments=experiments, qs=1)
active(classifiers=clfs, datasets=datasets, experiments=experiments, query_strat=query_strat)

if __name__ == "__main__":
main()

0 comments on commit cac0b57

Please sign in to comment.