TensorFlow对象检测API中的获取类和概率 [英] Get Class and Probability in Tensorflow Object Detection API
本文介绍了TensorFlow对象检测API中的获取类和概率的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!
问题描述
在TensorFlow对象检测API中获取类和检测到对象的概率时遇到问题。我想将这两个值与每个图像一起打印。
代码如下:
for image_path in TEST_IMAGE_PATHS:
image = Image.open(image_path)
# the array based representation of the image will be used later in order to prepare the
# result image with boxes and labels on it.
image_np = load_image_into_numpy_array(image)
# Expand dimensions since the model expects images to have shape: [1, None, None, 3]
image_np_expanded = np.expand_dims(image_np, axis=0)
# Actual detection.
output_dict = run_inference_for_single_image(image_np, detection_graph)
# Visualization of the results of a detection.
vis_util.visualize_boxes_and_labels_on_image_array(
image_np,
output_dict['detection_boxes'],
output_dict['detection_classes'],
output_dict['detection_scores'],
category_index,
instance_masks=output_dict.get('detection_masks'),
use_normalized_coordinates=True,
line_thickness=2)
plt.figure(figsize=IMAGE_SIZE)
plt.imshow(image_np)
推荐答案
以下代码提供了提取类ID和得分高于50%的所有实体的可能性。
#Create indexes list of element with a score > 0.5
indexes = [k for k,v in enumerate(output_dict['detection_scores']) if (v > 0.5)]
#Number of entities
num_entities = len(indexes)
#Extract the class id
class_id = itemgetter(*indexes)(output_dict['detection_classes'])
scores = itemgetter(*indexes)(output_dict['detection_scores'])
#Convert the class id in their name
class_names = []
if num_entities == 1:
class_names.append(category_index[class_id]['name'])
class_name = str(class_names)
else:
for i in range(0, len(indexes)):
class_names.append(category_index[class_id[i]]['name'])
如果只检测到一个元素,则IF是必需的。
然后您可以打印class_names[i]
和str(scores[i])
这篇关于TensorFlow对象检测API中的获取类和概率的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!
查看全文