Numba 可以与 Tensorflow 一起使用吗? [英] Can Numba be used with Tensorflow?

查看:122
本文介绍了Numba 可以与 Tensorflow 一起使用吗?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

Numba 可以用来编译与 Tensorflow 接口的 Python 代码吗?IE.Tensorflow 宇宙之外的计算将使用 Numba 运行以提高速度.我还没有找到有关如何执行此操作的任何资源.

解决方案

您可以使用 tf.numpy_functiontf.py_func 来包装 python函数并将其用作 TensorFlow 操作.这是我使用的示例:

@jitdef dice_coeff_nb(y_true, y_pred):计算骰子系数"平滑 = np.float32(1)y_true_f = np.reshape(y_true, [-1])y_pred_f = np.reshape(y_pred, [-1])交集 = np.sum(y_true_f * y_pred_f)分数 = (2. * 交点 + 平滑)/(np.sum(y_true_f) +np.sum(y_pred_f) + 平滑)返回分数@jitdef dice_loss_nb(y_true, y_pred):计算骰子损失"损失 = 1 - dice_coeff_nb(y_true, y_pred)回波损耗def bce_dice_loss_nb(y_true, y_pred):将 dice_loss 添加到 categorical_crossentropy"损失 = tf.numpy_function(dice_loss_nb, [y_true, y_pred], tf.float64) + \tf.keras.losses.categorical_crossentropy(y_true, y_pred)回波损耗

然后我在训练一个 tf.keras 模型时使用了这个损失函数:

<预><代码>...模型 = tf.keras.models.Model(输入=输入,输出=输出)模型编译(优化器=亚当",损失=bce_dice_loss_nb)

Can Numba be used to compile Python code which interfaces with Tensorflow? I.e. computations outside of the Tensorflow universe would run with Numba for speed. I have not found any resources on how to do this.

解决方案

You can use tf.numpy_function, or tf.py_func to wrap a python function and use it as a TensorFlow op. Here is an example which I used:

@jit
def dice_coeff_nb(y_true, y_pred):
    "Calculates dice coefficient"
    smooth = np.float32(1)
    y_true_f = np.reshape(y_true, [-1])
    y_pred_f = np.reshape(y_pred, [-1])
    intersection = np.sum(y_true_f * y_pred_f)
    score = (2. * intersection + smooth) / (np.sum(y_true_f) +
                                            np.sum(y_pred_f) + smooth)
    return score

@jit
def dice_loss_nb(y_true, y_pred):
    "Calculates dice loss"
    loss = 1 - dice_coeff_nb(y_true, y_pred)
    return loss

def bce_dice_loss_nb(y_true, y_pred):
    "Adds dice_loss to categorical_crossentropy"
    loss =  tf.numpy_function(dice_loss_nb, [y_true, y_pred], tf.float64) + \
            tf.keras.losses.categorical_crossentropy(y_true, y_pred)
    return loss

Then I used this loss function in training a tf.keras model:

...
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
model.compile(optimizer='adam', loss=bce_dice_loss_nb)

这篇关于Numba 可以与 Tensorflow 一起使用吗?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆