如何在keras中使用Bert作为长文本分类中的段落编码器来实现网络? [英] How to implement network using Bert as a paragraph encoder in long text classification, in keras?
问题描述
我正在做一个长文本分类任务,该任务在doc中有10000个以上的单词,我计划使用Bert作为段落编码器,然后将段落的嵌入内容逐步导入BiLSTM. 网络如下:
I am doing a long text classification task, which has more than 10000 words in doc, I am planing to use Bert as a paragraph encoder, then feed the embeddings of paragraph to BiLSTM step by step. The network is as below:
输入:(批量大小,max_paragraph_len,max_tokens_per_para,embeddding_size)
Input: (batch_size, max_paragraph_len, max_tokens_per_para,embedding_size)
伯特层:(max_paragraph_len,paragraph_embedding_size)
bert layer: (max_paragraph_len,paragraph_embedding_size)
lstm层:???
lstm layer: ???
输出层:(batch_size,classification_size)
output layer: (batch_size,classification_size)
如何用keras实施它? 我正在使用keras的load_trained_model_from_checkpoint来加载bert模型
How to implement it with keras? I am using keras's load_trained_model_from_checkpoint to load bert model
bert_model = load_trained_model_from_checkpoint(
config_path,
model_path,
training=False,
use_adapter=True,
trainable=['Encoder-{}-MultiHeadSelfAttention-Adapter'.format(i + 1) for i in range(layer_num)] +
['Encoder-{}-FeedForward-Adapter'.format(i + 1) for i in range(layer_num)] +
['Encoder-{}-MultiHeadSelfAttention-Norm'.format(i + 1) for i in range(layer_num)] +
['Encoder-{}-FeedForward-Norm'.format(i + 1) for i in range(layer_num)],
)
推荐答案
I believe you can check the following article. The author shows how to load a pre-trained BERT model, embed it into a Keras layer and use it into a customized Deep Neural Network. First install TensorFlow 2.0 Keras implementation of google-research/bert:
pip install bert-for-tf2
然后运行:
import bert
import os
def createBertLayer():
global bert_layer
bertDir = os.path.join(modelBertDir, "multi_cased_L-12_H-768_A-12")
bert_params = bert.params_from_pretrained_ckpt(bertDir)
bert_layer = bert.BertModelLayer.from_params(bert_params, name="bert")
bert_layer.apply_adapter_freeze()
def loadBertCheckpoint():
modelsFolder = os.path.join(modelBertDir, "multi_cased_L-12_H-768_A-12")
checkpointName = os.path.join(modelsFolder, "bert_model.ckpt")
bert.load_stock_weights(bert_layer, checkpointName)
这篇关于如何在keras中使用Bert作为长文本分类中的段落编码器来实现网络?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!