尝试使用argmax预测标签时发生TypeError [英] TypeError when trying to predict labels with argmax

查看:59
本文介绍了尝试使用argmax预测标签时发生TypeError的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我已经成功地将转移学习教程用两个类来做我自己的分类器,即印象派"和现代主义".

I have successfully followed this transfer learning tutorial to make my own classifier with two classes, "impressionism" and "modernism".

现在尝试获取我的测试图像的标签,并应用> 线程:

Now trying to get a label for my test image, applying advice from this thread:

y_prob = model.predict(new_image)
y_prob

(gives this output) array([[3.1922062e-04, 9.9968076e-01]], dtype=float32)

y_classes = y_prob.argmax(axis=-1)
y_classes
(gives this output) array([1])

# create a list containing the class labels
labels = ['modernism', 'impressionism']
predicted_label = sorted(labels)[y_classes]

结果错误:

"TypeError                                 Traceback (most recent call last)
<ipython-input-35-571175bcfc65> in <module>()
      1 # create a list containing the class labels
      2 labels = ['modernism', 'impressionism']
----> 3 predicted_label = sorted(labels)[y_classes]

TypeError: only integer scalar arrays can be converted to a scalar index"

我做错了什么?访问测试图像的文本标签(及其概率)的正确方法是什么?如果我了解数组预测,则可以从我的图像文件夹中识别出两个类别.

What am I doing wrong and what would be the right way to access the text labels (and their probabilities) for my test image? If I understand the array prediction, it has recognized from my image folders that there are two classes.

非常感谢您有时间提供帮助!

Many thanks if you have time to help!

推荐答案

这里发生的是y_prob.argmax(axis=-1)返回数组值[1].只有numpy数组才能使用列表进行索引/拼接.

What's happening here is that y_prob.argmax(axis=-1) is returning an array value of [1]. Only numpy arrays can index/splice with a list.

该问题是由于sorted方法而发生的,我在测试中并未对此进行说明.即使输入数组的类型为np.ndarray,输出也将成为列表.

The issue occurs due to the sorted method, I was not accounting for that in my testing. Even though the input array is type np.ndarray, the output becomes a list.

所以:

labels = ['modernism', 'impressionism']
predicted_label = numpy.array(sorted(labels))[y_classes]

labels = numpy.array(['modernism', 'impressionism'])
labels.sort()
predicted_label = labels[y_classes]

这篇关于尝试使用argmax预测标签时发生TypeError的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆