如何在最新版本的 Tensorflow 中使用 MultiVariateNormal 分布 [英] How to use a MultiVariateNormal distribution in the latest version of Tensorflow
问题描述
我需要使用 tf.contrib.distributions.MultivariateNormal
中的 MultiVariateNormal
分发但是在最新版本的 Tensorflow 中,上述分发不可用,这导致错误
有人可以指出哪个可用分布会采用均值和西格玛并给出多元正态分布.
I need to use the MultiVariateNormal
distribution from the tf.contrib.distributions.MultivariateNormal
However in the latest version of Tensorflow the above distribution is not available which leads to an error
Can someone please point out which of the available distribution would take in mean and sigma and give the MultivariateNormal distribution.
tf.contrib.distributions.MultivariateNormalFullCovariance
defines the Multivariate Normal distribution that is parameterized by the
mean vector (mu)and the
covariance matrix`.
An Example,
# Let mean vector and co-variance be:
mu = [1., 2]
cov = [[ 1, 3/5],[ 3/5, 2]]
#Multivariate Normal distribution
gaussian = tf.contrib.distributions.MultivariateNormalFullCovariance(
loc=mu,
covariance_matrix=cov)
# Generate a mesh grid to plot the distributions
X, Y = tf.meshgrid(tf.range(-3, 3, 0.1), tf.range(-3, 3, 0.1))
idx = tf.concat([tf.reshape(X, [-1, 1]), tf.reshape(Y,[-1,1])], axis =1)
prob = tf.reshape(gaussian.prob(idx), tf.shape(X))
with tf.Session() as sess:
p = sess.run(prob)
m, c = sess.run([gaussian.mean(), gaussian.covariance()])
# m is [1., 2.]
# c is [[1, 0.6], [0.6, 2]]
这篇关于如何在最新版本的 Tensorflow 中使用 MultiVariateNormal 分布的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!