Forecast_generator和类标签 [英] predict_generator and class labels

查看:88
本文介绍了Forecast_generator和类标签的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用ImageDataGenerator生成新的增强图像并从预先训练的模型中提取瓶颈特征,但是我在keras上看到的大多数教程 样本数量与目录中的图像数量相同.

I am using ImageDataGenerator to generate new augmented images and extract bottleneck features from pretrained model but most of the tutorial I see on keras samples same no of training samples as number of images in directory.

 train_generator  = train_datagen.flow_from_directory(
                train_path,
                target_size=image_size,
                shuffle = "false",
                class_mode='categorical',
                batch_size=1)

bottleneck_features_train = model.predict_generator(
    train_generator, 2* nb_train_samples // batch_size)

假设我想从上述代码中获得2倍的图像,如何获得从瓶颈层提取的要素所需的类标签,这些标签存储在元组 train_generator 中.

Suppose I want 2 times more images from the above code, how I can get the desired class labels for the features extracted from bottleneck layer which are stored in tuple train_generator.

不应在 training_generator.py 在422行

x, _ = generator_output 

做这样的事情

 => x, y = generator_output

并从predict_generator返回元组[np.concatenate(out) for out in all_outs],y

and return tuple [np.concatenate(out) for out in all_outs],y from predict_generator

即返回相应的类标签以及预测的特征 all_outs ,因为如果不运行两次生成器就无法获得相应的标签.

i.e return the corresponding class labels along with the predicted features all_outs since there is no way to get the corresponding labels without running generator twice.

推荐答案

如果您使用的是预测,通常您根本就不需要Y,因为Y将是预测的结果. (您不需要训练,因此不需要真实的标签)

If you're using predict, normally you simply don't want Y, because Y will be the result of the prediction. (You're not training, so you don't need the true labels)

但是您可以自己做:

bottleneck = []
labels = []
for i in range(2 * nb_train_samples // batch_size):
    x, y = next(train_generator)

    bottleneck.append(model.predict(x))
    labels.append(y) 

bottleneck = np.concatenate(bottleneck)
labels = np.concatenate(labels)

如果您希望通过索引编制索引(如果您的生成器支持的话):

If you want it with indexing (if your generator supports that):

#...
for epoch in range(2):
    for i in range(nb_train_samples // batch_size):
        x,y = train_generator[i]

        #...

这篇关于Forecast_generator和类标签的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆