首页 分享 Python 梯度下降法

Python 梯度下降法

来源:花匠小妙招 时间:2024-12-04 04:34

题目描述:

自定义一个可微并且存在最小值的一元函数,用梯度下降法求其最小值。并绘制出学习率从0.1到0.9(步长0.1)时,达到最小值时所迭代的次数的关系曲线,根据该曲线给出简单的分析。

代码:

# -*- coding: utf-8 -*- ''' 遇到问题没人解答?小编创建了一个Python学习交流QQ群:778463939 寻找有志同道合的小伙伴,互帮互助,群里还有不错的视频学习教程和PDF电子书! ''' import numpy as np import matplotlib.pyplot as plt plot_x=np.linspace(-1,6,150) #在-1到6之间等距的生成150个数 plot_y=(plot_x-2.5)**2+3 # 同时根据plot_x来生成plot_y(y=(x-2.5)²+3) plt.plot(plot_x,plot_y) plt.show() ###定义一个求二次函数导数的函数dJ def dJ(x): return 2*(x-2.5) ###定义一个求函数值的函数J def J(x): try: return (x-2.5)**2+3 except: return float('inf') x=0.0 #随机选取一个起始点 eta=0.1 #eta是学习率,用来控制步长的大小 epsilon=1e-8 #用来判断是否到达二次函数的最小值点的条件 history_x=[x] #用来记录使用梯度下降法走过的点的X坐标 count=0 min=0 while True: gradient=dJ(x) #梯度(导数) last_x=x x=x-eta*gradient history_x.append(x) count=count+1 if (abs(J(last_x)-J(x)) <epsilon): #用来判断是否逼近最低点 min=x break plt.plot(plot_x,plot_y) plt.plot(np.array(history_x),J(np.array(history_x)),color='r',marker='*') #绘制x的轨迹 plt.show() print'min_x =',(min) print'min_y =',(J(min)) #打印到达最低点时y的值 print'count =',(count) sum_eta=[] result=[] for i in range(1,10,1): x=0.0 #随机选取一个起始点 eta=i*0.1 sum_eta.append(eta) epsilon=1e-8 #用来判断是否到达二次函数的最小值点的条件 num=0 min=0 while True: gradient=dJ(x) #梯度(导数) last_x=x x=x-eta*gradient num=num+1 if (abs(J(last_x)-J(x)) <epsilon): #用来判断是否逼近最低点 min=x break result.append(num)#记录学习率从0.1到0.9(步长0.1)时,达到最小值时所迭代的次数 plt.scatter(sum_eta,result,c='r') plt.plot(sum_eta,result,c='r') plt.title("relation") plt.xlabel("eta") plt.ylabel("count") plt.show print(result)

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475

运行结果:
在这里插入图片描述
在这里插入图片描述
结果分析:
函数y=(x-2.5)²+3从学习率和迭代次数的关系图上我们可以知道当学习率较低时迭代次数较多,随着学习率的增大,迭代次数开始逐渐减少,当学习率为0.5时迭代次数最少,之后随着学习率的增加,迭代次数开始增加,当学习率为0.9时迭代次数和0.1时相等。关于0.5成对称分布。

相关知识

python 线性回归
神经网络与深度学习
如何翻译和解释机器学习术语?请看 Google 官方答案 下
【免费】TensorFlow0.12.1版本的mac操作系统下载资源
基于python的鸢尾花二分类
神经网络与深度学习(五)前馈神经网络(3)鸢尾花分类
干货来袭,谷歌最新机器学习术语表(下)
谷歌出品!机器学习常用术语总结
卷积神经网络实现鸢尾花数据分类python代码实现
学好Python=基础学科能力+业务知识+ IT技术

网址: Python 梯度下降法 https://www.huajiangbk.com/newsview859408.html

所属分类:花卉
上一篇: 2024上海外滩旅游攻略之外滩历
下一篇: 纪念花木兰颁奖词.doc

推荐分享