북극곰의 개발일기

모두를 위한 딥러닝 day 2 - 실습 (Linear Regression)





posted by purplebeen on Fri Dec 21 2018 21:56:53 GMT+0900 (KST) in AI


일반적인 방법을 이용한 구현

원본 데이터

x y
1 1
2 2
3 3

Build graph using Tensorflow Operations

import tensorflow as tf
# X and Y data
x_train = [1, 2, 3]
y_train = [1, 2, 3]

W = tf.Variable(tf.random_normal([1]), name = 'weight')
b = tf.Variable(tf.random_normal([1]), name = 'bias')

#Our hypothesis Wx + b
hypothesis = x_train * W + b

cost / loss function

cost / loss function

#cost / loss function
cost = tf.reduce_mean(tf.square(hypothesis - y_train))

GradientDescent

#Minimize
optimizer = tf.train.GradientDescentOptimizer(learning_rate = 0.01)
train = optimizer.minimize(cost)

Run / update graph and get results

#Launch the graph in a session
sess = tf.Session()

#Initializes global variables in the graph (W와 b를 사용하기 위함)
sess.run(tf.global_variables_initializer())

#Fit the line
for step in range(2001):
    sess.run(train)
    if step % 20 == 0:
        print(step, sess.run(cost), sess.run(W), sess.run(b))

Placeholder를 이용한 구현

원본 데이터

x y
1 2.1
2 3.1
3 4.1
4 5.1
5 6.1
import tensorflow as tf
W = tf.Variable(tf.random_normal([1]), name = 'weight')
b = tf.Variable(tf.random_normal([1]), name = 'bias')
X = tf.placeholder(tf.float32, shape=[None])
Y = tf.placeholder(tf.float32, shape=[None])

# Our hypothesis WX + b
hypothesis = X * W + b

# cost / Loss function
cost = tf.reduce_mean(tf.square(hypothesis - Y))

#Minimize
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
train = optimizer.minimize(cost)

# Lauch the graph in a session
sess = tf.Session()
# Initializes global variables in the graph.
sess.run(tf.global_variables_initializer())

# Fit the line with new trainging data
for step in range(2001):
    cost_val, W_val, b_val, _ = sess.run([cost, W, b, train], 
                                        feed_dict = {X : [1,2,3,4,5],
                                                    Y : [2.1, 3.1, 4.1, 5.1, 6.1]})
    if step % 20 == 0:
        print(step, cost_val, W_val, b_val)