张量不是该图的元素;部署Keras模型 [英] Tensor is not an element of this graph; deploying Keras model

查看:59
本文介绍了张量不是该图的元素;部署Keras模型的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在部署一个keras模型,并通过flask API将测试数据发送到该模型.我有两个文件:

Im deploying a keras model and sending the test data to the model via a flask api. I have two files:

首先:我的Flask应用:

First: My Flask App:

# Let's startup the Flask application
app = Flask(__name__)

# Model reload from jSON:
print('Load model...')
json_file = open('models/model_temp.json', 'r')
loaded_model_json = json_file.read()
json_file.close()
keras_model_loaded = model_from_json(loaded_model_json)
print('Model loaded...')

# Weights reloaded from .h5 inside the model
print('Load weights...')
keras_model_loaded.load_weights("models/Model_temp.h5")
print('Weights loaded...')

# URL that we'll use to make predictions using get and post
@app.route('/predict',methods=['GET','POST'])
def predict():
    data = request.get_json(force=True)
    predict_request = [data["month"],data["day"],data["hour"]] 
    predict_request = np.array(predict_request)
    predict_request = predict_request.reshape(1,-1)
    y_hat = keras_model_loaded.predict(predict_request, batch_size=1, verbose=1)
    return jsonify({'prediction': str(y_hat)}) 

if __name__ == "__main__":
    # Choose the port
    port = int(os.environ.get('PORT', 9000))
    # Run locally
    app.run(host='127.0.0.1', port=port)

第二:文件Im用于将json数据发送到api端点:

Second: The file Im using to send the json data sending to the api endpoint:

response = rq.get('api url has been removed')
data=response.json()
currentDT = datetime.datetime.now()
Month = currentDT.month
Day = currentDT.day
Hour = currentDT.hour

url= "http://127.0.0.1:9000/predict"
post_data = json.dumps({'month': month, 'day': day, 'hour': hour,})
r = rq.post(url,post_data)

我从Flask得到有关Tensorflow的回复:

Im getting this response from Flask regarding Tensorflow:

ValueError:Tensor Tensor("dense_6/BiasAdd:0",shape =(?, 1),dtype = float32)不是此图的元素.

ValueError: Tensor Tensor("dense_6/BiasAdd:0", shape=(?, 1), dtype=float32) is not an element of this graph.

我的keras模型是一个简单的6密层模型,并且训练没有错误.

My keras model is a simple 6 dense layer model and trains with no errors.

有什么想法吗?

推荐答案

Flask使用多个线程.您遇到的问题是因为tensorflow模型未在同一线程中加载和使用.一种解决方法是强制tensorflow使用gloabl默认图.

Flask uses multiple threads. The problem you are running into is because the tensorflow model is not loaded and used in the same thread. One workaround is to force tensorflow to use the gloabl default graph .

在加载模型后添加它

global graph
graph = tf.get_default_graph() 

在您的预测之内

with graph.as_default():
    y_hat = keras_model_loaded.predict(predict_request, batch_size=1, verbose=1)

这篇关于张量不是该图的元素;部署Keras模型的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆