二分法、梯度下降法、牛顿法求解根号
二分法、梯度下降法、牛顿法求解根号
本文将介绍三种求解根号的算法:梯度下降法、牛顿迭代法和二分法。这些算法在数学和计算机科学中都有广泛的应用,通过具体的代码实现和示例,帮助读者更好地理解这些算法的原理和应用。
1. 梯度下降法(Gradient descent)
若要求$\sqrt{2}$,即要求解$x^2 - 2 = 0$的根,也就是函数
$$L = (x^2 - 2)^2$$
取极小值时 $x$ 的取值。这个也就对应机器学习中的损失函数。
要寻找损失函数的最低点就是找到曲线的最低点。在这里,我们使用了微积分里导数,通过求出函数导数的值,从而找到函数下降的方向或者是最低点(极值点)。
$$g(x) = \frac{dL}{dx} = 4x^3 - 8x = 4x(x^2 - 2)$$
给$x$一个初始值,然后不断通过下式来更新$x$就可以逐渐逼近最优的$x$,这里的$a$代表步长,也就是学习率。
$$x^{n+1} = x^n - a g(x^n)$$
import random
import matplotlib.pyplot as plt
class Solution():
def gradient_descent(self, n):
# 随机初始化
x = float(random.randint(1, 100))
# 学习率
lr = 0.00001
# 记录损失
loss = []
# 损失阈值
while (abs(x ** 2 - n) > 0.0000000001):
# x(n+1) = x(n) - lr * g(x(n))
x = x - lr * 4 * x * (x ** 2 - n)
# 记录损失
loss.append((x ** 2 - n)**2)
return loss, x
if __name__ == '__main__':
n = 100
loss, a = Solution().gradient_descent(n)
print(a)
# 画损失图
x = range(len(loss))
plt.plot(x, loss, color='b')
plt.xlim(0, 1000)
plt.show()
损失变化图
2. 牛顿迭代法(Newton’s method)
它是牛顿在17世纪提出的一种在实数域和复数域上近似求解方程的方法。
多数方程不存在求根公式,因此求精确根非常困难,甚至不可解,从而寻找方程的近似根就显得特别重要。方法使用函数的泰勒级数的前面几项来寻找方程的根。牛顿迭代法是求方程根的重要方法之一,其最大优点是在方程的单根附近具有平方收敛,而且该法还可以用来求方程的重根、复根,此时线性收敛,但是可通过一些方法变成超线性收敛。另外该方法广泛用于计算机编程中。
把$f(x)$在点$x_0$的某邻域内展开成泰勒级数。
$$f(x) = f(x_0) + f'(x_0)(x-x_0) + \frac{f''(x_0)(x-x_0)^2}{2!} + \cdots + \frac{f^{(n)}(x_0)(x-x_0)^n}{n!} + R_n(x)$$
取其线性部分(即泰勒展开的前两项),这里我们要求解$f(x) = 0$带入得
$$0 = f(x_0) + f'(x_0)(x-x_0)$$
得
$$x = x_0 - \frac{f(x_0)}{f'(x_0)}$$
这样,得到牛顿迭代法的一个迭代关系式:
$$x_{n+1} = x_n - \frac{f(x_n)}{f'(x_n)}$$
由$f(x) = x^2 - a$,可以得到
$$x_{n+1} = x_n - \frac{x_n^2 - a}{2x_n} = \frac{x_n^2 + a}{2x_n} = \frac{x_n + a/x_n}{2}$$
这种方法可以很有效地求出根号$a$的近似值:首先随便猜一个近似值$x$,然后不断令$x$等于$x$和$a/x$的平均数,迭代个六七次后$x$的值就已经相当精确了。
例如,我想求根号$2$等于多少。假如我猜测的结果为$4$,虽然错的离谱,但你可以看到使用牛顿迭代法后这个值很快就趋近于根号$2$了:
$$(4 + 2/4) / 2 = 2.25$$
$$(2.25 + 2/2.25) / 2 = 1.56944…$$
$$(1.56944…+ 2/1.56944…) / 2 = 1.42189…$$
$$(1.42189…+ 2/1.42189…) / 2 = 1.41423…$$
$…$
这种算法的原理很简单,我们仅仅是不断用$(x, f(x))$的切线来逼近方程的根。
class sqrt(object):
def s(self, x):
a = x
while a * a > x:
a = (a + x / a) / 2
print(a)
if __name__ == '__main__':
x = 169
sqrt().s(x)
85.0
43.49411764705882
23.68985027605849
15.411853548944432
13.188719595702175
13.00135021013767
13.000000070110696
13.0
3. 二分法
一个数$a$的平方根小于等于$a$,使用二分法解决如下
class Solution3():
def mySqrt(self, x):
if x == 0:
return 0
if x == 1:
return 1
left = 1
right = x
while left <= right:
mid = left + (right - left) // 2
if mid * mid == x:
return mid
elif mid * mid > x:
right = mid - 0.0001
else:
left = mid + 0.0001
return right
if __name__ == '__main__':
print(Solution3().mySqrt(3))