在 Tensorflow 中可视化注意力激活 [英] Visualizing attention activation in Tensorflow

查看:21
本文介绍了在 Tensorflow 中可视化注意力激活的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

有没有办法在 TensorFlow 的 seq2seq 模型中可视化某些输入的注意力权重,例如上面链接中的图(来自 Bahdanau 等人,2014 年)?我已经找到了

我修改了代码如下:https://github.com/rockingdingo/deepnlp/tree/master/deepnlp/textsum#attention-visualization

seq2seq_attn.py

#在函数attention_decoder()->中找到attention mask tensor注意力()# 将注意力掩码张量添加到所有调用 attention_decoder() 的函数的 ‘return’ 语句中,# 一直到model_with_buckets() 函数,这是我用于bucket 训练的最后一个函数.定义注意(查询):"""使用 hidden_​​features 和查询将注意力掩码放在隐藏上."""ds = [] # 注意力读取的结果会存储在这里.# 一些代码对于 in xrange(num_heads):使用 variable_scope.variable_scope("Attention_%d" % a):# 一些代码s = math_ops.reduce_sum(v[a] * math_ops.tanh(hidden_​​features[a] + y),[2, 3])# 这是我们要提取的注意力掩码张量a = nn_ops.softmax(s)# 一些代码# 添加'a'来返回函数返回 ds,一个

seq2seq_model_attn.py

# 修改model.step() 函数并返回掩码张量self.outputs, self.losses, self.attn_masks = seq2seq_attn.model_with_buckets(…)# 使用 session.run() 来评估 attn 掩码attn_out = session.run(self.attn_masks[bucket_id], input_feed)attn_matrix = ...

predict_attn.pyeval.py

# 使用 eval.py 中的 plot_attention 函数在预测过程中可视化二维 ndarray.eval.plot_attention(attn_matrix[0:ty_cut, 0:tx_cut], X_label = X_label, Y_label = Y_label)

并且可能在未来 tensorflow 将有更好的方法来提取和可视化注意力权重图.有什么想法吗?

Is there a way to visualize the attention weights on some input like the figure in the link above(from Bahdanau et al., 2014), in TensorFlow's seq2seq models? I have found TensorFlow's github issue regarding this, but I couldn't find out how to fetch the attention mask during the session.

解决方案

I also want to visualize the attention weights of Tensorflow seq2seq ops for my text summarization task. And I think the temporary solution is to use session.run() to evaluate the attention mask tensor as mentioned above. Interestingly, the original seq2seq.py ops is considered legacy version and can’t be found in github easily so I just used the seq2seq.py file in the 0.12.0 wheel distribution and modified it. To draw the heatmap, I used the 'Matplotlib' package, which is very convenient.

The final output of attention visualization for news headline textsum looks like this:

I modified the code as below: https://github.com/rockingdingo/deepnlp/tree/master/deepnlp/textsum#attention-visualization

seq2seq_attn.py

# Find the attention mask tensor in function attention_decoder()-> attention()
# Add the attention mask tensor to ‘return’ statement of all the function that calls the attention_decoder(), 
# all the way up to model_with_buckets() function, which is the final function I use for bucket training.

def attention(query):
  """Put attention masks on hidden using hidden_features and query."""
  ds = []  # Results of attention reads will be stored here.

  # some code

  for a in xrange(num_heads):
    with variable_scope.variable_scope("Attention_%d" % a):
      # some code

      s = math_ops.reduce_sum(v[a] * math_ops.tanh(hidden_features[a] + y),
                              [2, 3])
      # This is the attention mask tensor we want to extract
      a = nn_ops.softmax(s)

      # some code

  # add 'a' to return function
  return ds, a

seq2seq_model_attn.py

# modified model.step() function and return masks tensor
self.outputs, self.losses, self.attn_masks = seq2seq_attn.model_with_buckets(…)

# use session.run() to evaluate attn masks
attn_out = session.run(self.attn_masks[bucket_id], input_feed)
attn_matrix = ...

predict_attn.py and eval.py

# Use the plot_attention function in eval.py to visual the 2D ndarray during prediction.

eval.plot_attention(attn_matrix[0:ty_cut, 0:tx_cut], X_label = X_label, Y_label = Y_label)

And probably in the future tensorflow will have better way to extract and visualize the attention weight map. Any thoughts?

这篇关于在 Tensorflow 中可视化注意力激活的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆