无法将生成器从 estimator.predict 转换为列表 [英] cannot convert generator to list from estimator.predict

查看:33
本文介绍了无法将生成器从 estimator.predict 转换为列表的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

在使用张量流排名时,我们使用另一个模型(格子).我在测试数据中使用了一个查询,但是当我从 estimator.predict 函数获取生成器后尝试将生成器转换为列表时出现此错误.我确实为一个查询提供了 25 个示例(文档).

While using tensor flow ranking ranking we use another model (lattice). I am using a single query in my test data but I get this error when I try to convert generator into list after I get the generator from the estimator.predict function. I do feed 25 examples(documents) for one query.

我的代码是这样的:

features_test, labels_test = load_libsvm_data(FLAGS.dataset_base_path + '/test-sample-' + FLAGS.locale + '.txt',
                                                FLAGS.num_total_results)

  predict_fn, predict_hook = get_pred_inputs(features_test)
  generator_ = estimator.predict(input_fn = predict_fn, hooks = [predict_hook])
  predictions_list = list(generator_)

FLAGS.num_total_results 是 25,但我不知道如何得到结果.

The FLAGS.num_total_results is 25 but I don't know how I can get the result.

我得到的错误是这样的:

and the error I get is this:

File "ranking/web/python/tf_modeling/skipbinary_calibration2.py", line 567, in train_and_eval
    print(next(generator_))
  File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 613, in predict
    self.config)
  File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1170, in _call_model_fn
    model_fn_results = self._model_fn(features=features, **kwargs)
  File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow_ranking/python/model.py", line 446, in _model_fn
    config)
  File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow_ranking/python/model.py", line 119, in compute_logits
    features, mode, params)
  File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow_ranking/python/model.py", line 87, in _call_transform_fn
    return self._transform_fn(features, mode=mode)
  File "ranking/web/python/tf_modeling/skipbinary_calibration2.py", line 490, in _transform_fn
    scope="transform_layer"))
  File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow_ranking/python/feature.py", line 233, in encode_pointwise_features
    features, example_feature_columns.values(), mode=mode, scope=scope)
  File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow_ranking/python/feature.py", line 98, in encode_features
    dense_layer(features, cols_to_output_tensors=cols_to_tensors)
  File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer_v1.py", line 778, in __call__
    outputs = call_fn(cast_inputs, *args, **kwargs)
  File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py", line 262, in wrapper
    return converted_call(f, args, kwargs, options=options)
  File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py", line 418, in converted_call
    return _call_unconverted(f, args, kwargs, options, False)
  File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py", line 346, in _call_unconverted
    return f(*args, **kwargs)
  File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow/python/feature_column/dense_features.py", line 146, in call
    processed_tensors = self._process_dense_tensor(column, tensor)
  File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow/python/feature_column/feature_column_v2.py", line 445, in _process_dense_tensor
    return array_ops.reshape(tensor, shape=target_shape)
  File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow/python/ops/array_ops.py", line 193, in reshape
    result = gen_array_ops.reshape(tensor, shape, name)
  File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 8087, in reshape
    "Reshape", tensor=tensor, shape=shape, name=name)
  File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py", line 744, in _apply_op_helper
    attrs=attr_protos, op_def=op_def)
  File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 3327, in _create_op_internal
    op_def=op_def)
  File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 1817, in __init__
    control_input_ops, op_def)
  File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 1657, in _create_c_op
    raise ValueError(str(e))
ValueError: Cannot reshape a tensor with 25 elements to shape [1,1] (1 elements) for '{{node transform/encoding_layer/1/Reshape}} = Reshape[T=DT_FLOAT, Tshape=DT_INT32](IteratorGetNext, transform/encoding_layer/1/Reshape/shape)' with input shapes: [1,25,1], [2] and with input tensors computed as partial shapes: input[1] = [1,1].

推荐答案

 predictions_list = [x for x in generator_]

这篇关于无法将生成器从 estimator.predict 转换为列表的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆