将Keras模型整合到TensorFlow中 [英] Integrating Keras model into TensorFlow

查看:112
本文介绍了将Keras模型整合到TensorFlow中的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我试图在TensorFlow代码中使用预先训练的Keras模型,如

I am trying to use a pre-trained Keras model within TensorFlow code, as described in this Keras blog post under section II: Using Keras models with TensorFlow.

我想使用Keras中可用的预先训练的VGG16网络从图像中提取卷积特征图,并在其上添加我自己的TensorFlow代码.所以我做到了:

I want to use the pre-trained VGG16 network available in Keras to extract convolutional feature maps from images, and add my own TensorFlow code over that. So I've done this:

import tensorflow as tf
from tensorflow.python.keras.applications.vgg16 import VGG16, preprocess_input
from tensorflow.python.keras import backend as K

# images = a NumPy array containing 8 images

model = VGG16(include_top=False, weights='imagenet')
inputs = tf.placeholder(shape=images.shape, dtype=tf.float32)
inputs = preprocess_input(inputs)
features = model(inputs)

with tf.Session() as sess:
    K.set_session(sess)
    output = sess.run(features, feed_dict={inputs: images})
    print(output.shape)

但是,这给了我一个错误:

However, this gives me an error:

FailedPreconditionError: Attempting to use uninitialized value block1_conv1_2/kernel
     [[Node: block1_conv1_2/kernel/read = Identity[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"](block1_conv1_2/kernel)]]
     [[Node: vgg16_1/block5_pool/MaxPool/_3 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_132_vgg16_1/block5_pool/MaxPool", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

相反,如果我在运行网络之前运行初始化程序op:

Instead, if I run an initializer op before running the network:

with tf.Session() as sess:
    K.set_session(sess)
    tf.global_variables_initializer().run()
    output = sess.run(features, feed_dict={inputs: images})
    print(output.shape)

然后我得到了预期的输出:

Then I get the expected output:

(8, 11, 38, 512)

我的问题是,在运行tf.global_variables_initializer()时,变量是随机初始化还是使用ImageNet权重初始化?我之所以这么问,是因为上面提到的博客文章中没有提到使用预训练的Keras模型时需要运行初始化程序,这确实让我感到不安.

My question is, upon running tf.global_variables_initializer(), have the variables been initialized randomly or with the ImageNet weights? I ask this because the blog post referenced above does not mention that an initializer needs to be run when using pre-trained Keras models, and indeed it makes me feel a bit uneasy.

我怀疑它确实使用了ImageNet权重,并且仅因为TensorFlow需要显式初始化所有变量,才需要运行初始化程序.但这只是一个猜测.

I suspect that it does use the ImageNet weights, and that one needs to run the initializer only because TensorFlow requires all variables to be explicitly initialized. But this is just a guess.

推荐答案

TLDR

使用Keras时,

TLDR

When using Keras,

  1. 尽可能避免使用Session(本着不可知的Keras精神)
  2. 否则,请使用Keras处理的Sessiontf.keras.backend.get_session.
  3. 在程序的早期阶段将Keras的set_session用于高级用途(例如,当需要进行概要分析或放置设备时),这与通常的做法和在纯" Tensorflow中的良好用法背道而驰.
  1. Avoid using Session if you can (in the spirit of agnostic Keras)
  2. Use Keras-handled Session through tf.keras.backend.get_session otherwise.
  3. Use Keras' set_session for advanced uses (e.g. when you need profiling or device placement) and very early in your program — contrary to common practice and good usage in "pure" Tensorflow.

关于此的更多信息

必须先初始化变量,然后才能使用它们.实际上,它比这要微妙得多:必须在会话中使用变量对其进行初始化.让我们看一下这个例子:

More about that

Variables must be initialized before they can be used. Actually, it's a bit more subtle than that: Variables must be initialized in the session they are used. Let's look at this example:

import tensorflow as tf

x = tf.Variable(0.)

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    # x is initialized -- no issue here
    x.eval()

with tf.Session() as sess:
    x.eval()
    # Error -- x was never initialized in this session, even though
    # it has been initialized before in another session

因此,model中的变量未初始化也就不足为奇了,因为您是在sess之前创建模型的.

So it shouldn't come as a surprise that variables from your model are not initialized, because you create your model before sess.

但是,VGG16不仅为模型变量(用tf.global_variables_initializer调用的变量)创建初始化操作,而且实际上确实调用了它们.问题是,Session在哪个范围内?

However, VGG16 not only creates initializer operations for the model variables (the ones you are calling with tf.global_variables_initializer), but actually does call them. Question is, within which Session?

好吧,由于构建模型时不存在,因此Keras为您创建了一个默认模型,您可以使用tf.keras.backend.get_session()恢复它.现在可以按预期使用此会话,因为在此会话中初始化了变量:

Well, since none existed at the time you built your model, Keras created a default one for you, that you can recover using tf.keras.backend.get_session(). Using this session now works as expected because variables are initialized in this session:

with tf.keras.backend.get_session() as sess:
    K.set_session(sess)
    output = sess.run(features, feed_dict={inputs: images})
    print(output.shape)

请注意,您还可以创建自己的Session并通过keras.backend.set_session将其提供给Keras-这正是您所做的.但是,如本例所示,Keras和TensorFlow具有不同的心态.

Note that you could also create your own Session and provide it to Keras, through keras.backend.set_session — and this is exactly what you have done. But, as this example shows, Keras and TensorFlow have different mindsets.

一个TensorFlow用户通常会先构造一个图,然后实例化一个Session,也许是在冻结图之后.

A TensorFlow user would typically first construct a graph, then instantiate a Session, perhaps after freezing the graph.

Keras与框架无关,并且在构造阶段之间没有这种固有的区别-特别是,我们在这里了解到Keras可能很好地实例化了会话图构造期间.

Keras is framework-agnostic and does not have this built-in distinction between construction phases — in particular, we learned here that Keras may very well instantiate a Session during graph construction.

因此,在使用Keras时,如果您需要处理需要tf.Session的TensorFlow特定代码,我建议您不要自己管理tf.Session,而应该依靠tf.keras.backend.get_session.

For this reason, when using Keras, I would advise against managing a tf.Session yourself and instead rely on tf.keras.backend.get_session if you need to handle TensorFlow specific code that requires a tf.Session.

这篇关于将Keras模型整合到TensorFlow中的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆