如何使用 c++ api 在 tflite 中获取权重? [英] How to get weights in tflite using c++ api?

查看:36
本文介绍了如何使用 c++ api 在 tflite 中获取权重?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我在设备上使用 .tflite 模型.最后一层是 ConditionalRandomField 层,我需要该层的权重来进行预测.如何使用 c++ api 获取权重?

I am using a .tflite model on device. The last layer is ConditionalRandomField layer, and I need weights of the layer to do prediction. How do I get weights with c++ api?

相关:如何查看 .tflite 中的权重文件?

Netron 或 flatc 不能满足我的需求.设备太重.

Netron or flatc doesn't meet my needs. too heavy on device.

似乎 TfLiteNode 将权重存储在 void* user_data 或 void* builtin_data 中.我如何阅读它们?

It seems TfLiteNode stores weights in void* user_data or void* builtin_data. How do I read them?

更新:

结论:.tflite 在 .h5 剂量时不存储 CRF 权重.(也许是因为它们不影响输出.)

Conclusion: .tflite doesn't store CRF weights while .h5 dose. (Maybe because they do not affect output.)

我做什么:

// obtain from model.
Interpreter *interpreter;
// get the last index of nodes.
// I'm not sure if the index sequence of nodes is the direction which tensors or layers flows.
const TfLiteNode *node = &((interpreter->node_and_registration(interpreter->nodes_size()-1))->first);

// then follow the answer of @yyoon

推荐答案

在 TFLite 节点中,权重应该存储在 inputs 数组中,该数组包含相应 TfLiteTensor 的索引*.

In a TFLite node, the weights should be stored in the inputs array, which contains the index of the corresponding TfLiteTensor*.

所以,如果你已经获得了最后一层的TfLiteNode*,你可以做这样的事情来读取权重值.

So, if you have already obtained the TfLiteNode* of the last layer, you could do something like this to read the weight values.

TfLiteContext* context; // You would usually have access to this already.
TfLiteNode* node;       // <obtain this from the graph>;

for (int i = 0; i < node->inputs->size; ++i) {
  TfLiteTensor* input_tensor = GetInput(context, node, i);

  // Determine if this is a weight tensor.
  // Usually the weights will be memory-mapped read-only tensor
  // directly baked in the TFLite model (flatbuffer).
  if (input_tensor->allocation_type == kTfLiteMmapRo) {
    // Read the values from input_tensor, based on its type.
    // For example, if you have float weights,
    const float* weights = GetTensorData<float>(input_tensor);

    // <read the weight values...>
  }
}

这篇关于如何使用 c++ api 在 tflite 中获取权重?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆