本文共 6823 字,大约阅读时间需要 22 分钟。
本文这套代码是开源的,来自黎嘉信老师的github.
import randomimport mathimport numpy as npfrom result_set import KNNResultSet,RadiusNNResultSetclass Node: #节点,每一个数都是一个分支节点 def __init__(self,key,value=-1): self.left = None self.right = None self.key =key self.value = value #value可以用作储存其他数值,譬如点原来的序号 def __str__(self): return "key: %s, value: %s" % (str(self.key), str(self.value))def insert(root,key,value=-1): #构建二叉树 if root is None: root = Node(key,value) #赋初值 else: if key < root.key: root.left = insert(root.left,key,value) #小数放左边 elif key > root.key: root.right = insert(root.right,key,value) #大数放右边 else: # don't insert if key already exist in the tree pass return root#二叉树的三种应用def inorder(root): # Inorder (Left, Root, Right) if root is not None: inorder(root.left) print(root) inorder(root.right)def preorder(root): # Preorder (Root, Left, Right) if root is not None: print(root) preorder(root.left) preorder(root.right)def postorder(root): # Postorder (Left, Right, Root) if root is not None: postorder(root.left) postorder(root.right) print(root)def knn_search(root:Node,result_set:KNNResultSet,key): if root is None: return False # compare the root itself result_set.add_point(math.fabs(root.key - key),root.value) #计算worst_dist ,并把当前root.value(index二叉树)里的值加入到resut_set 中 if result_set.worstDist() == 0: return True if root.key >= key: # iterate left branch first if knn_search(root.left, result_set, key): return True elif math.fabs(root.key-key) < result_set.worstDist(): return knn_search(root.right, result_set, key) return False else: # iterate right branch first if knn_search(root.right, result_set, key): return True elif math.fabs(root.key-key) < result_set.worstDist(): return knn_search(root.left, result_set, key) return Falsedef radius_search(root: Node, result_set: RadiusNNResultSet, key): if root is None: return False # compare the root itself result_set.add_point(math.fabs(root.key - key), root.value) if root.key >= key: # iterate left branch first if radius_search(root.left, result_set, key): return True elif math.fabs(root.key-key) < result_set.worstDist(): return radius_search(root.right, result_set, key) return False else: # iterate right branch first if radius_search(root.right, result_set, key): return True elif math.fabs(root.key-key) < result_set.worstDist(): return radius_search(root.left, result_set, key) return Falsedef search_recursively(root,key): #1NN 搜索 ,递归法 if root is None or root.key == key: return root if key < root.key: return search_recursively(root.left,key) elif key > root.key: return search_recursively(root.right,key)def search_iterative(root, key): #1NN 搜索 ,循环判断 current_node = root while current_node is not None: if current_node.key == key: return current_node elif key < current_node.key: current_node = current_node.left elif key > current_node.key: current_node = current_node.right return current_nodedef main(): # Data generation db_size = 100 k = 5 #搜寻5个点 radius = 2.0 data = np.random.permutation(db_size).tolist() #random.permutation 随机排列一个数组 root =None for i,point in enumerate(data): root = insert(root,point,i) query_key = 6 result_set = KNNResultSet(capacity=k) knn_search(root, result_set, query_key) print('kNN Search:') print('index - distance') print(result_set) result_set = RadiusNNResultSet(radius=radius) radius_search(root, result_set, query_key) print('Radius NN Search:') print('index - distance') print(result_set) # print("inorder") # inorder(root) # print("preorder") # preorder(root) # print("postorder") # postorder(root) # node = search_recursive(root, 2) # print(node) # # node = search_iterative(root, 2) # print(node)if __name__ == '__main__': main()
结果:
import copyclass DistIndex: def __init__(self, distance, index): self.distance = distance self.index = index def __lt__(self, other): return self.distance < other.distanceclass KNNResultSet: def __init__(self, capacity): self.capacity = capacity self.count = 0 self.worst_dist = 1e10 self.dist_index_list = [] for i in range(capacity): self.dist_index_list.append(DistIndex(self.worst_dist, 0)) self.comparison_counter = 0 def size(self): return self.count def full(self): return self.count == self.capacity def worstDist(self): return self.worst_dist def add_point(self, dist, index): self.comparison_counter += 1 if dist > self.worst_dist: return if self.count < self.capacity: self.count += 1 i = self.count - 1 while i > 0: if self.dist_index_list[i - 1].distance > dist: self.dist_index_list[i] = copy.deepcopy(self.dist_index_list[i - 1]) i -= 1 else: break self.dist_index_list[i].distance = dist self.dist_index_list[i].index = index self.worst_dist = self.dist_index_list[self.capacity - 1].distance def __str__(self): output = '' for i, dist_index in enumerate(self.dist_index_list): output += '%d - %.2f\n' % (dist_index.index, dist_index.distance) output += 'In total %d comparison operations.' % self.comparison_counter return outputclass RadiusNNResultSet: def __init__(self, radius): self.radius = radius self.count = 0 self.worst_dist = radius self.dist_index_list = [] self.comparison_counter = 0 def size(self): return self.count def worstDist(self): return self.radius def add_point(self, dist, index): self.comparison_counter += 1 if dist > self.radius: return self.count += 1 self.dist_index_list.append(DistIndex(dist, index)) def __str__(self): self.dist_index_list.sort() output = '' for i, dist_index in enumerate(self.dist_index_list): output += '%d - %.2f\n' % (dist_index.index, dist_index.distance) output += 'In total %d neighbors within %f.\nThere are %d comparison operations.' \ % (self.count, self.radius, self.comparison_counter) return output
转载地址:http://acexi.baihongyu.com/