文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

TensorFlow实现简单线性回归

2024-04-02 19:55

关注

本文实例为大家分享了TensorFlow实现简单线性回归的具体代码,供大家参考,具体内容如下

简单的一元线性回归

一元线性回归公式:

其中x是特征:[x1,x2,x3,…,xn,]T
w是权重,b是偏置值

代码实现

导入必须的包

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os

# 屏蔽warning以下的日志信息
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

产生模拟数据

def generate_data():
    x = tf.constant(np.array([i for i in range(0, 100, 5)]).reshape(-1, 1), tf.float32)
    y = tf.add(tf.matmul(x, [[1.3]]) + 1, tf.random_normal([20, 1], stddev=30))
    return x, y

x是100行1列的数据,tf.matmul是矩阵相乘,所以权值设置成二维的。
设置的w是1.3, b是1

实现回归

def myregression():
    """
    自实现线性回归
    :return:
    """
    x, y = generate_data()
    #     建立模型  y = x * w + b
    # w 1x1的二维数据
    w = tf.Variable(tf.random_normal([1, 1], mean=0.0, stddev=1.0), name='weight_a')
    b = tf.Variable(0.0, name='bias_b')

    y_predict = tf.matmul(x, a) + b

    # 建立损失函数
    loss = tf.reduce_mean(tf.square(y_predict - y))
    
    # 训练
    train_op = tf.train.GradientDescentOptimizer(0.1).minimize(loss=loss)

    # 初始化全局变量
    init_op = tf.global_variables_initializer()

  
    with tf.Session() as sess:
        sess.run(init_op)
        print('初始的权重:%f偏置值:%f' % (a.eval(), b.eval()))
    
        # 训练优化
        for i in range(1, 100):
            sess.run(train_op)
            print('第%d次优化的权重:%f偏置值:%f' % (i, a.eval(), b.eval()))
        # 显示回归效果
        show_img(x.eval(), y.eval(), y_predict.eval())

使用matplotlib查看回归效果

def show_img(x, y, y_pre):
    plt.scatter(x, y)
    plt.plot(x, y_pre)
    plt.show()

完整代码

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'


def generate_data():
    x = tf.constant(np.array([i for i in range(0, 100, 5)]).reshape(-1, 1), tf.float32)
    y = tf.add(tf.matmul(x, [[1.3]]) + 1, tf.random_normal([20, 1], stddev=30))
    return x, y


def myregression():
    """
    自实现线性回归
    :return:
    """
    x, y = generate_data()
    # 建立模型  y = x * w + b
    w = tf.Variable(tf.random_normal([1, 1], mean=0.0, stddev=1.0), name='weight_a')
    b = tf.Variable(0.0, name='bias_b')

    y_predict = tf.matmul(x, w) + b

    # 建立损失函数
    loss = tf.reduce_mean(tf.square(y_predict - y))
    # 训练
    train_op = tf.train.GradientDescentOptimizer(0.0001).minimize(loss=loss)

    init_op = tf.global_variables_initializer()

    with tf.Session() as sess:
        sess.run(init_op)
        print('初始的权重:%f偏置值:%f' % (w.eval(), b.eval()))
        # 训练优化
        for i in range(1, 35000):
            sess.run(train_op)
            print('第%d次优化的权重:%f偏置值:%f' % (i, w.eval(), b.eval()))
        show_img(x.eval(), y.eval(), y_predict.eval())


def show_img(x, y, y_pre):
    plt.scatter(x, y)
    plt.plot(x, y_pre)
    plt.show()


if __name__ == '__main__':
    myregression()

看看训练的结果(因为数据是随机产生的,每次的训练结果都会不同,可适当调节梯度下降的学习率和训练步数)

35000次的训练结果

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持编程网。

阅读原文内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     807人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     351人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     314人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     433人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     221人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯