add kd tree class

add mesh decimation algorithm
This commit is contained in:
wmayer
2017-11-14 11:37:30 +01:00
parent 0b33f977f7
commit 04ea295280
35 changed files with 5546 additions and 27 deletions

View File

@@ -0,0 +1,19 @@
find_package (SWIG REQUIRED)
include (${SWIG_USE_FILE})
find_package (PythonLibs)
include_directories (${PYTHON_INCLUDE_PATH})
include_directories (${CMAKE_CURRENT_SOURCE_DIR})
include_directories (${CMAKE_CURRENT_SOURCE_DIR}/..)
# Build the _kdtree python module
set_source_files_properties (py-kdtree.i PROPERTIES CPLUSPLUS ON)
swig_add_module (kdtree python py-kdtree.i)
swig_link_libraries (kdtree ${PYTHON_LIBRARIES})
# Copy the test file into the build dir
install (FILES py-kdtree_test.py DESTINATION ${CMAKE_INSTALL_PREFIX}/python)
install (FILES ${CMAKE_BINARY_DIR}/python-bindings/kdtree.py DESTINATION ${CMAKE_INSTALL_PREFIX}/python)
install (FILES ${CMAKE_BINARY_DIR}/python-bindings/_kdtree.so DESTINATION ${CMAKE_INSTALL_PREFIX}/python)

View File

@@ -0,0 +1,264 @@
#!/usr/bin/python
TREE_TYPES = [(dim, "int", "unsigned long long", "i", "L") for dim in range(2,7)] + \
[(dim, "float", "unsigned long long", "f", "L") for dim in range(2,7)]
def write_swig_file(tmpl_fn_name, swig_fn_name):
TMPL_SEPARATOR_DEF="""\
////////////////////////////////////////////////////////////////////////////////
// TYPE (%s) -> %s
////////////////////////////////////////////////////////////////////////////////
"""
TMPL_SEPARATOR=[]
TMPL_RECORD_DEF="""\
#define RECORD_%i%s%s record_t<%i, %s, %s> // cf. py-kdtree.hpp
"""
TMPL_RECORD=[]
TMPL_IN_CONV_RECORD_DEF="""\
%%typemap(in) RECORD_%i%s%s (RECORD_%i%s%s temp) {
if (PyTuple_Check($input)) {
if (PyArg_ParseTuple($input,"(%s)%s", %s, &temp.data)!=0)
{
$1 = temp;
} else {
PyErr_SetString(PyExc_TypeError,"tuple must have %i elements: (%i dim %s vector, %s value)");
return NULL;
}
} else {
PyErr_SetString(PyExc_TypeError,"expected a tuple.");
return NULL;
}
}
"""
TMPL_IN_CONV_RECORD=[]
TMPL_IN_CONV_POINT_DEF="""\
%%typemap(in) RECORD_%i%s%s::point_t (RECORD_%i%s%s::point_t point) {
if (PyTuple_Check($input)) {
if (PyArg_ParseTuple($input,"%s", %s)!=0)
{
$1 = point;
} else {
PyErr_SetString(PyExc_TypeError,"tuple must contain %i ints");
return NULL;
}
} else {
PyErr_SetString(PyExc_TypeError,"expected a tuple.");
return NULL;
}
}
"""
TMPL_IN_CONV_POINT=[]
TMPL_OUT_CONV_POINT_DEF="""\
%%typemap(out) RECORD_%i%s%s * {
RECORD_%i%s%s * r = $1;
PyObject* py_result;
if (r != NULL) {
py_result = PyTuple_New(2);
if (py_result==NULL) {
PyErr_SetString(PyErr_Occurred(),"unable to create a tuple.");
return NULL;
}
if (PyTuple_SetItem(py_result, 0, Py_BuildValue("(%s)", %s))==-1) {
PyErr_SetString(PyErr_Occurred(),"(a) when setting element");
Py_DECREF(py_result);
return NULL;
}
if (PyTuple_SetItem(py_result, 1, Py_BuildValue("%s", r->data))==-1) {
PyErr_SetString(PyErr_Occurred(),"(b) when setting element");
Py_DECREF(py_result);
return NULL;
}
} else {
py_result = Py_BuildValue("");
}
$result = py_result;
}
"""
TMPL_OUT_CONV_POINT=[]
TMPL_OUT_CONV_RECORD_DEF="""\
%%typemap(out) std::vector<RECORD_%i%s%s >* {
std::vector<RECORD_%i%s%s >* v = $1;
PyObject* py_result = PyList_New(v->size());
if (py_result==NULL) {
PyErr_SetString(PyErr_Occurred(),"unable to create a list.");
return NULL;
}
std::vector<RECORD_%i%s%s >::const_iterator iter = v->begin();
for (size_t i=0; i<v->size(); i++, iter++) {
if (PyList_SetItem(py_result, i, Py_BuildValue("(%s)%s", %s, (*iter).data))==-1) {
PyErr_SetString(PyErr_Occurred(),"(c) when setting element");
Py_DECREF(py_result);
return NULL;
} else {
//std::cout << "successfully set element " << *iter << std::endl;
}
}
$result = py_result;
}
"""
TMPL_OUT_CONV_RECORD=[]
TMPL_PY_CLASS_DEF="""\
%%template () RECORD_%i%s%s;
%%template (KDTree_%i%s) PyKDTree<%i, %s, %s>;
"""
TMPL_PY_CLASS=[]
TYPE_DEFS = []
for t in TREE_TYPES:
dim, coord_t, data_t, py_coord_t, py_data_t = t
coord_t_short = "".join([_[0] for _ in coord_t.split(" ")])
data_t_short = "".join([_[0] for _ in data_t.split(" ")])
TMPL_SEPARATOR.append(TMPL_SEPARATOR_DEF%(",".join([coord_t for _ in range(dim)]), data_t))
TMPL_RECORD.append(TMPL_RECORD_DEF%(dim, coord_t_short, data_t_short, dim, coord_t, data_t))
TMPL_IN_CONV_RECORD.append(TMPL_IN_CONV_RECORD_DEF%\
(dim, coord_t_short, data_t_short,
dim, coord_t_short, data_t_short,
py_coord_t*dim, py_data_t, ",".join(["&temp.point[%i]"%i for i in range(dim)]),
dim, dim, coord_t, data_t)
)
TMPL_IN_CONV_POINT.append(TMPL_IN_CONV_POINT_DEF%\
(dim, coord_t_short, data_t_short,
dim, coord_t_short, data_t_short,
py_coord_t*dim, ",".join(["&point[%i]"%i for i in range(dim)]),
dim)
)
TMPL_OUT_CONV_RECORD.append(TMPL_OUT_CONV_RECORD_DEF%\
(dim, coord_t_short, data_t_short,
dim, coord_t_short, data_t_short,
dim, coord_t_short, data_t_short,
py_coord_t*dim, py_data_t, ",".join(["(*iter).point[%i]"%i for i in range(dim)]),
)
)
TMPL_OUT_CONV_POINT.append(TMPL_OUT_CONV_POINT_DEF%\
(dim, coord_t_short, data_t_short,
dim, coord_t_short, data_t_short,
py_coord_t*dim, ",".join(["r->point[%i]"%i for i in range(dim)]),
py_data_t)
)
TMPL_PY_CLASS.append(TMPL_PY_CLASS_DEF%\
(dim, coord_t_short, data_t_short,
dim, coord_t.capitalize(), dim, coord_t, data_t)
)
TMPL_BODY_LIST = []
for i in range(len(TREE_TYPES)):
TMPL_BODY_LIST.append(TMPL_SEPARATOR[i] + "\n" + \
TMPL_RECORD[i] + "\n" + \
TMPL_IN_CONV_POINT[i] + "\n" + \
TMPL_IN_CONV_RECORD[i] + "\n" + \
TMPL_OUT_CONV_POINT[i] + "\n" + \
TMPL_OUT_CONV_RECORD[i])
TMPL_BODY = "\n\n".join(TMPL_BODY_LIST)
# write swig file
i_content = open(tmpl_fn_name, "r").read()
i_content = i_content.replace("%%TMPL_BODY%%", TMPL_BODY).replace("%%TMPL_PY_CLASS_DEF%%", "\n".join(TMPL_PY_CLASS))
f=open(swig_fn_name, "w")
f.write(i_content)
f.close()
def write_hpp_file(tmpl_fn_name, hpp_fn_name):
TMPL_SEPARATOR_DEF="""\
////////////////////////////////////////////////////////////////////////////////
// Definition of (%s) with data type %s
////////////////////////////////////////////////////////////////////////////////
"""
TMPL_SEPARATOR=[]
TMPL_RECORD_DEF = """\
#define RECORD_%i%s%s record_t<%i, %s, %s>
#define KDTREE_TYPE_%i%s%s KDTree::KDTree<%i, RECORD_%i%s%s, std::pointer_to_binary_function<RECORD_%i%s%s,int,double> >
"""
TMPL_RECORD=[]
TMPL_OP_EQ_DEF = """\
inline bool operator==(RECORD_%i%s%s const& A, RECORD_%i%s%s const& B) {
return %s && A.data == B.data;
}
"""
TMPL_OP_EQ = []
TMPL_OP_OUT_DEF="""\
std::ostream& operator<<(std::ostream& out, RECORD_%i%s%s const& T)
{
return out << '(' << %s << '|' << T.data << ')';
}
"""
TMPL_OP_OUT = []
TYPE_DEFS = []
for t in TREE_TYPES:
dim, coord_t, data_t, py_coord_t, py_data_t = t
coord_t_short = "".join([_[0] for _ in coord_t.split(" ")])
data_t_short = "".join([_[0] for _ in data_t.split(" ")])
TMPL_SEPARATOR.append(TMPL_SEPARATOR_DEF%(",".join([coord_t for _ in range(dim)]), data_t))
TMPL_RECORD.append(TMPL_RECORD_DEF%(dim, coord_t_short, data_t_short,
dim, coord_t, data_t,
dim, coord_t_short, data_t_short,
dim,
dim, coord_t_short, data_t_short,
dim, coord_t_short, data_t_short)
)
TMPL_OP_EQ.append(TMPL_OP_EQ_DEF%(dim, coord_t_short, data_t_short,
dim, coord_t_short, data_t_short,
" && ".join(["A.point[%i] == B.point[%i]"%(i,i) for i in range(dim)])))
TMPL_OP_OUT.append(TMPL_OP_OUT_DEF%(dim, coord_t_short, data_t_short,
" << ',' << ".join(["T.point[%i]"%i for i in range(dim)])))
TMPL_BODY_LIST = []
for i in range(len(TREE_TYPES)):
TMPL_BODY_LIST.append(TMPL_SEPARATOR[i] + "\n" + TMPL_RECORD[i] + "\n" + TMPL_OP_EQ[i] + "\n" + TMPL_OP_OUT[i])
TMPL_BODY = "\n\n".join(TMPL_BODY_LIST)
# write hpp file
hpp_content = open(tmpl_fn_name, "r").read()
hpp_content = hpp_content.replace("%%TMPL_HPP_DEFS%%", TMPL_BODY)
f=open(hpp_fn_name, "w")
f.write(hpp_content)
f.close()
if __name__=="__main__":
write_swig_file("py-kdtree.i.tmpl", "py-kdtree.i")
write_hpp_file("py-kdtree.hpp.tmpl", "py-kdtree.hpp")

View File

@@ -0,0 +1,145 @@
/** \file
* Provides a Python interface for the libkdtree++.
*
* \author Willi Richert <w.richert@gmx.net>
*
*
* This defines a proxy to a (int, int) -> long long KD-Tree. The long
* long is needed to save a reference to Python's object id(). Thereby,
* you can associate Python objects with 2D integer points.
*
* If you want to customize it you can adapt the following:
*
* * Dimension of the KD-Tree point vector.
* * DIM: number of dimensions.
* * operator==() and operator<<(): adapt to the number of comparisons
* * py-kdtree.i: Add or adapt all usages of PyArg_ParseTuple() to reflect the
* number of dimensions.
* * adapt query_records in find_nearest() and count_within_range()
* * Type of points.
* * coord_t: If you want to have e.g. floats you have
* to adapt all usages of PyArg_ParseTuple(): Change "i" to "f" e.g.
* * Type of associated data.
* * data_t: currently unsigned long long, which is "L" in py-kdtree.i
* * PyArg_ParseTuple() has to be changed to reflect changes in data_t
*
*/
#ifndef _PY_KDTREE_H_
#define _PY_KDTREE_H_
#include <kdtree++/kdtree.hpp>
#include <iostream>
#include <vector>
#include <limits>
template <size_t DIM, typename COORD_T, typename DATA_T >
struct record_t {
static const size_t dim = DIM;
typedef COORD_T coord_t;
typedef DATA_T data_t;
typedef coord_t point_t[dim];
inline coord_t operator[](size_t const N) const { return point[N]; }
point_t point;
data_t data;
};
typedef double RANGE_T;
%%TMPL_HPP_DEFS%%
////////////////////////////////////////////////////////////////////////////////
// END OF TYPE SPECIFIC DEFINITIONS
////////////////////////////////////////////////////////////////////////////////
template <class RECORD_T>
inline double tac(RECORD_T r, int k) { return r[k]; }
template <size_t DIM, typename COORD_T, typename DATA_T >
class PyKDTree {
public:
typedef record_t<DIM, COORD_T, DATA_T> RECORD_T;
typedef KDTree::KDTree<DIM, RECORD_T, std::pointer_to_binary_function<RECORD_T,int,double> > TREE_T;
TREE_T tree;
PyKDTree() : tree(std::ptr_fun(tac<RECORD_T>)) { };
void add(RECORD_T T) { tree.insert(T); };
/**
Exact erase.
*/
bool remove(RECORD_T T) {
bool removed = false;
typename TREE_T::const_iterator it = tree.find_exact(T);
if (it!=tree.end()) {
tree.erase_exact(T);
removed = true;
}
return removed;
};
int size(void) { return tree.size(); }
void optimize(void) { tree.optimise(); }
RECORD_T* find_exact(RECORD_T T) {
RECORD_T* found = NULL;
typename TREE_T::const_iterator it = tree.find_exact(T);
if (it!=tree.end())
found = new RECORD_T(*it);
return found;
}
size_t count_within_range(typename RECORD_T::point_t T, RANGE_T range) {
RECORD_T query_record;
memcpy(query_record.point, T, sizeof(COORD_T)*DIM);
return tree.count_within_range(query_record, range);
}
std::vector<RECORD_T >* find_within_range(typename RECORD_T::point_t T, RANGE_T range) {
RECORD_T query_record;
memcpy(query_record.point, T, sizeof(COORD_T)*DIM);
std::vector<RECORD_T> *v = new std::vector<RECORD_T>;
tree.find_within_range(query_record, range, std::back_inserter(*v));
return v;
}
RECORD_T* find_nearest (typename RECORD_T::point_t T) {
RECORD_T* found = NULL;
RECORD_T query_record;
memcpy(query_record.point, T, sizeof(COORD_T)*DIM);
std::pair<typename TREE_T::const_iterator, typename TREE_T::distance_type> best =
tree.find_nearest(query_record, std::numeric_limits<typename TREE_T::distance_type>::max());
if (best.first!=tree.end()) {
found = new RECORD_T(*best.first);
}
return found;
}
std::vector<RECORD_T >* get_all() {
std::vector<RECORD_T>* v = new std::vector<RECORD_T>;
for (typename TREE_T::const_iterator iter=tree.begin(); iter!=tree.end(); ++iter) {
v->push_back(*iter);
}
return v;
}
size_t __len__() { return tree.size(); }
};
#endif //_PY_KDTREE_H_

View File

@@ -0,0 +1,27 @@
/** \file
*
* Provides a Python interface for the libkdtree++.
*
* \author Willi Richert <w.richert@gmx.net>
*
*/
%module kdtree
%{
#define SWIG_FILE_WITH_INIT
#include "py-kdtree.hpp"
%}
%ignore record_t::operator[];
%ignore operator==;
%ignore operator<<;
%ignore KDTree::KDTree::operator=;
%ignore tac;
%%TMPL_BODY%%
%include "py-kdtree.hpp"
%%TMPL_PY_CLASS_DEF%%

View File

@@ -0,0 +1,96 @@
#define KDTREE_DEFINE_OSTREAM_OPERATORS
#include <kdtree++/kdtree.hpp>
#include <iostream>
#include <vector>
#include "py-kdtree.hpp"
int main()
{
KDTree_2Int t;
RECORD_2il c0 = { {5, 4} }; t.add(c0);
RECORD_2il c1 = { {4, 2} }; t.add(c1);
RECORD_2il c2 = { {7, 6} }; t.add(c2);
RECORD_2il c3 = { {2, 2} }; t.add(c3);
RECORD_2il c4 = { {8, 0} }; t.add(c4);
RECORD_2il c5 = { {5, 7} }; t.add(c5);
RECORD_2il c6 = { {3, 3} }; t.add(c6);
RECORD_2il c7 = { {9, 7} }; t.add(c7);
RECORD_2il c8 = { {2, 2} }; t.add(c8);
RECORD_2il c9 = { {2, 0} }; t.add(c9);
std::cout << t.tree << std::endl;
t.remove(c0);
t.remove(c1);
t.remove(c3);
t.remove(c5);
t.optimize();
std::cout << std::endl << t.tree << std::endl;
int i=0;
for (KDTREE_TYPE_2il::const_iterator iter=t.tree.begin(); iter!=t.tree.end(); ++iter, ++i);
std::cout << "iterator walked through " << i << " nodes in total" << std::endl;
if (i!=6)
{
std::cerr << "Error: does not tally with the expected number of nodes (6)" << std::endl;
return 1;
}
i=0;
for (KDTREE_TYPE_2il::const_reverse_iterator iter=t.tree.rbegin(); iter!=t.tree.rend(); ++iter, ++i);
std::cout << "reverse_iterator walked through " << i << " nodes in total" << std::endl;
if (i!=6)
{
std::cerr << "Error: does not tally with the expected number of nodes (6)" << std::endl;
return 1;
}
RECORD_2il::point_t s = {5, 4};
std::vector<RECORD_2il> v;
unsigned int const RANGE = 3;
size_t count = t.count_within_range(s, RANGE);
std::cout << "counted " << count
<< " nodes within range " << RANGE << " of " << s << ".\n";
v = t.find_within_range(s, RANGE);
std::cout << "found " << v.size() << " nodes within range " << RANGE
<< " of " << s << ":\n";
std::vector<RECORD_2il>::const_iterator ci = v.begin();
for (; ci != v.end(); ++ci)
std::cout << *ci << " ";
std::cout << "\n" << std::endl;
std::cout << "Nearest to " << s << ": " <<
t.find_nearest(s) << std::endl;
RECORD_2il::point_t s2 = { 10, 10};
std::cout << "Nearest to " << s2 << ": " <<
t.find_nearest(s2) << std::endl;
std::cout << std::endl;
std::cout << t.tree << std::endl;
return 0;
}
/* COPYRIGHT --
*
* This file is part of libkdtree++, a C++ template KD-Tree sorting container.
* libkdtree++ is (c) 2004-2007 Martin F. Krafft <libkdtree@pobox.madduck.net>
* and Sylvain Bougerel <sylvain.bougerel.devel@gmail.com> distributed under the
* terms of the Artistic License 2.0. See the ./COPYING file in the source tree
* root for more information.
*
* THIS PACKAGE IS PROVIDED "AS IS" AND WITHOUT ANY EXPRESS OR IMPLIED
* WARRANTIES, INCLUDING, WITHOUT LIMITATION, THE IMPLIED WARRANTIES
* OF MERCHANTIBILITY AND FITNESS FOR A PARTICULAR PURPOSE.
*/

View File

@@ -0,0 +1,400 @@
#
# $Id: py-kdtree_test.py 2268 2008-08-20 10:08:58Z richert $
#
import unittest
from kdtree import KDTree_2Int, KDTree_4Int, KDTree_3Float, KDTree_4Float, KDTree_6Float
class KDTree_2IntTestCase(unittest.TestCase):
def test_empty(self):
nn = KDTree_2Int()
self.assertEqual(0, nn.size())
actual = nn.find_nearest((2,3))
self.assertTrue(None==actual, "%s != %s"%(str(None), str(actual)))
def test_get_all(self):
nn = KDTree_2Int()
o1 = object()
nn.add(((1,1), id(o1)))
o2 = object()
nn.add(((10,10), id(o2)))
o3 = object()
nn.add(((11,11), id(o3)))
self.assertEqual([((1,1), id(o1)), ((10,10), id(o2)), ((11,11), id(o3))], nn.get_all())
self.assertEqual(3, len(nn))
nn.remove(((10,10), id(o2)))
self.assertEqual(2, len(nn))
self.assertEqual([((1,1), id(o1)), ((11,11), id(o3))], nn.get_all())
def test_nearest(self):
nn = KDTree_2Int()
nn_id = {}
o1 = object()
nn.add(((1,1), id(o1)))
nn_id[id(o1)] = o1
o2 = object()
nn.add(((10,10), id(o2)))
nn_id[id(o2)] = o2
expected = o1
actual = nn.find_nearest((2,2))[1]
self.assertTrue(expected==nn_id[actual], "%s != %s"%(str(expected), str(nn_id[actual])))
expected = o2
actual = nn.find_nearest((6, 6))[1]
self.assertTrue(expected==nn_id[actual], "%s != %s"%(str(expected), str(nn_id[actual])))
def test_find_within_range(self):
nn = KDTree_6Float()
nn_id = {}
o1 = object()
nn.add(((1,1,0,0,0,0), id(o1)))
nn_id[id(o1)] = o1
o2 = object()
nn.add(((10,10,0,0,0,0), id(o2)))
nn_id[id(o2)] = o2
o3 = object()
nn.add(((4.1, 4.1,0,0,0,0), id(o3)))
nn_id[id(o3)] = o3
expected = set([long(id(o1)), long(id(o3))])
actual = set([ident
for _coord, ident
in nn.find_within_range((2.1,2.1,0,0,0,0), 3.9)])
self.assertTrue(expected==actual, "%s != %s"%(str(expected), str(actual)))
def test_remove(self):
class C:
def __init__(self, i):
self.i = i
self.next = None
nn = KDTree_2Int()
k1, o1 = (1,1), C(7)
self.assertFalse(nn.remove((k1, id(o1))), "This cannot be removed!")
nn.add((k1, id(o1)))
k2, o2 = (1,1), C(7)
nn.add((k2, id(o2)))
self.assertEqual(2, nn.size())
self.assertTrue(nn.remove((k2, id(o2))))
self.assertEqual(1, nn.size())
self.assertFalse(nn.remove((k2, id(o2))))
self.assertEqual(1, nn.size())
nearest = nn.find_nearest(k1)
self.assertTrue(nearest[1] == id(o1), "%s != %s"%(nearest[1], o1))
#self.assertTrue(nearest[1] is o1, "%s,%s is not %s"%(str(nearest[0]), str(nearest[1]), str((k1,id(o1)))))
def test_count_within_range(self):
nn = KDTree_2Int()
for p in [(0,0), (1,0), (0,1), (1,1)]:
nn.add((p, id(p)))
res = nn.count_within_range((0,0), 1.0)
self.assertEqual(3, res, "Counted %i points instead of %i"%(res, 3))
res = nn.count_within_range((0,0), 1.9)
self.assertEqual(4, res, "Counted %i points instead of %i"%(res, 4))
class KDTree_4IntTestCase(unittest.TestCase):
def test_empty(self):
nn = KDTree_4Int()
self.assertEqual(0, nn.size())
actual = nn.find_nearest((0,0,2,3))
self.assertTrue(None==actual, "%s != %s"%(str(None), str(actual)))
def test_get_all(self):
nn = KDTree_4Int()
o1 = object()
nn.add(((0,0,1,1), id(o1)))
o2 = object()
nn.add(((0,0,10,10), id(o2)))
o3 = object()
nn.add(((0,0,11,11), id(o3)))
self.assertEqual([((0,0,1,1), id(o1)), ((0,0,10,10), id(o2)), ((0,0,11,11), id(o3))], nn.get_all())
self.assertEqual(3, len(nn))
nn.remove(((0,0,10,10), id(o2)))
self.assertEqual(2, len(nn))
self.assertEqual([((0,0,1,1), id(o1)), ((0,0,11,11), id(o3))], nn.get_all())
def test_nearest(self):
nn = KDTree_4Int()
nn_id = {}
o1 = object()
nn.add(((0,0,1,1), id(o1)))
nn_id[id(o1)] = o1
o2 = object()
nn.add(((0,0,10,10), id(o2)))
nn_id[id(o2)] = o2
expected = o1
actual = nn.find_nearest((0,0,2,2))[1]
self.assertTrue(expected==nn_id[actual], "%s != %s"%(str(expected), str(nn_id[actual])))
expected = o2
actual = nn.find_nearest((0,0,6,6))[1]
self.assertTrue(expected==nn_id[actual], "%s != %s"%(str(expected), str(nn_id[actual])))
def test_remove(self):
class C:
def __init__(self, i):
self.i = i
self.next = None
nn = KDTree_4Int()
k1, o1 = (0,0,1,1), C(7)
self.assertFalse(nn.remove((k1, id(o1))), "This cannot be removed!")
nn.add((k1, id(o1)))
k2, o2 = (0,0,1,1), C(7)
nn.add((k2, id(o2)))
self.assertEqual(2, nn.size())
self.assertTrue(nn.remove((k2, id(o2))))
self.assertEqual(1, nn.size())
self.assertFalse(nn.remove((k2, id(o2))))
self.assertEqual(1, nn.size())
nearest = nn.find_nearest(k1)
self.assertTrue(nearest[1] == id(o1), "%s != %s"%(nearest[1], o1))
#self.assertTrue(nearest[1] is o1, "%s,%s is not %s"%(str(nearest[0]), str(nearest[1]), str((k1,id(o1)))))
class KDTree_4FloatTestCase(unittest.TestCase):
def test_empty(self):
nn = KDTree_4Float()
self.assertEqual(0, nn.size())
actual = nn.find_nearest((0,0,2,3))
self.assertTrue(None==actual, "%s != %s"%(str(None), str(actual)))
def test_get_all(self):
nn = KDTree_4Int()
o1 = object()
nn.add(((0,0,1,1), id(o1)))
o2 = object()
nn.add(((0,0,10,10), id(o2)))
o3 = object()
nn.add(((0,0,11,11), id(o3)))
self.assertEqual([((0,0,1,1), id(o1)), ((0,0,10,10), id(o2)), ((0,0,11,11), id(o3))], nn.get_all())
self.assertEqual(3, len(nn))
nn.remove(((0,0,10,10), id(o2)))
self.assertEqual(2, len(nn))
self.assertEqual([((0,0,1,1), id(o1)), ((0,0,11,11), id(o3))], nn.get_all())
def test_nearest(self):
nn = KDTree_4Int()
nn_id = {}
o1 = object()
nn.add(((0,0,1,1), id(o1)))
nn_id[id(o1)] = o1
o2 = object()
nn.add(((0,0,10,10), id(o2)))
nn_id[id(o2)] = o2
expected = o1
actual = nn.find_nearest((0,0,2,2))[1]
self.assertTrue(expected==nn_id[actual], "%s != %s"%(str(expected), str(nn_id[actual])))
expected = o2
actual = nn.find_nearest((0,0,6,6))[1]
self.assertTrue(expected==nn_id[actual], "%s != %s"%(str(expected), str(nn_id[actual])))
def test_remove(self):
class C:
def __init__(self, i):
self.i = i
self.next = None
nn = KDTree_4Int()
k1, o1 = (0,0,1,1), C(7)
self.assertFalse(nn.remove((k1, id(o1))), "This cannot be removed!")
nn.add((k1, id(o1)))
k2, o2 = (0,0,1,1), C(7)
nn.add((k2, id(o2)))
self.assertEqual(2, nn.size())
self.assertTrue(nn.remove((k2, id(o2))))
self.assertEqual(1, nn.size())
self.assertFalse(nn.remove((k2, id(o2))))
self.assertEqual(1, nn.size())
nearest = nn.find_nearest(k1)
self.assertTrue(nearest[1] == id(o1), "%s != %s"%(nearest[1], o1))
#self.assertTrue(nearest[1] is o1, "%s,%s is not %s"%(str(nearest[0]), str(nearest[1]), str((k1,id(o1)))))
class KDTree_3FloatTestCase(unittest.TestCase):
def test_empty(self):
nn = KDTree_3Float()
self.assertEqual(0, nn.size())
actual = nn.find_nearest((2,3,0))
self.assertTrue(None==actual, "%s != %s"%(str(None), str(actual)))
def test_get_all(self):
nn = KDTree_3Float()
o1 = object()
nn.add(((1,1,0), id(o1)))
o2 = object()
nn.add(((10,10,0), id(o2)))
o3 = object()
nn.add(((11,11,0), id(o3)))
self.assertEqual([((1,1,0), id(o1)), ((10,10,0), id(o2)), ((11,11,0), id(o3))], nn.get_all())
self.assertEqual(3, len(nn))
nn.remove(((10,10,0), id(o2)))
self.assertEqual(2, len(nn))
self.assertEqual([((1,1,0), id(o1)), ((11,11,0), id(o3))], nn.get_all())
def test_nearest(self):
nn = KDTree_3Float()
nn_id = {}
o1 = object()
nn.add(((1,1,0), id(o1)))
nn_id[id(o1)] = o1
o2 = object()
nn.add(((10,10,0), id(o2)))
nn_id[id(o2)] = o2
o3 = object()
nn.add(((4.1, 4.1,0), id(o3)))
nn_id[id(o3)] = o3
expected = o3
actual = nn.find_nearest((2.9,2.9,0))[1]
self.assertTrue(expected==nn_id[actual], "%s != %s"%(str(expected), str(nn_id[actual])))
expected = o3
actual = nn.find_nearest((6, 6,0))[1]
self.assertTrue(expected==nn_id[actual], "%s != %s"%(str(expected), str(nn_id[actual])))
def test_remove(self):
class C:
def __init__(self, i):
self.i = i
self.next = None
nn = KDTree_3Float()
k1, o1 = (1.1,1.1,0), C(7)
self.assertFalse(nn.remove((k1, id(o1))), "This cannot be removed!")
nn.add((k1, id(o1)))
k2, o2 = (1.1,1.1,0), C(7)
nn.add((k2, id(o2)))
self.assertEqual(2, nn.size())
self.assertTrue(nn.remove((k2, id(o2))))
self.assertEqual(1, nn.size())
self.assertFalse(nn.remove((k2, id(o2))))
self.assertEqual(1, nn.size())
nearest = nn.find_nearest(k1)
self.assertTrue(nearest[1] == id(o1), "%s != %s"%(nearest[1], o1))
#self.assertTrue(nearest[1] is o1, "%s,%s is not %s"%(str(nearest[0]), str(nearest[1]), str((k1,id(o1)))))
class KDTree_6FloatTestCase(unittest.TestCase):
def test_empty(self):
nn = KDTree_6Float()
self.assertEqual(0, nn.size())
actual = nn.find_nearest((2,3,0,0,0,0))
self.assertTrue(None==actual, "%s != %s"%(str(None), str(actual)))
def test_get_all(self):
nn = KDTree_6Float()
o1 = object()
nn.add(((1,1,0,0,0,0), id(o1)))
o2 = object()
nn.add(((10,10,0,0,0,0), id(o2)))
o3 = object()
nn.add(((11,11,0,0,0,0), id(o3)))
self.assertEqual([((1,1,0,0,0,0), id(o1)), ((10,10,0,0,0,0), id(o2)), ((11,11,0,0,0,0 ), id(o3))], nn.get_all())
self.assertEqual(3, len(nn))
nn.remove(((10,10,0,0,0,0), id(o2)))
self.assertEqual(2, len(nn))
self.assertEqual([((1,1,0,0,0,0), id(o1)), ((11,11,0,0,0,0), id(o3))], nn.get_all())
def test_nearest(self):
nn = KDTree_6Float()
nn_id = {}
o1 = object()
nn.add(((1,1,0,0,0,0), id(o1)))
nn_id[id(o1)] = o1
o2 = object()
nn.add(((10,10,0,0,0,0), id(o2)))
nn_id[id(o2)] = o2
o3 = object()
nn.add(((4.1, 4.1,0,0,0,0), id(o3)))
nn_id[id(o3)] = o3
expected = o3
actual = nn.find_nearest((2.9,2.9,0,0,0,0))[1]
self.assertTrue(expected==nn_id[actual], "%s != %s"%(str(expected), str(nn_id[actual])))
expected = o3
actual = nn.find_nearest((6, 6,0,0,0,0))[1]
self.assertTrue(expected==nn_id[actual], "%s != %s"%(str(expected), str(nn_id[actual])))
def test_remove(self):
class C:
def __init__(self, i):
self.i = i
self.next = None
nn = KDTree_6Float()
k1, o1 = (1.1,1.1,0,0,0,0), C(7)
self.assertFalse(nn.remove((k1, id(o1))), "This cannot be removed!")
nn.add((k1, id(o1)))
k2, o2 = (1.1,1.1,0,0,0,0), C(7)
nn.add((k2, id(o2)))
self.assertEqual(2, nn.size())
self.assertTrue(nn.remove((k2, id(o2))))
self.assertEqual(1, nn.size())
self.assertFalse(nn.remove((k2, id(o2))))
self.assertEqual(1, nn.size())
nearest = nn.find_nearest(k1)
self.assertTrue(nearest[1] == id(o1), "%s != %s"%(nearest[1], o1))
#self.assertTrue(nearest[1] is o1, "%s,%s is not %s"%(str(nearest[0]), str(nearest[1]), str((k1,id(o1)))))
def suite():
return unittest.defaultTestLoader.loadTestsFromModule(sys.modules.get(__name__))
if __name__ == '__main__':
unittest.main()