如何在最新版本的 Tensorflow 中使用 MultiVariateNormal 分布 [英] How to use a MultiVariateNormal distribution in the latest version of Tensorflow

查看:62
本文介绍了如何在最新版本的 Tensorflow 中使用 MultiVariateNormal 分布的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我需要使用 tf.contrib.distributions.MultivariateNormal 中的 MultiVariateNormal 分发但是在最新版本的 Tensorflow 中,上述分发不可用,这导致错误

有人可以指出哪个可用分布会采用均值和西格玛并给出多元正态分布.

解决方案

I need to use the MultiVariateNormal distribution from the tf.contrib.distributions.MultivariateNormal However in the latest version of Tensorflow the above distribution is not available which leads to an error

Can someone please point out which of the available distribution would take in mean and sigma and give the MultivariateNormal distribution.

解决方案

tf.contrib.distributions.MultivariateNormalFullCovariancedefines the Multivariate Normal distribution that is parameterized by themean vector (mu)and thecovariance matrix`.

An Example,

# Let mean vector and co-variance be:
mu = [1., 2] 
cov = [[ 1,  3/5],[ 3/5,  2]]

#Multivariate Normal distribution
gaussian = tf.contrib.distributions.MultivariateNormalFullCovariance(
           loc=mu,
           covariance_matrix=cov)

# Generate a mesh grid to plot the distributions
X, Y = tf.meshgrid(tf.range(-3, 3, 0.1), tf.range(-3, 3, 0.1))
idx = tf.concat([tf.reshape(X, [-1, 1]), tf.reshape(Y,[-1,1])], axis =1)
prob = tf.reshape(gaussian.prob(idx), tf.shape(X))

with tf.Session() as sess:
   p = sess.run(prob)
   m, c = sess.run([gaussian.mean(), gaussian.covariance()])
   # m is [1., 2.]
   # c is [[1, 0.6], [0.6, 2]]

这篇关于如何在最新版本的 Tensorflow 中使用 MultiVariateNormal 分布的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆