我如何才能加速我编写的Python代码:使用空间搜索的球体接触检测(碰撞) [英] How could I speed up my written python code: spheres contact detection (collision) using spatial searching

查看:0
本文介绍了我如何才能加速我编写的Python代码:使用空间搜索的球体接触检测(碰撞)的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在处理一个球体的空间搜索案例,我想在其中找到连接的球体。为此,我在每个球体周围搜索中心与搜索球体中心的距离为(最大球体直径)的球体。一开始,我尝试使用Scipy相关方法,但与等价NumPy方法相比,Scipy方法耗时更长。对于Scipy算法,首先确定K-近邻球体的个数,然后再用cKDTree.query查找,这样会耗费较多的时间。然而,即使通过省略带有常量值的第一步(在这种情况下省略第一步是不好的),它也比NumPy方法慢。这与我对快速空间搜索速度的期望相反。所以,我尝试使用Numbaprange来加速一些列表循环而不是一些麻木的行。Numba可以更快地运行代码,但我相信可以通过向量化、使用其他替代NumPy模块或以其他方式使用Numba来优化代码以获得更好的性能。我在所有球体上使用了迭代,以防止可能的内存泄漏和…,其中球体数量较多。

import numpy as np
import numba as nb
from scipy.spatial import cKDTree, distance

# ---------------------------- input data ----------------------------
""" For testing by prepared files:
radii = np.load('a.npy')     # shape: (n-spheres, )     must be loaded by np.load('a.npy') or np.loadtxt('radii_large.csv')
poss = np.load('b.npy')      # shape: (n-spheres, 3)    must be loaded by np.load('b.npy') or np.loadtxt('pos_large.csv', delimiter=',')
"""

rnd = np.random.RandomState(70)
data_volume = 200000

radii = rnd.uniform(0.0005, 0.122, data_volume)
dia_max = 2 * radii.max()

x = rnd.uniform(-1.02, 1.02, (data_volume, 1))
y = rnd.uniform(-3.52, 3.52, (data_volume, 1))
z = rnd.uniform(-1.02, -0.575, (data_volume, 1))
poss = np.hstack((x, y, z))
# --------------------------------------------------------------------

# @nb.jit('float64[:,::1](float64[:,::1], float64[::1])', forceobj=True, parallel=True)
def ends_gap(poss, dia_max):
    particle_corsp_overlaps = np.array([], dtype=np.float64)
    ends_ind = np.empty([1, 2], dtype=np.int64)
    """ using list looping """
    # particle_corsp_overlaps = []
    # ends_ind = []

    # for particle_idx in nb.prange(len(poss)):  # by list looping
    for particle_idx in range(len(poss)):
        unshared_idx = np.delete(np.arange(len(poss)), particle_idx)                                                    # <--- relatively high time consumer
        poss_without = poss[unshared_idx]

        """ # SCIPY method ---------------------------------------------------------------------------------------------
        nears_i_ind = cKDTree(poss_without).query_ball_point(poss[particle_idx], r=dia_max)         # <--- high time consumer
        if len(nears_i_ind) > 0:
            dist_i, dist_i_ind = cKDTree(poss_without[nears_i_ind]).query(poss[particle_idx], k=len(nears_i_ind))       # <--- high time consumer
            if not isinstance(dist_i, float):
                dist_i[dist_i_ind] = dist_i.copy()
        """  # NUMPY method --------------------------------------------------------------------------------------------
        lx_limit_idx = poss_without[:, 0] <= poss[particle_idx][0] + dia_max
        ux_limit_idx = poss_without[:, 0] >= poss[particle_idx][0] - dia_max
        ly_limit_idx = poss_without[:, 1] <= poss[particle_idx][1] + dia_max
        uy_limit_idx = poss_without[:, 1] >= poss[particle_idx][1] - dia_max
        lz_limit_idx = poss_without[:, 2] <= poss[particle_idx][2] + dia_max
        uz_limit_idx = poss_without[:, 2] >= poss[particle_idx][2] - dia_max

        nears_i_ind = np.where(lx_limit_idx & ux_limit_idx & ly_limit_idx & uy_limit_idx & lz_limit_idx & uz_limit_idx)[0]
        if len(nears_i_ind) > 0:
            dist_i = distance.cdist(poss_without[nears_i_ind], poss[particle_idx][None, :]).squeeze()                   # <--- relatively high time consumer
        # """  # -------------------------------------------------------------------------------------------------------
            contact_check = dist_i - (radii[unshared_idx][nears_i_ind] + radii[particle_idx])
            connected = contact_check[contact_check <= 0]

            particle_corsp_overlaps = np.concatenate((particle_corsp_overlaps, connected))
            """ using list looping """
            # if len(connected) > 0:
            #    for value_ in connected:
            #        particle_corsp_overlaps.append(value_)

            contacts_ind = np.where([contact_check <= 0])[1]
            contacts_sec_ind = np.array(nears_i_ind)[contacts_ind]
            sphere_olps_ind = np.where((poss[:, None] == poss_without[contacts_sec_ind][None, :]).all(axis=2))[0]       # <--- high time consumer

            ends_ind_mod_temp = np.array([np.repeat(particle_idx, len(sphere_olps_ind)), sphere_olps_ind], dtype=np.int64).T
            if particle_idx > 0:
                ends_ind = np.concatenate((ends_ind, ends_ind_mod_temp))
            else:
                ends_ind[0, 0], ends_ind[0, 1] = ends_ind_mod_temp[0, 0], ends_ind_mod_temp[0, 1]
            """ using list looping """
            # for contacted_idx in sphere_olps_ind:
            #    ends_ind.append([particle_idx, contacted_idx])

    # ends_ind_org = np.array(ends_ind)  # using lists
    ends_ind_org = ends_ind
    ends_ind, ends_ind_idx = np.unique(np.sort(ends_ind_org), axis=0, return_index=True)                                # <--- relatively high time consumer
    gap = np.array(particle_corsp_overlaps)[ends_ind_idx]
    return gap, ends_ind, ends_ind_idx, ends_ind_org

在我对23000个球体进行的一次测试中,Scipy、NumPy和Numba辅助的方法使用Colab TPU在大约400秒、200秒和180秒内完成了循环;对于500.000个球体,则需要3.5h。对于我的项目来说,这些执行时间根本不令人满意,在中等数据量的情况下,球体的数量可能高达1.000.000。我将在我的主代码中多次调用此代码,并寻找能够在毫秒(尽可能快)内执行此代码的方法。有可能吗?? 如果有人能在需要时加快代码的速度,我将不胜感激。

备注:

  • 此代码必须可以在CPU和GPU上使用Python3.7+执行。
  • 此代码必须适用于数据大小,至少为300.000个球体。
  • 都麻木、僵硬、…相同的模块,而不是我编写的模块,使我的代码显著更快,将被提升。

如有任何建议或解释,我将不胜感激:

  1. 在此主题中,哪种方法更快?
  2. 在这种情况下,Scipy为什么不比其他方法更快?它在哪些方面可能对此主题有所帮助?
  3. 选择迭代器方法和矩阵形式方法对我来说是一件令人困惑的事情。迭代方法使用较少的内存,并且可以由numba和…使用和调整但是,我认为它们并不有用,也不能与NumPy和…等矩阵方法(取决于内存限制)相媲美对于巨大的球体数。对于这种情况,也许我可以省略NumPy的迭代,但我强烈地认为,由于巨大的矩阵大小操作和内存泄漏,它无法处理。

准备好的样本测试数据:

POSS数据:23000500000
半径数据:23000500000
线速测试日志:两个测试用例scipy方法和numpy时间消耗。

推荐答案

第一步:更好的算法

首先,构建一棵k-d树需要O(n log n)时间,执行查询需要O(log n)时间,其中n是点数。因此,乍一看,使用k-d树似乎是个好主意。但是,您的代码为每个点构建了一棵k-d树,从而导致O(n² log n)时间。这就是为什么Scipy解决方案比其他解决方案慢的原因。问题是,Scipy不提供更新k-d树的方法。原来updating efficiently a k-d tree appears not to be possible。希望这对您来说不是问题:您只需构建一棵包含所有点的k-d树,然后丢弃您不希望出现在每次查询结果中的当前点

此外,sphere_olps_ind的计算运行在O(n² m)时间内,其中n是点的总数,m是邻居(即.从K-D树查询中检索到的最近点)。假设没有重复点,那么sphere_olps_ind就等于np.sort(contacts_sec_ind)。后者在O(m log m)中运行,这要好得多。

此外,在循环中使用np.concatenate在Numpy数组中附加值很慢,因为它会为每次迭代创建一个新的更大的数组。使用列表是个好主意,但将Numpy数组直接追加到列表中,然后调用np.concatenate一次要快得多

以下是生成的代码:

def ends_gap(poss, dia_max):
    particle_corsp_overlaps = []
    ends_ind = [np.empty([1, 2], dtype=np.int64)]

    kdtree = cKDTree(poss)

    for particle_idx in range(len(poss)):
        # Find the nearest point including the current one and
        # then remove the current point from the output.
        # The distances can be computed directly without a new query.
        cur_point = poss[particle_idx]
        nears_i_ind = np.array(kdtree.query_ball_point(cur_point, r=dia_max), dtype=np.int64)
        assert len(nears_i_ind) > 0

        if len(nears_i_ind) <= 1:
            continue

        nears_i_ind = nears_i_ind[nears_i_ind != particle_idx]
        dist_i = distance.cdist(poss[nears_i_ind], cur_point[None, :]).squeeze()

        contact_check = dist_i - (radii[nears_i_ind] + radii[particle_idx])
        connected = contact_check[contact_check <= 0]

        particle_corsp_overlaps.append(connected)

        contacts_ind = np.where([contact_check <= 0])[1]
        contacts_sec_ind = nears_i_ind[contacts_ind]
        sphere_olps_ind = np.sort(contacts_sec_ind)

        ends_ind_mod_temp = np.array([np.repeat(particle_idx, len(sphere_olps_ind)), sphere_olps_ind], dtype=np.int64).T
        if particle_idx > 0:
            ends_ind.append(ends_ind_mod_temp)
        else:
            ends_ind[0][:] = ends_ind_mod_temp[0, 0], ends_ind_mod_temp[0, 1]

    ends_ind_org = np.concatenate(ends_ind)
    ends_ind, ends_ind_idx = np.unique(np.sort(ends_ind_org), axis=0, return_index=True)                                # <--- relatively high time consumer
    gap = np.concatenate(particle_corsp_overlaps)[ends_ind_idx]
    return gap, ends_ind, ends_ind_idx, ends_ind_org

第2步:优化

首先,通过向Scipy方法提供poss并指定参数workers=-1query_ball_point调用可以在并行中的所有点上一次完成。但是,请注意,这需要更多内存。

此外,Numba可用于显著提高计算速度。可以改进的主要部分是计算距离和创建许多不必要的临时数组,以及使用块数组直接索引而不是列表的追加(因为输出数组的有界大小可以在query_ball_point调用后得知)。

以下是使用Numba优化代码的简单示例:

@nb.jit('(float64[:, ::1], int64[::1], int64[::1], float64)')
def compute(poss, all_neighbours, all_neighbours_sizes, dia_max):
    particle_corsp_overlaps = []
    ends_ind_lst = [np.empty((1, 2), dtype=np.int64)]
    an_offset = 0

    for particle_idx in range(len(poss)):
        cur_point = poss[particle_idx]
        cur_len = all_neighbours_sizes[particle_idx]
        nears_i_ind = all_neighbours[an_offset:an_offset+cur_len]
        an_offset += cur_len
        assert len(nears_i_ind) > 0

        if len(nears_i_ind) <= 1:
            continue

        nears_i_ind = nears_i_ind[nears_i_ind != particle_idx]
        dist_i = np.empty(len(nears_i_ind), dtype=np.float64)

        # Compute the distances
        x1, y1, z1 = poss[particle_idx]
        for i in range(len(nears_i_ind)):
            x2, y2, z2 = poss[nears_i_ind[i]]
            dist_i[i] = np.sqrt((x2-x1)**2 + (y2-y1)**2 + (z2-z1)**2)

        contact_check = dist_i - (radii[nears_i_ind] + radii[particle_idx])
        connected = contact_check[contact_check <= 0]

        particle_corsp_overlaps.append(connected)

        contacts_ind = np.where(contact_check <= 0)
        contacts_sec_ind = nears_i_ind[contacts_ind]
        sphere_olps_ind = np.sort(contacts_sec_ind)

        ends_ind_mod_temp = np.empty((len(sphere_olps_ind), 2), dtype=np.int64)
        for i in range(len(sphere_olps_ind)):
            ends_ind_mod_temp[i, 0] = particle_idx
            ends_ind_mod_temp[i, 1] = sphere_olps_ind[i]

        if particle_idx > 0:
            ends_ind_lst.append(ends_ind_mod_temp)
        else:
            tmp = ends_ind_lst[0]
            tmp[:] = ends_ind_mod_temp[0, :]

    return particle_corsp_overlaps, ends_ind_lst

def ends_gap(poss, dia_max):
    kdtree = cKDTree(poss)
    tmp = kdtree.query_ball_point(poss, r=dia_max, workers=-1)
    all_neighbours = np.concatenate(tmp, dtype=np.int64)
    all_neighbours_sizes = np.array([len(e) for e in tmp], dtype=np.int64)
    particle_corsp_overlaps, ends_ind_lst = compute(poss, all_neighbours, all_neighbours_sizes, dia_max)
    ends_ind_org = np.concatenate(ends_ind_lst)
    ends_ind, ends_ind_idx = np.unique(np.sort(ends_ind_org), axis=0, return_index=True)
    gap = np.concatenate(particle_corsp_overlaps)[ends_ind_idx]
    return gap, ends_ind, ends_ind_idx, ends_ind_org

ends_gap(poss, dia_max)

性能分析

以下是我的6核计算机(i5-9600KF处理器)在小型数据集上的性能结果:

Initial code with Scipy:             259 s
Initial default code with Numpy:     112 s
Optimized algorithm:                   1.37 s
Final optimized code:                  0.22 s

以下是大数据集上的性能结果:

Initial code with Scipy:          100000 s     (estimation)
Initial default code with Numpy:    6700 s     (estimation)
Optimized algorithm:                   6.36 s
Final optimized code:                  1.28 s

因此,采用高效算法的Numba实现比初始的Numpy实现快了~5230倍,比初始的Scipy实现快了约78000倍。

可以进一步优化Numba代码,但请注意,在我的机器上调用Numbacompute花费的时间不到25%。np.unique调用是最昂贵的,但要使其更快并不容易。很大一部分时间花在从Scipy到Numba的数据转换上,但只要使用Scipy,此代码就是必需的。因此,代码可以稍加改进(例如,当然快2倍),但如果您需要更快的代码,那么您需要使用像C++这样的本地语言和高度优化的并行k-d树实现。我预计经过高度优化的本机代码会快上一个数量级,但也不会比这个快多少。我几乎不相信在我的机器上,无论实施什么,大数据集都能在10毫秒内计算出来。


备注

注意,gap与提供的函数不同(其他值保持不变)。然而,同样的事情发生在初始Scipy方法和Numpy方法之间。这似乎来自Scipy未定义的nears_i_inddist_i等变量的排序,并以非平凡的方式更改gap结果(不仅仅是gap的顺序)。我不确定这是最初实施的问题。正因为如此,比较不同实现的正确性要困难得多。

forceobj不应在生产中使用,因为文档说明这仅用于测试目的。

这篇关于我如何才能加速我编写的Python代码:使用空间搜索的球体接触检测(碰撞)的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆