文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

Python3 完整实现DNN

2023-01-31 06:58

关注
    完整实现DNN,包括前向传播和反向传播。实现一个2次函数的拟合。
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun May  6 16:11:40 2018

@author: wsw
"""

# construct simple DNN
import numpy as np
import matplotlib.pyplot as plt
import sys

def generate_data():
    x = np.linspace(-2,2,100)[np.newaxis,:]
    noise = np.random.normal(0.0,0.5,size=(1,100))
    y = x**2+noise
    return x,y


class DNN():
    
    def __init__(self,input_nodes=1,hidden1_nodes=4,hidden2_nodes=4,output_nodes=1):
        self.input_nodes = input_nodes
        self.hidden1_nodes = hidden1_nodes
        self.hidden2_nodes = hidden2_nodes
        self.output_nodes = output_nodes
        self.build_DNN()
        
    def build_DNN(self):
        np.random.seed(1)
        # Layer1 parameter
        self.w1 = np.random.normal(0.0,0.1,size=(self.hidden1_nodes,self.input_nodes))
        self.b1 = np.zeros(shape=(self.hidden1_nodes,1))
        # Layer2 parameter
        self.w2 = np.random.normal(0.0,0.2,size=(self.hidden2_nodes,self.hidden1_nodes))
        self.b2 = np.ones(shape=(self.hidden2_nodes,1))
        # Layer3 parameter
        self.w3 = np.random.normal(0.0,0.5,size=(self.output_nodes,self.hidden2_nodes))
        self.b3 = np.zeros(shape=(self.output_nodes,1))
        
    def forwardPropagation(self,inputs):
        self.z1 = np.matmul(self.w1,inputs) + self.b1
        self.a1 = 1/(1+np.exp(-self.z1))
        self.z2 = np.matmul(self.w2,self.a1) + self.b2
        self.a2 = 1/(1+np.exp(-self.z2))
        self.z3 = np.matmul(self.w3,self.a2) + self.b3
        self.a3 = self.z3
    
    def backwardPropagation(self,da,a,a_1,w,b,last=False):
        '''
        da:current layer activation output partial devirate result
        a:current layer activation output
        a_1:previous layer of current layer activation output
        w:current parameter
        b:current bias
        '''
        # dz = da/dz
        if last:
            dz = da
        else:
            dz = a*(1-a)*da
        # dw = dz/dw
        nums = da.shape[1]
        dw = np.matmul(dz,a_1.T)/nums
        db = np.mean(dz,axis=1,keepdims=True)
        # da_1 = dz/da_1
        da_1 = np.matmul(w.T,dz)
        
        w -= 0.5*dw
        b -= 0.5*db
        return da_1
    
    def train(self,x,y,max_iter=50000):
        for i in range(max_iter):
            self.forwardPropagation(x)
            #print(self.a3)
            loss = 0.5*np.mean((self.a3-y)**2)
            da = self.a3-y
            da_2 = self.backwardPropagation(da,self.a3,self.a2,self.w3,self.b3,True)
            da_1 = self.backwardPropagation(da_2,self.a2,self.a1,self.w2,self.b2)
            da_0 = self.backwardPropagation(da_1,self.a1,x,self.w1,self.b1)
            self.view_bar(i+1,max_iter,loss)
        return self.a3
    
    def view_bar(self,step,total,loss):
        rate = step/total
        rate_num = int(rate*40)
        r = '\rstep-%d loss value-%.4f[%s%s]\t%d%% %d/%d'%(step,loss,'>'*rate_num,'-'*(40-rate_num),
                                      int(rate*100),step,total)
        sys.stdout.write(r)
        sys.stdout.flush()
        
if __name__ == '__main__':
    x,y = generate_data()
    plt.scatter(x,y,c='r')
    plt.ion()
    print('plot') 
    dnn = DNN()
    predict = dnn.train(x,y)
    plt.plot(x.flatten(),predict.flatten(),'-')
    plt.show()
    

      
        
运行结果:
阅读原文内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯