从2个自动编码器中提取特征并将其馈入MLP [英] Extract features from 2 auto-encoders and feed them into an MLP

查看:85
本文介绍了从2个自动编码器中提取特征并将其馈入MLP的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我了解从自动编码器中提取的功能可以输入到mlp中以进行分类或回归.这是我之前所做的.
但是,如果我有2个自动编码器怎么办?是否可以从2个自动编码器的瓶颈层中提取特征并将其馈入基于这些特征执行分类的mlp中?如果是,那怎么办?我不确定如何串联这两个功能集.我尝试使用numpy.hstack()给出了无法散列的切片"错误,而使用tf.concat()则给出了错误模型的输入张量必须是Keras张量".两个自动编码器的瓶颈层各自的尺寸为(None,100).因此,从本质上讲,如果我水平堆叠它们,我应该得到一个(无,200). mlp的隐藏层可能包含一些(num_hidden = 100)神经元.谁能帮忙吗?

I understand that the features extracted from an auto-encoder can be fed into an mlp for classification or regression purpose. This is something that I did earlier.
But what if I have 2 auto-encoders? Can I extract the features from the bottleneck layers of 2 auto-encoders and feed them into an mlp which performs classification based on these features? If yes, then how? I am not sure how to concatenate these two feature sets. I tried with numpy.hstack() which gives me 'unhashable slice' error, whereas, using tf.concat() gives me the error 'Input tensors to a Model must be Keras tensors.' the bottleneck layers of the two auto-encoders are of dimension (None,100) each. So, essentially, if I stack them horizontally, I should be getting a (None, 200). The hidden layer of the mlp may contain some (num_hidden=100) neurons. Could anyone please help?

x1 = autoencoder1.get_layer('encoder2').output
x2 = autoencoder2.get_layer('encoder2').output

#inp = np.hstack((x1, x2))
inp = tf.concat([x1, x2], 1)
x = tf.concat([x1, x2], 1)
h = Dense(num_hidden, activation='relu', name='hidden')(x)
y = Dense(1, activation='sigmoid', name='prediction')(h)
mymlp = Model(inputs=inp, outputs=y)

# Compile model
mymlp.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

# Train model
mymlp.fit(x_train, y_train, epochs=20, batch_size=8)

根据@twolffpiggott的建议进行了更新:

updated as per @twolffpiggott's suggestion:

from keras.layers import Input, Dense, Dropout
from keras import layers
from keras.models import Model
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
import numpy as np

x1 = Data1
x2 = Data2
y = Data3

num_neurons1 = x1.shape[1]
num_neurons2 = x2.shape[1]

# Train-test split
x1_train, x1_test, x2_train, x2_test, y_train, y_test = train_test_split(x1, x2, y, test_size=0.2)

# scale data within [0-1] range
scalar = MinMaxScaler()
x1_train = scalar.fit_transform(x1_train)
x1_test = scalar.transform(x1_test)

x2_train = scalar.fit_transform(x2_train)
x2_test = scalar.transform(x2_test)

x_train = np.concatenate([x1_train, x2_train], axis =-1)
x_test = np.concatenate([x1_test, x2_test], axis =-1)

# Auto-encoder1

encoding_dim1 = 500
encoding_dim2 = 100

input_data = Input(shape=(num_neurons1,))
encoded = Dense(encoding_dim1, activation='relu', name='encoder1')(input_data)
encoded1 = Dense(encoding_dim2, activation='relu', name='encoder2')(encoded)
decoded = Dense(encoding_dim2, activation='relu', name='decoder1')(encoded1)
decoded = Dense(num_neurons1, activation='sigmoid', name='decoder2')(decoded)

# this model maps an input to its reconstruction
autoencoder1 = Model(inputs=input_data, outputs=decoded)
autoencoder1.compile(optimizer='sgd', loss='mse')                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    

# training
autoencoder1.fit(x1_train, x1_train,
                    epochs=100,
                    batch_size=8,
                    shuffle=True,
                    validation_data=(x1_test, x1_test))

# Auto-encoder2

encoding_dim1 = 500
encoding_dim2 = 100

input_data = Input(shape=(num_neurons2,))
encoded = Dense(encoding_dim1, activation='relu', name='encoder1')(input_data)
encoded2 = Dense(encoding_dim2, activation='relu', name='encoder2')(encoded)
decoded = Dense(encoding_dim2, activation='relu', name='decoder1')(encoded2)
decoded = Dense(num_neurons2, activation='sigmoid', name='decoder2')(decoded)


# this model maps an input to its reconstruction
autoencoder2 = Model(inputs=input_data, outputs=decoded)
autoencoder2.compile(optimizer='sgd', loss='mse')

# training
autoencoder2.fit(x2_train, x2_train,
                    epochs=100,
                    batch_size=8,
                    shuffle=True,
                    validation_data=(x2_test, x2_test))

# MLP

num_hidden = 100

encoded1.trainable = False
encoded2.trainable = False

encoded1 = autoencoder1(autoencoder1.inputs)
encoded2 = autoencoder2(autoencoder2.inputs)

concatenated = layers.concatenate([encoded1, encoded2], axis=-1)
x = Dropout(0.2)(concatenated)
h = Dense(num_hidden, activation='relu', name='hidden')(x)
h = Dropout(0.5)(h)
y = Dense(1, activation='sigmoid', name='prediction')(h)
myMLP = Model(inputs=[autoencoder1.inputs, autoencoder2.inputs], outputs=y)

# Compile model
myMLP.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

# Training
myMLP.fit(x_train, y_train, epochs=200, batch_size=8)

# Testing
myMLP.predict(x_test)

给我一​​个错误:不可散列的类型:行中的列表": myMLP = Model(inputs = [autoencoder1.inputs,autoencoder2.inputs],outputs = y)

giving me an error: unhashable type: 'list' from the line: myMLP = Model(inputs=[autoencoder1.inputs, autoencoder2.inputs], outputs=y)

推荐答案

问题是您正在将numpy数组与keras张量混合.不能走

The problem is that you're mixing numpy arrays with keras tensors. This can't go.

有两种方法.

  • 从每个自动编码器中预测numpy数组,并合并数组,然后将它们发送给第三个模型
  • 连接所有模型,可能会使自动编码器变得不易训练,并且每个自动编码器只有一个输入.

我个人是第一个. (假设自动编码器已经过训练,不需要更改.)

Personally, I'd go for the first. (Assuming the autoencoders are already trained and don't need change).

numpyOutputFromAuto1 = autoencoder1.predict(numpyInputs1)    
numpyOutputFromAuto2 = autoencoder2.predict(numpyInputs2)

inputDataForThird = np.concatenate([numpyOutputFromAuto1,numpyOutputFromAuto2],axis=-1)

inputTensorForMlp = Input(inputsForThird.shape[1:])
h = Dense(num_hidden, activation='relu', name='hidden')(inputTensorForMlp)
y = Dense(1, activation='sigmoid', name='prediction')(h)

mymlp = Model(inputs=inputTensorForMlp, outputs=y)

....
mymlp.fit(inputDataForThird ,someY)

第二种方法

这有点复杂,起初我没有太多理由这样做. (但是当然,在某些情况下,这是一个不错的选择)

Second Approach

This is a little more complicated, and at first I don't see much reason to do this. (But of course there may be cases where it's a good choice)

现在,我们完全忘记了numpy并使用keras张量.

Now we're totally forgetting numpy and working with keras tensors.

自行创建mlp(如果以后在没有自动编码器的情况下使用它,则很好):

Creating the mlp on its own (good if you will use it later without the autoencoders):

inputTensorForMlp = Input(input_shape_compatible_with_concatenated_encoder_outputs)
x = Dropout(0.2)(inputTensorForMlp)
h = Dense(num_hidden, activation='relu', name='hidden')(x)
h = Dropout(0.5)(h)
y = Dense(1, activation='sigmoid', name='prediction')(h)
myMLP = Model(inputs=[autoencoder1.inputs, autoencoder2.inputs], outputs=y)

我们可能想要自动编码器的瓶颈功能,对吗?如果碰巧使用以下方式正确创建自动编码器:编码器模型,解码器模型,将两者结合在一起,则仅使用编码器模型会更容易.其他:

We probably want the bottleneck features of the autoencoders, right? If you happened to create the autoencoders properly with: encoder model, decoder model, join both, then it's easier to use just the encoder model. Else:

encodedOutput1 = autoencoder1.layers[bottleneckLayer].outputs #or encoder1.outputs
encodedOutput2 = autoencoder1.layers[bottleneckLayer].outputs #or encoder2.outputs

创建一个联合模型.串联必须使用keras层(我们正在使用keras张量):

Creating a joined model. The concatenation must use a keras layer (we're working with keras tensors):

concatenated = Concatenate()([encodedOutput1,encodedOutput2])
output = myMLP(concatenated)

joinedModel = Model([autoencoder1.input,autoencoder2.input],output)

这篇关于从2个自动编码器中提取特征并将其馈入MLP的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆