Skip to content

Commit

Permalink
try this fix
Browse files Browse the repository at this point in the history
  • Loading branch information
dainis-boumber committed Aug 30, 2018
1 parent 8164748 commit ffdde5b
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 11 deletions.
28 changes: 18 additions & 10 deletions modules/active_da.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ def __init__(self, source_X, source_y, max_entropy=0.9, f_samples=0.01, window_g
#all sane

#how many do we actually sample in Step 3
self.source_X = source_X
self.X = source_X
self.source_y = source_y
#build the tree of the distribution that we do nn on
self.tree = scipy.spatial.cKDTree(self.source_X, leafsize=32, compact_nodes=False, balanced_tree=False)
self.tree = scipy.spatial.cKDTree(self.X, leafsize=32, compact_nodes=False, balanced_tree=False)

self.seeds = np.random.random_integers(0, len(source_X) - 1, len(source_y))
self.classes = set(source_y)
Expand All @@ -45,31 +45,39 @@ def __init__(self, source_X, source_y, max_entropy=0.9, f_samples=0.01, window_g
stepsize = 1
self.Ks = np.arange(1, len(source_y), step=stepsize) # ckdTree starts counting from 1
self.Hs = np.zeros(len(self.Ks))
print(self.Hs)
self.ws = np.zeros((len(self.seeds), len(self.Ks)))
self.K = 0
self.fit_called = False

for i, k in enumerate(self.Ks):
for j, seed in enumerate(self.seeds):
self.ws[j, i] = self._H(k=k, seed=seed)

# add up entropy for each window as they grow
self.Hs[i] = np.sum(self.ws[:, i]) / len(self.seeds)
#print(self.Hs[i])

if self.Hs[i] > max_entropy:
if i > 0:
self.K = self.Ks[i-1]
break # done with step 2

assert(self.K > 0 and self.K < len(source_y))

# returns indices into target_X

def query(self, target_X, N):
if N > len(target_X):
# must call at least once before query is called (query can then be called on specified data)
def fit(self, target_X):
self.X = target_X
self.tree = scipy.spatial.cKDTree(self.X, leafsize=32, compact_nodes=False, balanced_tree=False)
self.fit_called = True

# returns indices into target_X, data must be specified by a prior called to fit (it is OK to call fit and then
# query multiple times or in a loop
def query(self, N):
if self.fit_called == False:
raise ValueError
if N > len(self.X):
raise AttributeError

target_banned = np.zeros(len(target_X))
target_banned = np.zeros(len(self.X))
queried_example_indices = []
while 0 in target_banned: # step 6 check
not_banned_ix = [i for i, banned in enumerate(target_banned) if banned != 1]
Expand All @@ -84,7 +92,7 @@ def query(self, target_X, N):


def _nearest_neighbors(self, k, seed):
return self.tree.query(self.source_X[seed, :], k=k, n_jobs=-1)
return self.tree.query(self.X[seed, :], k=k, n_jobs=-1)

def _H(self, k, seed):
H = 0
Expand Down
2 changes: 1 addition & 1 deletion nd_boundary_plot
Submodule nd_boundary_plot updated 1 files
+1 −1 LICENSE

0 comments on commit ffdde5b

Please sign in to comment.