无法将生成器从 estimator.predict 转换为列表 [英] cannot convert generator to list from estimator.predict
本文介绍了无法将生成器从 estimator.predict 转换为列表的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!
问题描述
在使用张量流排名时,我们使用另一个模型(格子).我在测试数据中使用了一个查询,但是当我从 estimator.predict 函数获取生成器后尝试将生成器转换为列表时出现此错误.我确实为一个查询提供了 25 个示例(文档).
While using tensor flow ranking ranking we use another model (lattice). I am using a single query in my test data but I get this error when I try to convert generator into list after I get the generator from the estimator.predict function. I do feed 25 examples(documents) for one query.
我的代码是这样的:
features_test, labels_test = load_libsvm_data(FLAGS.dataset_base_path + '/test-sample-' + FLAGS.locale + '.txt',
FLAGS.num_total_results)
predict_fn, predict_hook = get_pred_inputs(features_test)
generator_ = estimator.predict(input_fn = predict_fn, hooks = [predict_hook])
predictions_list = list(generator_)
FLAGS.num_total_results 是 25,但我不知道如何得到结果.
The FLAGS.num_total_results is 25 but I don't know how I can get the result.
我得到的错误是这样的:
and the error I get is this:
File "ranking/web/python/tf_modeling/skipbinary_calibration2.py", line 567, in train_and_eval
print(next(generator_))
File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 613, in predict
self.config)
File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/estimator.py", line 1170, in _call_model_fn
model_fn_results = self._model_fn(features=features, **kwargs)
File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow_ranking/python/model.py", line 446, in _model_fn
config)
File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow_ranking/python/model.py", line 119, in compute_logits
features, mode, params)
File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow_ranking/python/model.py", line 87, in _call_transform_fn
return self._transform_fn(features, mode=mode)
File "ranking/web/python/tf_modeling/skipbinary_calibration2.py", line 490, in _transform_fn
scope="transform_layer"))
File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow_ranking/python/feature.py", line 233, in encode_pointwise_features
features, example_feature_columns.values(), mode=mode, scope=scope)
File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow_ranking/python/feature.py", line 98, in encode_features
dense_layer(features, cols_to_output_tensors=cols_to_tensors)
File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer_v1.py", line 778, in __call__
outputs = call_fn(cast_inputs, *args, **kwargs)
File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py", line 262, in wrapper
return converted_call(f, args, kwargs, options=options)
File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py", line 418, in converted_call
return _call_unconverted(f, args, kwargs, options, False)
File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py", line 346, in _call_unconverted
return f(*args, **kwargs)
File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow/python/feature_column/dense_features.py", line 146, in call
processed_tensors = self._process_dense_tensor(column, tensor)
File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow/python/feature_column/feature_column_v2.py", line 445, in _process_dense_tensor
return array_ops.reshape(tensor, shape=target_shape)
File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow/python/ops/array_ops.py", line 193, in reshape
result = gen_array_ops.reshape(tensor, shape, name)
File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 8087, in reshape
"Reshape", tensor=tensor, shape=shape, name=name)
File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py", line 744, in _apply_op_helper
attrs=attr_protos, op_def=op_def)
File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 3327, in _create_op_internal
op_def=op_def)
File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 1817, in __init__
control_input_ops, op_def)
File "/Users/vivekkaul/.pyenv/versions/tf2_lattice/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 1657, in _create_c_op
raise ValueError(str(e))
ValueError: Cannot reshape a tensor with 25 elements to shape [1,1] (1 elements) for '{{node transform/encoding_layer/1/Reshape}} = Reshape[T=DT_FLOAT, Tshape=DT_INT32](IteratorGetNext, transform/encoding_layer/1/Reshape/shape)' with input shapes: [1,25,1], [2] and with input tensors computed as partial shapes: input[1] = [1,1].
推荐答案
predictions_list = [x for x in generator_]
这篇关于无法将生成器从 estimator.predict 转换为列表的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!
查看全文