在 Tensorflow 中限制多类分类中的输出类 [英] Restricting output classes in multi-class classification in Tensorflow
问题描述
我正在构建一个双向 LSTM 来进行多类句子分类.我总共有 13 个类可供选择,我将 LSTM 网络的输出乘以维数为 [2*num_hidden_unit,num_classes]
的矩阵,然后应用 softmax 来获得句子的概率属于 13 个类别中的 1 个.
I am building a bidirectional LSTM to do multi-class sentence classification.
I have in total 13 classes to choose from and I am multiplying the output of my LSTM network to a matrix whose dimensionality is [2*num_hidden_unit,num_classes]
and then apply softmax to get the probability of the sentence to fall into 1 of the 13 classes.
因此,如果我们将 output[-1]
视为网络输出:
So if we consider output[-1]
as the network output:
W_output = tf.Variable(tf.truncated_normal([2*num_hidden_unit,num_classes]))结果 = tf.matmul(output[-1],W_output) + 偏差
并且我得到了我的 [1, 13]
矩阵(假设我目前不使用批处理).
and I get my [1, 13]
matrix (assuming I am not working with batches for the moment).
现在,我还知道给定的句子肯定不属于给定的类别,我想限制给定句子考虑的类别数量.因此,例如,对于给定的句子,我知道它只能归入 6 个类别,因此输出实际上应该是一个维度矩阵 [1,6]
.
Now, I also have information that a given sentence does not fall into a given class for sure and I want to restrict the number of classes considered for a given sentence. So let's say for instance that for a given sentence, I know it can fall only in 6 classes so the output should really be a matrix of dimensionality [1,6]
.
我想到的一个选项是在 result
矩阵上放置一个掩码,在该矩阵中,我将与要保留的类对应的行乘以 1,将要丢弃的行乘以 0,通过这种方式,我只会丢失一些信息而不是重定向它.
One option I was thinking of is to put a mask over the result
matrix where I multiply the rows corresponding to the classes that I want to keep by 1 and the ones I want to discard by 0, by in this way I will just lose some of the information instead of redirecting it.
有人知道在这种情况下该怎么做吗?
Anyone has a clue on what to do in this case?
推荐答案
我认为您最好的选择是,正如您所描述的那样,使用加权交叉熵损失函数,其中不可能的类"的权重为 0 并且1 对于其他可能的类.Tensorflow 有一个加权交叉熵损失函数.
I think your best bet is, as you seem to have described, using a weighted cross entropy loss function where the weights for your "impossible class" are 0 and 1 for the other possible classes. Tensorflow has a weighted cross entropy loss function.
另一个有趣但可能不太有效的方法是提供您现在拥有的关于您的句子在某个时间点(可能会在最后)可以/不能落入网络的类别的任何信息.
Another interesting but probably less effective method is to feed whatever information you now have about what classes your sentence can/cannot fall into the network at some point (probably towards the end).
这篇关于在 Tensorflow 中限制多类分类中的输出类的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!