# KD树

## 1.概念

1.实例进行存储以便`快速检索的二叉树`形结构。

2.构造kd树相当于不断用垂直于坐标轴的`超平面对k维空间`切分，构成`一系列k维超矩形区域`。每个节点对应于k维超矩形区域。

3.所有`非叶子节点`可以视作用一个超平面把空间分区成`两个半空间`，节点`左边的子树`包含在`超平面左边的点`，节点`右边的子树`包含在`超平面右边的点`。

4.如果选择`按照x轴划分`，所有`x值小于划分值的节点`都会出现`在左子树`，所有`x值大于划分值的节点`都会出现`在右子树`。

## 2.轴的划分

假设二维，数据集T={(7,2),(5,4),(2,3),(4,7),(9,6),(8,1)}

`划分策略`：

`1）轮流对轴进行划分`，如二维，轮流对x、y划分

`2）基于轴上方差最大的轴划分`，这样划分区分度更大，如计算x上(7、5、2、4、9、8)，y轴上（2、4、3、7、6、1）值的方差，取最大的作为划分轴。

`划分值`：

通常选定划分策略后，`取中点作`为`划分值`。

## 3.生成

选定轴后，取轴的中点数字为划分点，如选定x轴:(7、5、2、4、9、8)，然后中点取7，则用(7,2)点作为划分，左子树数据上x轴小于7，右子树x值大于=7，的2个数据集划分

如图，1次轴划分后

![img](https://img-blog.csdnimg.cn/20190213232928393.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2ppYW5nNDI1Nzc2MDI0,size_16,color_FFFFFF,t_70)![点击并拖拽以移动](https://firebasestorage.googleapis.com/v0/b/gitbook-x-prod.appspot.com/o/spaces%2F-LlRDjw7ExCWOBrbokF1%2Fuploads%2FYBSt1TPKk2GixjrX9Nou%2Ffile.gif?alt=media)

最终不断基于轴划分，然后即可产生KD树：

![img](https://img-blog.csdnimg.cn/20190213233029595.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L2ppYW5nNDI1Nzc2MDI0,size_16,color_FFFFFF,t_70)![点击并拖拽以移动](https://firebasestorage.googleapis.com/v0/b/gitbook-x-prod.appspot.com/o/spaces%2F-LlRDjw7ExCWOBrbokF1%2Fuploads%2FizE8VzaDMWslyI1DtcQa%2Ffile.gif?alt=media)

## 4.KD树的应用

KDTree通常用在`KNN算法`等地方，寻找某个数据点`最近邻的k个点`。通过构造KDTree，可以`快速的查找`数据点的k个最近点。

1）python创建

```python
class KDNode(object):
    def __init__(self, value, split, left, right):
        # value=[x,y]
        self.value = value
        self.split = split
        self.right = right
        self.left = left


class KDTree(object):
    def __init__(self, data):
        # data=[[x1,y1],[x2,y2]...,]
        # 维度
        k = len(data[0])

        def CreateNode(split, data_set):
            if not data_set:
                return None
            data_set.sort(key=lambda x: x[split])
            # 整除2
            split_pos = len(data_set) // 2
            median = data_set[split_pos]
            split_next = (split + 1) % k

            return KDNode(median, split, CreateNode(split_next, data_set[: split_pos]),
                          CreateNode(split_next, data_set[split_pos + 1:]))

        self.root = CreateNode(0, data)

    def search(self, root, x, count=1):
        nearest = []
        for i in range(count):
            nearest.append([-1, None])
        self.nearest = np.array(nearest)

        def recurve(node):
            if node is not None:
                axis = node.split
                daxis = x[axis] - node.value[axis]
                if daxis < 0:
                    recurve(node.left)
                else:
                    recurve(node.right)
                dist = sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(x, node.value)))
                for i, d in enumerate(self.nearest):
                    if d[0] < 0 or dist < d[0]:  # 如果当前nearest内i处未标记（-1），或者新点与x距离更近
                        self.nearest = np.insert(self.nearest, i, [dist, node.value], axis=0)  # 插入比i处距离更小的
                        # 同时最后一位移除
                        self.nearest = self.nearest[:-1]
                        break
                # 找到nearest集合里距离最大值的位置====为-1值的个数
                n = list(self.nearest[:, 0]).count(-1)
                # 切分轴的距离比nearest中最大的小（存在相交）
                if self.nearest[-n - 1, 0] > abs(daxis):
                    if daxis < 0:  # 相交，x[axis]< node.data[axis]时，去右边（左边已经遍历了）
                        recurve(node.right)
                    else:  # x[axis]> node.data[axis]时，去左边，（右边已经遍历了）
                        recurve(node.left)
        recurve(root)
        return self.nearest



data = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
kd = KDTree(data)

#[3, 4.5]最近的3个点
n = kd.search(kd.root, [3, 4.5], 3)
print(n)

#[[1.8027756377319946 list([2, 3])]
 [2.0615528128088303 list([5, 4])]
 [2.692582403567252 list([4, 7])]]
```

2）基于sklearn

<https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KDTree.html>

```python
import numpy as np
from sklearn.neighbors import KDTree

np.random.seed(0)
X = np.array([[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]])

tree = KDTree(X, leaf_size=2)
query_point = np.array([[2.1, 3.1]])

print('查询点：', query_point)

dist, ind = tree.query(query_point, k=3)

print(dist)  # 3个最近的距离
print(ind)  # 3个最近的索引
print(X[ind])  # 3个最近的点

'''
查询点： [[2.1 3.1]]
[[0.14142136 3.03644529 4.33820239]]
[[0 1 3]]
[[[2 3]
  [5 4]
  [4 7]]]
'''
```


---

# Agent Instructions: Querying This Documentation

If you need additional information that is not directly available in this page, you can query the documentation dynamically by asking a question.

Perform an HTTP GET request on the current page URL with the `ask` query parameter:

```
GET https://im-qianuxn.gitbook.io/pytorch/ji-suan-ji/ml/kdtree.md?ask=<question>
```

The question should be specific, self-contained, and written in natural language.
The response will contain a direct answer to the question and relevant excerpts and sources from the documentation.

Use this mechanism when the answer is not explicitly present in the current page, you need clarification or additional context, or you want to retrieve related documentation sections.
