您如何优化此代码以进行 nn 预测? [英] How do you optimize this code for nn prediction?

查看:23
本文介绍了您如何优化此代码以进行 nn 预测?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

你如何优化这段代码?目前,它的运行速度因通过此循环的数据量而变慢.此代码运行 1-最近邻.它将根据 p_data_set 预测 training_element 的标签

How do you optimize this code? At the moment it is running to slow for the amount of data that goes through this loop. This code runs 1-nearest neighbor. It will predict the label of the training_element based off the p_data_set

#               [x] ,           [[x1],[x2],[x3]],    [l1, l2, l3]
def prediction(training_element, p_data_set, p_label_set):
    temp = np.array([], dtype=float)
    for p in p_data_set:
        temp = np.append(temp, distance.euclidean(training_element, p))

    minIndex = np.argmin(temp)
    return p_label_set[minIndex]

推荐答案

使用 k-D 树 用于快速最近邻查找,例如scipy.spatial.cKDTree:

Use a k-D tree for fast nearest-neighbour lookups, e.g. scipy.spatial.cKDTree:

from scipy.spatial import cKDTree

# I assume that p_data_set is (nsamples, ndims)
tree = cKDTree(p_data_set)

# training_elements is also assumed to be (nsamples, ndims)
dist, idx = tree.query(training_elements, k=1)

predicted_labels = p_label_set[idx]

这篇关于您如何优化此代码以进行 nn 预测?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆