From 58c950a9c20af442d2a1ed36fdc808e26b50c73b Mon Sep 17 00:00:00 2001 From: Dainis Boumber Date: Thu, 23 Aug 2018 21:26:03 -0500 Subject: [PATCH] all fixed --- complexity.py | 10 +++++----- modules/active_da.py | 29 ++++++++++++----------------- nd_boundary_plot | 2 +- 3 files changed, 18 insertions(+), 23 deletions(-) diff --git a/complexity.py b/complexity.py index d92b958..f164b31 100644 --- a/complexity.py +++ b/complexity.py @@ -182,12 +182,12 @@ def main(): # experiments.append('moons') # datasets.append((u.hastie(1000), u.hastie(1000))) - # datasets.append((make_gaussian_quantiles(n_samples=500, n_features=5, n_classes=3), - # make_gaussian_quantiles(n_samples=500, n_features=5, n_classes=3))) - # experiments.append('gauus') + datasets.append((make_gaussian_quantiles(n_samples=500, n_features=5, n_classes=3), + make_gaussian_quantiles(n_samples=500, n_features=5, n_classes=3))) + experiments.append('gauus') - datasets.append((mnist.load_mnist(), mnist.load_mnist_rotated())) - experiments.append('MNIST_vs_MNIST_Rotated') + #datasets.append((mnist.load_mnist(), mnist.load_mnist_rotated())) + #experiments.append('MNIST_vs_MNIST_Rotated') #baseline_active(classifiers=clfs, datasets=datasets, experiments=experiments, query_strat=query_strat) bsda_active(datasets=datasets) diff --git a/modules/active_da.py b/modules/active_da.py index 10122a7..5f15e74 100644 --- a/modules/active_da.py +++ b/modules/active_da.py @@ -26,11 +26,10 @@ class CADA(object): 6. Go to step 3 until no more examples are left for sampling. 7. Create a model with the queried examples on target. (outside the scope of this class, can use any model you want) ''' - def __init__(self, source_X, source_y, max_entropy=0.1, f_samples=0.01, window_growth_rate=0.1, mink=1, maxk=-1, n_samples=-1): + def __init__(self, source_X, source_y, max_entropy=0.9, f_samples=0.01, window_growth_rate=0.01): assert(len(source_X)==len(source_y)) assert(f_samples <= 1.0 and max_entropy <= 1.0) - assert (n_samples < len(source_y)) - assert(mink > 0 and maxk < len(source_y)) + #all sane #how many do we actually sample in Step 3 @@ -38,20 +37,13 @@ def __init__(self, source_X, source_y, max_entropy=0.1, f_samples=0.01, window_g self.source_y = source_y #build the tree of the distribution that we do nn on self.tree = scipy.spatial.cKDTree(self.source_X, leafsize=32, compact_nodes=False, balanced_tree=False) - print('build done') - if n_samples == -1: - self.n_samples = int(len(source_y) * f_samples) - if self.n_samples == 0: - self.n_samples = 1 - else: - self.n_samples = n_samples - #init sampling - self.seeds = np.random.random_integers(0, len(source_X) - 1, self.n_samples) + + self.seeds = np.random.random_integers(0, len(source_X) - 1, len(source_y)) self.classes = set(source_y) - stepsize = int(maxk * window_growth_rate) + stepsize = int(len(source_y) * window_growth_rate) if stepsize == 0: stepsize = 1 - self.Ks = np.arange(mink, maxk, step=stepsize) # ckdTree starts counting from 1 + self.Ks = np.arange(1, len(source_y), step=stepsize) # ckdTree starts counting from 1 self.Hs = np.zeros(len(self.Ks)) print(self.Hs) self.ws = np.zeros((len(self.seeds), len(self.Ks))) @@ -63,12 +55,14 @@ def __init__(self, source_X, source_y, max_entropy=0.1, f_samples=0.01, window_g # add up entropy for each window as they grow self.Hs[i] = np.sum(self.ws[:, i]) / len(self.seeds) - + #print(self.Hs[i]) if self.Hs[i] > max_entropy: if i > 0: - self.K = self.Ks[i-1] + self.K = self.Ks[i-1] break # done with step 2 + assert(self.K > 0 and self.K < len(source_y)) + # returns indices into target_X def query(self, target_X, N): @@ -84,7 +78,7 @@ def query(self, target_X, N): for example_ix in example_indices: _, ii = self._nearest_neighbors(self.K, example_ix) # step 4 - target_banned[ii.flatten().squeeze()] = 1 # step 5 + target_banned[ii] = 1 # step 5 return queried_example_indices @@ -101,5 +95,6 @@ def _H(self, k, seed): r = len(same_c)/float(k) if r > 0: H += (r * np.log2(r)) + print(H) return -H diff --git a/nd_boundary_plot b/nd_boundary_plot index 974a0ed..12dd5fd 160000 --- a/nd_boundary_plot +++ b/nd_boundary_plot @@ -1 +1 @@ -Subproject commit 974a0ed54f0bf820e602ef026b71c2955d796a9b +Subproject commit 12dd5fdcef68344bac04b08bac14ab0db8d8bf8b