在Tensorflow中可视化注意力激活 [英] Visualizing attention activation in Tensorflow

查看:385
本文介绍了在Tensorflow中可视化注意力激活的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

在TensorFlow的seq2seq模型中,是否有办法可视化某些输入的注意力权重,如上面链接中的图(来自Bahdanau等人,2014)?我已经发现 TensorFlow的github问题,但是我找不到如何获取会议期间的注意口罩.

Is there a way to visualize the attention weights on some input like the figure in the link above(from Bahdanau et al., 2014), in TensorFlow's seq2seq models? I have found TensorFlow's github issue regarding this, but I couldn't find out how to fetch the attention mask during the session.

推荐答案

我还希望可视化Tensorflow seq2seq ops的文本摘要任务的关注权重.而且我认为临时解决方案是如上所述使用session.run()评估注意掩码张量.有趣的是,原始seq2seq.py ops被认为是旧版本,无法在github中轻松找到,因此我只在0.12.0 wheel发行版中使用了seq2seq.py文件并对其进行了修改.为了绘制热图,我使用了"Matplotlib"软件包,这非常方便.

I also want to visualize the attention weights of Tensorflow seq2seq ops for my text summarization task. And I think the temporary solution is to use session.run() to evaluate the attention mask tensor as mentioned above. Interestingly, the original seq2seq.py ops is considered legacy version and can’t be found in github easily so I just used the seq2seq.py file in the 0.12.0 wheel distribution and modified it. To draw the heatmap, I used the 'Matplotlib' package, which is very convenient.

新闻标题文本总和的注意力可视化的最终输出如下所示:

The final output of attention visualization for news headline textsum looks like this:

我修改了如下代码: https://github.com/rockingdingo/deepnlp/tree/master/deepnlp/textsum#attention-visualization

I modified the code as below: https://github.com/rockingdingo/deepnlp/tree/master/deepnlp/textsum#attention-visualization

seq2seq_attn.py

# Find the attention mask tensor in function attention_decoder()-> attention()
# Add the attention mask tensor to ‘return’ statement of all the function that calls the attention_decoder(), 
# all the way up to model_with_buckets() function, which is the final function I use for bucket training.

def attention(query):
  """Put attention masks on hidden using hidden_features and query."""
  ds = []  # Results of attention reads will be stored here.

  # some code

  for a in xrange(num_heads):
    with variable_scope.variable_scope("Attention_%d" % a):
      # some code

      s = math_ops.reduce_sum(v[a] * math_ops.tanh(hidden_features[a] + y),
                              [2, 3])
      # This is the attention mask tensor we want to extract
      a = nn_ops.softmax(s)

      # some code

  # add 'a' to return function
  return ds, a

seq2seq_model_attn.py

# modified model.step() function and return masks tensor
self.outputs, self.losses, self.attn_masks = seq2seq_attn.model_with_buckets(…)

# use session.run() to evaluate attn masks
attn_out = session.run(self.attn_masks[bucket_id], input_feed)
attn_matrix = ...

predict_attn.py eval.py

# Use the plot_attention function in eval.py to visual the 2D ndarray during prediction.

eval.plot_attention(attn_matrix[0:ty_cut, 0:tx_cut], X_label = X_label, Y_label = Y_label)

也许在将来,张量流将有更好的方法来提取和可视化注意力权重图.有什么想法吗?

And probably in the future tensorflow will have better way to extract and visualize the attention weight map. Any thoughts?

这篇关于在Tensorflow中可视化注意力激活的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆