Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcusGitAccount committed Feb 22, 2019
1 parent 28384ef commit 6095ecc
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 22 deletions.
70 changes: 70 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: Current File (Integrated Terminal)",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal"
},
{
"name": "Python: Remote Attach",
"type": "python",
"request": "attach",
"port": 5678,
"host": "localhost",
"pathMappings": [
{
"localRoot": "${workspaceFolder}",
"remoteRoot": "."
}
]
},
{
"name": "Python: Module",
"type": "python",
"request": "launch",
"module": "enter-your-module-name-here",
"console": "integratedTerminal"
},
{
"name": "Python: Django",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/manage.py",
"console": "integratedTerminal",
"args": [
"runserver",
"--noreload",
"--nothreading"
],
"django": true
},
{
"name": "Python: Flask",
"type": "python",
"request": "launch",
"module": "flask",
"env": {
"FLASK_APP": "app.py"
},
"args": [
"run",
"--no-debugger",
"--no-reload"
],
"jinja": true
},
{
"name": "Python: Current File (External Terminal)",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "externalTerminal"
}
]
}
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"python.pythonPath": "C:\\Users\\pop_m\\Anaconda3\\python.exe"
}
Binary file added __pycache__/ball_tree.cpython-36.pyc
Binary file not shown.
Binary file added __pycache__/heap.cpython-36.pyc
Binary file not shown.
38 changes: 22 additions & 16 deletions ball_tree.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import numpy as np
import matplotlib.pyplot as plt

from functools import cmp_to_key
from heap import Heap

# Introselect is a hybrid algorithm, combining both quickselect
# and median of medians
Expand All @@ -18,16 +20,19 @@ class BallTree:
def __init__(self, points: [[float]], metric):
if points is None:
raise ValueError('Dataset not provided.')
self.is_leaf = True
self.center = self.radius = None
self.dimension = None
self.left = None
self.right = None
self.points = np.array(points, copy=True)

if len(points) <= 1:
if len(points) == 1:
self.center = self.points[0]
return None

mid = len(self.points) >> 1
self.is_leaf = False
# Computing the dimension of the greatest spread, i.e.
# the dimension of points from the dataset that
# spread over the largest interval
Expand All @@ -39,8 +44,14 @@ def __init__(self, points: [[float]], metric):
center_index = introselect_by_dimension(points, mid, self.dimension)
self.center = self.points[center_index]
self.radius = np.apply_along_axis(lambda point: metric(self.center, point), 1, self.points).max(0)
self.left = BallTree(self.points[:mid], metric)
self.right = BallTree(self.points[mid:], metric)

left = self.points[:mid]
right = self.points[mid:]

if len(left) != 0:
self.left = BallTree(left, metric)
if len(right) != 0:
self.right = BallTree(right, metric)

def plot(self, plt):
if len(self.points) > 1:
Expand All @@ -58,18 +69,13 @@ def traverse_tree(tree_node, plt=None):
traverse_tree(tree_node.left, plt)
traverse_tree(tree_node.right, plt)

points = np.random.rand(100, 2) * 10000

plt.rcParams["font.size"] = 1
x = points[:, 0]
y = points[:, 1]

np.random.randint()

plt.scatter(x, y)
plt.show()

tree = BallTree(points, euclid_metric)

def _knn_update(node, target, metric):
pass

def _knn_prepare(node, target, k, metric):
pass

def knn_search(node, target, k, metric, queue):
if len(queue) != 0:
if metric(node.pivot, target) > metric():
pass
36 changes: 30 additions & 6 deletions heap.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ def __len__(self):
def __repr__(self):
return 'Heap: %s' % self.container

def __getitem__(self, index):
return self.container[index]

def __iter__(self):
return (item for item in self.container)

def _parent(self, index):
return (index - 1) >> 1

Expand Down Expand Up @@ -68,17 +74,35 @@ def pop(self):
return last
return None

def make_heap(self):
index = len(self.container) >> 1
while index > 0:
self._heapify(index)
index -= 1
def is_empty(self):
return len(self) == 0

@staticmethod
def make_heap(array, cmp):
heap = Heap(cmp)
heap.container = array
index = len(array) >> 1
while index >= 0:
heap._heapify(index)
index -= 1

if __name__ == '__main__':
heap = Heap(cmp = lambda parent, child: parent < child)
cmp = lambda parent, child: parent < child
heap = Heap(cmp)
for nbr in randint(0, 30, 5):
heap.push(nbr)
print(repr(heap))
heap.pop()
print(repr(heap))

for item in heap:
print(item, end=' ')
print('')
for i in range(0, len(heap)):
print(heap[i], end=' ')
print('')

arr = randint(0, 25, 15)
print(arr)
Heap.make_heap(arr, cmp)
print(arr)
46 changes: 46 additions & 0 deletions tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import numpy as np
import matplotlib.pyplot as plt

from functools import cmp_to_key
from heap import Heap
from ball_tree import BallTree, euclid_metric

if __name__ == '__main__':
plt.title = 'KNN search.'
points = np.random.randint(100, size=(20, 2))

plt.rcParams["font.size"] = 1
x = points[:, 0]
y = points[:, 1]
plt.scatter(x, y)

tree = BallTree(points, euclid_metric)
point = np.random.randint(0, 100, 2)
x_, y_ = point
plt.plot(x_, y_, 'bo', color='red')

distances = sorted([euclid_metric(point, candidate) for candidate in points])
s = set(distances)

k = 10
cmp = lambda a, b: a[1] > b[1]
heap = Heap(cmp)
for candidate in points:
distance = euclid_metric(point, candidate)
if len(heap) < k or distance < heap[0][1]:
heap.push((candidate, distance))
if len(heap) > k:
heap.pop()
for candidate in heap:
print(candidate)
x_, y_ = candidate[0]
plt.plot(x_, y_, 'bo', color='pink')

print(distances[:k])
all = True
for candidate in heap:
if not candidate[1] in s:
all = False
break
print('All? %s' % all)
plt.show()

0 comments on commit 6095ecc

Please sign in to comment.