Tensorflow 中的 8 位量化错误 [英] Error with 8-bit Quantization in Tensorflow

查看:47
本文介绍了Tensorflow 中的 8 位量化错误的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我一直在试验新的 TensorFlow 中可用的 8 位量化功能.我可以在没有任何问题的情况下运行博客文章中给出的示例(googlenet 的量化),它对我来说很好用!!!

I have been experimenting with the new 8-bit quantization feature available in TensorFlow. I could run the example given in the blog post (quantization of googlenet) without any issue and it works fine for me !!!

现在,我想将相同的内容应用于更简单的网络.所以我使用了 CIFAR-10 的预训练网络(在 Caffe 上训练),提取其参数,在 tensorflow 中创建相应的图,用这个预训练的权重初始化权重,最后将其保存为 GraphDef 对象.请参阅此 IPython Notebook 了解完整过程.

Now, I would like to apply the same for a simpler network. So I used a pre-trained network for CIFAR-10 (which is trained on Caffe), extracted its parameters, created corresponding graph in tensorflow, initialized the weights with this pre-trained weights and finally saved it as a GraphDef object. See this IPython Notebook for full procedure.

现在我使用 tensorflow 脚本应用了 8 位量化,如 Pete Warden 的博客中所述:

Now I applied the 8-bit quantization with the tensorflow script as mentioned in the Pete Warden's blog:

bazel-bin/tensorflow/contrib/quantization/tools/quantize_graph --input=cifar.pb  --output=qcifar.pb --mode=eightbit --bitdepth=8 --output_node_names="ArgMax"

现在我想在这个量化网络上运行分类.所以我将新的 qcifar.pb 加载到 tensorflow 会话并传递图像(与我将其传递给原始版本的方式相同).完整代码可在此 IPython Notebook 中找到.

Now I wanted to run the classification on this quantized network. So I loaded the new qcifar.pb to a tensorflow session and passed the image (the same way I passed it to original version). Full code can be found in this IPython Notebook.

但正如你在最后看到的,我收到以下错误:

But as you can see at the end, I am getting following error:

NotFoundError:操作类型未注册QuantizeV2"

有人可以建议我在这里缺少什么吗?

Can anybody suggest what am I missing here?

推荐答案

由于量化操作和内核在 contrib 中,您需要在 Python 脚本中显式加载它们.在 quantize_graph 中有一个例子.py 脚本本身:

Because the quantized ops and kernels are in contrib, you'll need to explicitly load them in your python script. There's an example of that in the quantize_graph.py script itself:

from tensorflow.contrib.quantization import load_quantized_ops_so从 tensorflow.contrib.quantization.kernels 导入 load_quantized_kernels_so

这是我们应该更新文档以提及的内容!

This is something that we should update the documentation to mention!

这篇关于Tensorflow 中的 8 位量化错误的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆