神经网络常见激活函数:Swish函数详解
神经网络常见激活函数:Swish函数详解
Swish函数是深度学习领域中一种重要的激活函数,它结合了ReLU的稀疏性和Sigmoid的平滑性,具有自适应性和非单调性等特点。本文将详细介绍Swish函数的定义、导数、图像表示、优缺点,并提供在PyTorch和TensorFlow中的实现代码。
Swish函数
函数+导函数
Swish函数的定义如下:
$$
\begin{aligned}
\rm Swish(x) &= x \cdot \sigma(\beta x) \
&= \frac{x}{1 + e^{-\beta x}}
\end{aligned}
$$
其中,$\sigma(\cdot)$是sigmoid函数,$\beta$是一个可学习的参数或固定的超参数,控制着函数的形状。
Swish函数的导数为:
$$
\begin{aligned}
\frac{d}{dx} \rm Swish &= \left(x \cdot\sigma(\beta x) \right)' \
&=\sigma(\beta x) + x\cdot \left(\sigma(\beta x) \right)'\
\quad \
\because &\quad \sigma'(u)= \sigma(u)(1-\sigma(u)) \
\quad \
\therefore &=\sigma(\beta x) + \beta x \cdot \sigma(\beta x) \cdot (1 - \sigma(\beta x)) \
&= \frac{1 + e^{-\beta x} + \beta x e^{-\beta x}}{(1 + e^{-\beta x})^2}
\end{aligned}
$$
函数和导函数图像
以下是Swish函数及其导数的图像表示:
import numpy as np
from matplotlib import pyplot as plt
# 定义 Swish 函数
def swish(x, beta=1.0):
return x * (1 / (1 + np.exp(-beta * x)))
# 定义 Swish 的导数
def swish_derivative(x, beta=1.0):
sigmoid = 1 / (1 + np.exp(-beta * x))
return sigmoid * (1 + x * beta * (1 - sigmoid))
# 生成数据
x = np.linspace(-5, 5, 1000)
beta = 1.0 # 可以调整 beta 的值
y = swish(x, beta)
y1 = swish_derivative(x, beta)
# 绘制图形
plt.figure(figsize=(12, 8))
ax = plt.gca()
plt.plot(x, y, label=f'Swish (β={beta})')
plt.plot(x, y1, label='Derivative')
plt.title(f'Swish (β={beta}) and Derivative')
# 设置上边和右边无边框
ax.spines['right'].set_color('none')
ax.spines['top'].set_color('none')
# 设置 x 坐标刻度数字或名称的位置
ax.xaxis.set_ticks_position('bottom')
# 设置边框位置
ax.spines['bottom'].set_position(('data', 0))
ax.yaxis.set_ticks_position('left')
ax.spines['left'].set_position(('data', 0))
plt.legend(loc=2)
plt.savefig('./swish.jpg')
plt.show()
优缺点
Swish函数的优点
- 非单调性:Swish是一个非单调函数,这使得它在负输入时仍然可以保持小梯度,同时在正输入时能够有效激活。
- 自适应性:Swish是一个自适应激活函数,能够根据输入动态调整输出,这有助于改善优化和泛化。
- 性能优势:在某些深度学习模型中,Swish已经被证明比ReLU及其变体表现更好,尤其是在深层网络中。
- 平滑过渡:Swish函数在整个定义域上是平滑的,这有助于避免梯度消失问题,同时提供更稳定的梯度流动。
Swish的缺点
- 计算复杂度高:与ReLU相比,Swish的计算成本更高,因为它涉及到sigmoid函数的计算。
- 训练时间增加:由于计算复杂度的增加,Swish可能会导致训练时间变长。
- 适用性有限:虽然Swish在某些任务中表现出色,但并非所有任务都能从中受益。需要根据具体任务和数据集进行实验和调整。
- Swish函数是一种较新的激活函数,其非单调性和自适应性使其在某些深度学习模型中表现优异,尤其是在深层网络中。然而,其计算复杂度较高,可能会增加训练时间,因此在实际应用中需要根据具体任务进行权衡和选择。
PyTorch中的Swish函数
在PyTorch中实现Swish函数的代码如下:
import torch
import torch.nn.functional as F
# 定义 Swish 函数
def swish(x, beta=1):
return x * torch.sigmoid(beta * x)
x = torch.randn(2) # 生成一个随机张量作为输入
swish_x = swish(x, beta=1) # 应用 Swish 函数
print(f"x: \n{x}")
print(f"swish_x:\n{swish_x}")
"""
输出:
x:
tensor([ 1.3991, -0.1989])
swish_x:
tensor([ 1.1221, -0.0896])
"""
TensorFlow中的Swish函数
在TensorFlow中实现Swish函数的代码如下:
import tensorflow as tf
# 创建 SWISH 激活函数
swish = tf.keras.activations.swish
# 生成随机输入
# x = tf.random.normal([2])
x = [1.3991, -0.1989]
# 应用 SWISH 激活函数
beta = 1
swish_x = swish(beta * x)
print(f"x: \n{x}")
print(f"swish_x:\n{swish_x}")
"""
输出:
x:
[1.3991, -0.1989]
swish_x:
[ 1.1221356 -0.08959218]
"""