Python:更快地计算两个字典的余弦相似度 [英] Python: calculate cosine similarity of two dicts faster

查看:125
本文介绍了Python:更快地计算两个字典的余弦相似度的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有两个字典:

d1 = {1234: 4, 125: 7, ...}
d2 = {1234: 8, 1288: 5, ...}

字典的长度从 10 到 40000 不等.为了计算余弦相似度,我使用了这个函数:

The lengths of dicts vary from 10 to 40000. To calculate the cosine similarity I use this function:

from scipy.linalg import norm
def simple_cosine_sim(a, b):
    if len(b) < len(a):
        a, b = b, a

    res = 0
    for key, a_value in a.iteritems():
        res += a_value * b.get(key, 0)
    if res == 0:
        return 0

    try:
        res = res / norm(a.values()) / norm(b.values())
    except ZeroDivisionError:
        res = 0
    return res 

是否可以更快地计算相似度?

Is it possible to calculate similarity faster?

UPD:使用 Cython 重写代码的速度提高 15%.感谢@Davidmh

UPD: rewrite code using Cython +15% faster. Thanks to @Davidmh

from scipy.linalg import norm

def fast_cosine_sim(a, b):
    if len(b) < len(a):
        a, b = b, a

    cdef long up, key
    cdef int a_value, b_value

    up = 0
    for key, a_value in a.iteritems():
        b_value = b.get(key, 0)
        up += a_value * b_value
    if up == 0:
        return 0
    return up / norm(a.values()) / norm(b.values())

推荐答案

如果索引不是太高,你可以将每个字典转换成一个数组.如果它们非常大,您可以使用稀疏数组.然后,余弦相似度只会将它们相乘.如果您必须为多次计算重复使用同一个字典,则此方法的效果最佳.

If the indexes are not too high, you could convert each dictionary into an array. If they are very big, you could use an sparse array. Then, the cosine similarity would just multiply them both. This method would perform best if you have to reuse the same dictionary for several calculations.

如果这不是一个选项,Cython 应该很快,只要您注释 a_value 和 b_value.

If this is not an option, Cython should be pretty fast, as long as you annotate a_value and b_value.

看着你的 Cython 重写,我看到了一些改进.第一件事是做一个 cython -a 来生成编译的 HTML 报告,看看哪些东西已经加速,哪些没有.首先,您将向上"定义为长,但您正在对整数求和.此外,在您的示例中,键是整数,但您将它们声明为双精度.另一个简单的事情是将输入输入为 dicts.

Looking at your Cython rewritting, I see a few improvements. The first thing is to do a cython -a to generate a HTML report of the compilation and see which things have been accelerated and which haven't. First of all, you define "up" as long, but you are summing integers. Also, in your example, keys are integers, but you are declaring them as double. Another easy thing would be to type the input as dicts.

另外,检查 C 代码,似乎有一些不检查,您可以使用 @cython.nonechecks(False) 禁用.

Also, inspecting the C code, it seems there is some none checking, which you can disable by using @cython.nonechecks(False).

实际上,字典的实现是非常有效的,所以在一般情况下,你可能不会比这更好.如果您需要最大限度地利用代码,也许值得用 C API 替换一些调用:http://docs.python.org/2/c-api/dict.html

Actually, the implementation of dictionaries is quite efficient, so in a general case, you will probably not get much better than that. If you need to squeeze the most out of your code, perhaps it is worth replacing some calls with the C API: http://docs.python.org/2/c-api/dict.html

cpython.PyDict_GetItem(a, key)

但是,您将负责引用计数和从 PyObject * 转换为 int 以获得可疑的性能增益.

But then, you would be responsible for reference counting and casting from PyObject * to int for a dubious performance gain.

无论如何,代码的开头应该是这样的:

Any way, the beginning of the code would look like this:

cimport cython

@cython.nonecheck(False)
@cython.cdivision(True)
def fast_cosine_sim(dict a, dict b):
    if len(b) < len(a):
        a, b = b, a

    cdef int up, key
    cdef int a_value, b_value

还有一个问题:你的字典很大吗?因为如果不是,则范数的计算实际上可能是一个重要的开销.

Yet another concern: are your dicionaries big? Because if they are not, the computing of norm could actually be an important overhead.

编辑 2:另一种可能的方法是只查看必要的键.说:

Another possible approach is to only look at the keys that are necessary. Say:

from scipy.linalg import norm
cimport cython

@cython.nonecheck(False)
@cython.cdivision(True)
def fast_cosine_sim(dict a, dict b):
    cdef int up, key
    cdef int a_value, b_value

    up = 0
    for key in set(a.keys()).intersection(b.keys()):
        a_value = a[key]
        b_value = b[key]
        up += a_value * b_value
    if up == 0:
        return 0
    return up / norm(a.values()) / norm(b.values())

这在 Cython 中非常有效.实际性能可能取决于键之间有多少重叠.

This is very efficient in Cython. The actual performance will probably depend on how much overlap is there between the keys.

这篇关于Python:更快地计算两个字典的余弦相似度的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆