如何将 perreplica 转换为张量? [英] how to convert perreplica to tensor?

查看:32
本文介绍了如何将 perreplica 转换为张量?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

在 tensorflow2.0 中使用多 GPU 进行训练时,perreplica 将通过以下代码减少:

When training with multi gpu in tensorflow2.0, perreplica would be reduce by below code:

strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)

但是,如果我只想将所有 gpu 的预测收集(没有总和减少"或平均减少")到张量中:

However, if I just want to collect(no 'sum reduce' or 'mean reduce') all gpu's predictions into a tensor:

per_replica_losses, per_replica_predicitions = strategy.experimental_run_v2(train_step, args=(dataset_inputs,))
# how to convert per_replica_predicitions to a tensor ?

推荐答案

简而言之,您可以将 PerReplica 结果转换成这样的张量元组:

In short, you can convert PerReplica result into a tuple of tensors like this:

tensors_tuple = per_replica_predicitions.values

返回的 tensors_tuple 将是来自每个副本/设备的 predictions 元组:

the return tensors_tuple will be a tuple of predictions from each replicas/devices:

(predicton_tensor_from_dev0, prediction_tensor_from_dev1,...)

此元组中的元素数量由分布式策略可用的设备决定.特别地,如果策略在单个副本/设备上运行,则 strategy.experimental_run_v2 的返回值将与直接调用 train_step 函数相同(张量或张量列表由您的 train_step 决定).所以你可能想写这样的代码:

The number of elements in this tuple is determined by your devices available to the distributed strategy. Specially, if the strategy runs on a single replica/device, the return value from strategy.experimental_run_v2 will be the same as calling train_step function directly (tensor or list of tensors decided by your train_step). So you might want to write the code like this:

per_replica_losses, per_replica_predicitions = strategy.experimental_run_v2(train_step, args=(dataset_inputs,))

if strategy.num_replicas_in_sync > 1:
    predicition_tensors = per_replica_predicitions.values
else:
    predicition_tensors = per_replica_predicitions

PerReplica 是一个封装了分布式运行结果的类对象.你可以在这里找到它的定义有更多的属性/方法供我们操作 PerReplica 对象.

PerReplica is a class object wrapping the results of distributed running. You can find its definition here, there are more properties/methods for us to operate the PerReplica object.

这篇关于如何将 perreplica 转换为张量?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆