张量流keras模型的动态历元数 [英] Dynamic number of epochs with a tensorflow keras model

查看:77
本文介绍了张量流keras模型的动态历元数的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我希望有一个神经网络进行训练,直到达到一定程度的准确性为止.是否有内置功能可以使用,而不是单独运行每个时期直到达到精度?

I want to have a neural net that trains until a certain level of accuracy has been reached. Is there a built in function to use instead of running each epoch individually until the accuracy has been reached?

model = keras.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(128, activation=tf.nn.relu),
    keras.layers.Dense(128, activation=tf.nn.relu),
    keras.layers.Dense(128, activation=tf.nn.relu),
    keras.layers.Dense(128, activation=tf.nn.relu),
    keras.layers.Dense(10, activation=tf.nn.softmax)
])

model.compile(optimizer=tf.train.AdamOptimizer(), loss='sparse_categorical_crossentropy', metrics=['accuracy'])

epochs = 0
train_acc = 0
while 1-train_acc > .01:
    model.fit(train_images, train_labels,  initial_epoch=epochs, epochs=epochs+1,verbose=0)
    epochs += 1
    train_loss, train_acc = model.evaluate(train_images,train_labels)

推荐答案

不,没有任何内置函数可以执行此操作.但是,您可以轻松定义一个自定义回调,一旦训练精度达到一定阈值,该回调便停止训练:

No, there isn't any built in function to do this. However, you can easily define a custom callback that stops training once the training accuracy reaches a certain threshold:

import keras


class AccuracyStopping(keras.callbacks.Callback):
    def __init__(self, acc_threshold):
        super(AccuracyStopping, self).__init__()
        self._acc_threshold = acc_threshold

    def on_epoch_end(self, batch, logs={}):
        train_acc = logs.get('acc')
        self.model.stop_training = 1 - train_acc <= self._acc_threshold

这是一个简单的示例,展示了如何使用它:

Here's a simple example showing how to use it:

import numpy as np
from keras.layers import Dense
from keras.models import Sequential

x = np.random.normal(size=(100,))
y = x > 0

model = Sequential()
model.add(Dense(1, input_dim=1, activation='sigmoid'))
model.compile('sgd', 'binary_crossentropy', metrics=['accuracy'])

acc_callback = AccuracyStopping(0.05)
model.fit(x, y, batch_size=8, epochs=1000, callbacks=[acc_callback])

这篇关于张量流keras模型的动态历元数的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆