如何在keras中使用Bert作为长文本分类中的段落编码器来实现网络? [英] How to implement network using Bert as a paragraph encoder in long text classification, in keras?

查看:919
本文介绍了如何在keras中使用Bert作为长文本分类中的段落编码器来实现网络?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在做一个长文本分类任务,该任务在doc中有10000个以上的单词,我计划使用Bert作为段落编码器,然后将段落的嵌入内容逐步导入BiLSTM. 网络如下:

I am doing a long text classification task, which has more than 10000 words in doc, I am planing to use Bert as a paragraph encoder, then feed the embeddings of paragraph to BiLSTM step by step. The network is as below:

输入:(批量大小,max_paragraph_len,max_tokens_per_para,embeddding_size)

Input: (batch_size, max_paragraph_len, max_tokens_per_para,embedding_size)

伯特层:(max_paragraph_len,paragraph_embedding_size)

bert layer: (max_paragraph_len,paragraph_embedding_size)

lstm层:???

lstm layer: ???

输出层:(batch_size,classification_size)

output layer: (batch_size,classification_size)

如何用keras实施它? 我正在使用keras的load_trained_model_from_checkpoint来加载bert模型

How to implement it with keras? I am using keras's load_trained_model_from_checkpoint to load bert model

bert_model = load_trained_model_from_checkpoint(
        config_path,
        model_path,
        training=False,
        use_adapter=True,
        trainable=['Encoder-{}-MultiHeadSelfAttention-Adapter'.format(i + 1) for i in range(layer_num)] +
            ['Encoder-{}-FeedForward-Adapter'.format(i + 1) for i in range(layer_num)] +
            ['Encoder-{}-MultiHeadSelfAttention-Norm'.format(i + 1) for i in range(layer_num)] +
            ['Encoder-{}-FeedForward-Norm'.format(i + 1) for i in range(layer_num)],
        )

推荐答案

我相信您可以检查以下

I believe you can check the following article. The author shows how to load a pre-trained BERT model, embed it into a Keras layer and use it into a customized Deep Neural Network. First install TensorFlow 2.0 Keras implementation of google-research/bert:

pip install bert-for-tf2

然后运行:

import bert
import os

def createBertLayer():
    global bert_layer

    bertDir = os.path.join(modelBertDir, "multi_cased_L-12_H-768_A-12")

    bert_params = bert.params_from_pretrained_ckpt(bertDir)

    bert_layer = bert.BertModelLayer.from_params(bert_params, name="bert")

    bert_layer.apply_adapter_freeze()

def loadBertCheckpoint():

    modelsFolder = os.path.join(modelBertDir, "multi_cased_L-12_H-768_A-12")
    checkpointName = os.path.join(modelsFolder, "bert_model.ckpt")

    bert.load_stock_weights(bert_layer, checkpointName)

这篇关于如何在keras中使用Bert作为长文本分类中的段落编码器来实现网络?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆