简单线性回归上的Tensorflow [英] Tensorflow on simple linear regression
问题描述
我是机器学习和张量流的初学者.在尝试张量流的第一步中,我尝试了一个简单的多元线性回归.但是,该模型似乎停留在局部最小值.这是我的代码.
I am a beginner in machine learning and tensorflow. In the first step trying the tensorflow, I tried a simple multivariate linear regression. However, it seems the model stuck at a local minimum. Here is my code.
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=1)
return tf.Variable(initial)
# dataset
xx = np.random.randint(0,1000,[1000,3])/1000.
yy = xx[:,0] * 2 + xx[:,1] * 1.4 + xx[:,2] * 3
# model
x = tf.placeholder(tf.float32, shape=[None, 3])
y_ = tf.placeholder(tf.float32, shape=[None])
W1 = weight_variable([3, 1])
y = tf.matmul(x, W1)
# training and cost function
cost_function = tf.reduce_mean(tf.square(y - y_))
train_function = tf.train.AdamOptimizer(1e-2).minimize(cost_function)
# create a session
sess = tf.Session()
# train
sess.run(tf.initialize_all_variables())
for i in range(10000):
sess.run(train_function, feed_dict={x:xx, y_:yy})
if i % 1000 == 0:
print(sess.run(cost_function, feed_dict={x:xx, y_:yy}))
输出为:
14.8449
2.20154
2.18375
2.18366
2.18366
2.18366
2.18366
2.18366
2.18366
输出值(yy)的范围是0到6,因此,在知道没有噪声添加到数据集的情况下,均方误差2.18相当大. 我也尝试了GradientDescentOptimizer,其学习速率为0.1和1e-2,但是并没有太大改善结果.
The output value (yy) is ranging from 0 to 6, so having mean square error 2.18 is considerably large, knowing that there is no noise added to the dataset. I also tried GradientDescentOptimizer with learning rate 0.1 and 1e-2, but it does not improve the results much.
我的实现有什么问题吗?
Is there anything wrong with my implementation?
推荐答案
这是因为y
与y_
的形状不同. y
的形状为(1000,1),而y_
的形状为(1000).因此,当您减去它们时,就无意间创建了二维矩阵.
This is because y
is not the same shape as y_
. y
is of shape (1000, 1) and y_
is of shape (1000). So when you subtract them, you're inadvertently creating a 2-D matrix.
要解决此问题,请将您的费用函数更改为:
To fix it change your cost function to:
cost_function = tf.reduce_mean(tf.square(tf.squeeze(y) - y_))
这篇关于简单线性回归上的Tensorflow的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!