numpy.amax 中的一个键 [英] A key in numpy.amax

查看:80
本文介绍了numpy.amax 中的一个键的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

在 Python 的标准 max 函数中我可以传入一个 key 参数:

In the Python's standard max function I can pass in a key parameter:

s = numpy.array(['one','two','three'])
max(s) # 'two' (lexicographically last)
max(s, key=len) # 'three' (longest string)

对于更大的(多维)数组,我们不能再使用 max,但我们可以使用 numpy.amax...不幸的是没有提供key参数.

With a larger (multi-dimensional) array, we can not longer use max, but we can use numpy.amax... which unfortunately offers no key parameter.

t = numpy.array([['one','two','three'],
                 ['four','five','six']], 
                dtype='object')
numpy.amax(t) # 'two` (max of the flat array)
numpy.amax(t, axis=1) # array([two, six], dtype=object) (max of first row, followed by max of second row)

我想做的是:

amax2(t, key=len) # 'three'
amax2(t, key=len, axis=1) # array([three, four], dtype=object)

是否有内置方法可以做到这一点?

Is there a built-in method to do this?

注意:在第一次尝试写这个问题时,我无法得到 amax 在这个玩具示例中工作

Note: In trying to write this question the first time I couldn't get amax working in this toy example!

推荐答案

这是一种非内置方式(它缺少 out 和 keepdim 参数href="http://docs.scipy.org/doc/numpy/reference/generated/numpy.amax.html#numpy.amax" rel="nofollow noreferrer">amax 的特点 当使用 key) 时,它看起来很长:

This is a non built-in way (it's missing the out and keepdim parameters of features of amax when using key), it seems rather long:

def amax2(x, *args, **kwargs):
    if 'key' not in kwargs:
        return numpy.amax(x,*args,**kwargs)
    else:
        key = kwargs.pop('key') # e.g. len, pop so no TypeError: unexpected keyword
        x_key = numpy.vectorize(key)(x) # apply key to x element-wise
        axis = kwargs.get('axis') # either None or axis is set in kwargs
        if len(args)>=2: # axis is set in args
            axis = args[1]

        # The following is kept verbose, but could be made more efficient/shorter    
        if axis is None: # max of flattened
            max_flat_index = numpy.argmax(x_key, axis=axis)
            max_tuple_index = numpy.unravel_index(max_flat_index, x.shape)
            return x[max_tuple_index]
        elif axis == 0: # max in each column
            max_indices = numpy.argmax(x_key, axis=axis)
            return numpy.array(
                 [ x[max_i, i] # reorder for col
                     for i, max_i in enumerate(max_indices) ], 
                 dtype=x.dtype)
        elif axis == 1: # max in each row
            max_indices = numpy.argmax(x_key, axis=axis)
            return numpy.array(
                 [ x[i, max_i]
                     for i, max_i in enumerate(max_indices) ],
                 dtype=x.dtype)

这个函数的想法是从 @PeterSobot 的回答的第二部分扩展到我之前的问题.

The idea for this function is extended from the second part of @PeterSobot's answer to my previous question.

这篇关于numpy.amax 中的一个键的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆