Keras模型无法预测是否在线程中调用 [英] Keras model fails to predict if called in a thread

查看:71
本文介绍了Keras模型无法预测是否在线程中调用的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我尝试在线程应用程序中使用keras和可用的模型VGG16进行预测.但是,如果我在主线程中调用预测,则一切正常.但是,如果我在线程函数中进行预测(是否使用threadingmultiprocessing...),则在预测过程中它将停滞:

I try to perform predictions using keras and the available model VGG16 in a threaded application. However, if I call the prediction in the main thread everything works fine. But if I predict inside a threaded function (whether I use threading, multiprocessing, ...), it just stalls during prediction:

这是最小的示例:

########################################
# Alter this variable
USE_THREADING = True
########################################

import numpy as np
import cv2
import copy
import threading
import keras
import platform
import tensorflow as tf
from keras.models import model_from_json
from multiprocessing import Process

def inference_handler(model_hash, frame_resized):
    print("multiprocessing: before prediction call")
    model_hash.predict(np.expand_dims(frame_resized, axis=0), batch_size = 1)
    print("multiprocessing: after prediction call")

if __name__ == "__main__":
    print("keras version:", keras.__version__)
    print("tf vresion: ", tf.__version__)
    print("python version:", platform.python_version())
    model_hash = keras.applications.VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
    # Perform the demo
    cap = cv2.VideoCapture(0)
    while(True):
        # Capture frame-by-frame
        ret, frame = cap.read()
        # Process the keys
        key = cv2.waitKey(1) & 0xFF
        if key == ord('q'):
            print("quit")
            break
        # Get the proper image for the network
        frame_resized = cv2.resize(frame, (224, 224))
        # show the images
        cv2.imshow('frame',frame)
        cv2.imshow('frame_resized',frame_resized)

        # Predict
        if USE_THREADING:
            p = Process(target=inference_handler, args=(model_hash, frame_resized,))
            p.start()
            p.join()
        else:
            print("main thread: before prediction call")
            model_hash.predict(np.expand_dims(frame_resized, axis=0), batch_size = 1)
            print("main thread: after prediction call")


    # When everything done, release the capture
    cap.release()
    cv2.destroyAllWindows()

USE_THREADING = False 给我:

Using TensorFlow backend.
keras version: 2.2.0
tf vresion:  1.8.0
python version: 3.5.2
2019-02-25 20:47:32.926696: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
main thread: before prediction call
main thread: after prediction call
main thread: before prediction call
main thread: after prediction call
main thread: before prediction call
...

USE_THREADING = True (失败)给我:

Using TensorFlow backend.
keras version: 2.2.0
tf vresion:  1.8.0
python version: 3.5.2
2019-02-25 20:50:34.922696: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
multiprocessing: before prediction call

推荐答案

不幸的是,如果将模型作为子过程的参数,则具有tensorflow后端的Keras会在预测期间停止运行.但是,如果直接在子流程中创建模型,则一切正常.因此,解决方案是通过队列将帧发送到子进程.这是一个可行的解决方案:

So Keras with tensorflow backend has unfortunately the issue of halting during prediction, if the model was given as an argument to the sub-process. However, if the model is created directly in the sub-process, everything works fine. Therefore, the solution is to send the frames to the subprocess via queues. Here is a working solution:

import numpy as np
import cv2
import copy
import keras
import platform
import tensorflow as tf
from keras.models import model_from_json
from multiprocessing import Process, Queue

def inference_handler(frame_queue):
    model_hash = keras.applications.VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
    model_hash._make_predict_function()
    while True:
        print("multiprocessing: before queue")
        frame_resized = frame_queue.get(block=True, timeout=None)
        print("multiprocessing: before prediction call")
        model_hash.predict(np.expand_dims(frame_resized, axis=0), batch_size = 1)
        print("multiprocessing: after prediction call")

if __name__ == "__main__":
    print("keras version:", keras.__version__)
    print("tf version: ", tf.__version__)
    print("python version:", platform.python_version())
    frame_queue = Queue(maxsize=1)
    p = Process(target=inference_handler, args=(frame_queue,))
    p.start()
    # p.join()
    cap = cv2.VideoCapture(0)
    while(True):
        # Capture frame-by-frame
        ret, frame = cap.read()
        # Process the keys
        key = cv2.waitKey(1) & 0xFF
        if key == ord('q'):
            print("quit")
            break
        # Get the proper image for the network
        frame_resized = cv2.resize(frame, (224, 224))
        # show the images
        cv2.imshow('frame',frame)
        cv2.imshow('frame_resized',frame_resized)

        # Advertise the frame
        if frame_queue.empty():
            print("Put frame into the queue")
            frame_queue.put_nowait(frame_resized)

    # When everything done, release the capture
    p.terminate()
    cap.release()
    cv2.destroyAllWindows()

这给了我

keras version: 2.2.0
tf version:  1.8.0
python version: 3.5.2
Put frame into the queue
multiprocessing: before queue
multiprocessing: before prediction call
Put frame into the queue
multiprocessing: after prediction call
multiprocessing: before queue
multiprocessing: before prediction call
Put frame into the queue
...

这篇关于Keras模型无法预测是否在线程中调用的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆