问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

神经网络常见激活函数:Swish函数详解

创作时间:
作者:
@小白创作中心

神经网络常见激活函数:Swish函数详解

引用
CSDN
1.
https://blog.csdn.net/hbkybkzw/article/details/145657284

Swish函数是深度学习领域中一种重要的激活函数,它结合了ReLU的稀疏性和Sigmoid的平滑性,具有自适应性和非单调性等特点。本文将详细介绍Swish函数的定义、导数、图像表示、优缺点,并提供在PyTorch和TensorFlow中的实现代码。

Swish函数

函数+导函数

Swish函数的定义如下:

$$
\begin{aligned}
\rm Swish(x) &= x \cdot \sigma(\beta x) \
&= \frac{x}{1 + e^{-\beta x}}
\end{aligned}
$$

其中,$\sigma(\cdot)$是sigmoid函数,$\beta$是一个可学习的参数或固定的超参数,控制着函数的形状。

Swish函数的导数为:

$$
\begin{aligned}
\frac{d}{dx} \rm Swish &= \left(x \cdot\sigma(\beta x) \right)' \
&=\sigma(\beta x) + x\cdot \left(\sigma(\beta x) \right)'\
\quad \
\because &\quad \sigma'(u)= \sigma(u)(1-\sigma(u)) \
\quad \
\therefore &=\sigma(\beta x) + \beta x \cdot \sigma(\beta x) \cdot (1 - \sigma(\beta x)) \
&= \frac{1 + e^{-\beta x} + \beta x e^{-\beta x}}{(1 + e^{-\beta x})^2}
\end{aligned}
$$

函数和导函数图像

以下是Swish函数及其导数的图像表示:

import numpy as np
from matplotlib import pyplot as plt

# 定义 Swish 函数
def swish(x, beta=1.0):
    return x * (1 / (1 + np.exp(-beta * x)))

# 定义 Swish 的导数
def swish_derivative(x, beta=1.0):
    sigmoid = 1 / (1 + np.exp(-beta * x))
    return sigmoid * (1 + x * beta * (1 - sigmoid))

# 生成数据
x = np.linspace(-5, 5, 1000)
beta = 1.0  # 可以调整 beta 的值
y = swish(x, beta)
y1 = swish_derivative(x, beta)

# 绘制图形
plt.figure(figsize=(12, 8))
ax = plt.gca()
plt.plot(x, y, label=f'Swish (β={beta})')
plt.plot(x, y1, label='Derivative')
plt.title(f'Swish (β={beta}) and Derivative')

# 设置上边和右边无边框
ax.spines['right'].set_color('none')
ax.spines['top'].set_color('none')

# 设置 x 坐标刻度数字或名称的位置
ax.xaxis.set_ticks_position('bottom')

# 设置边框位置
ax.spines['bottom'].set_position(('data', 0))
ax.yaxis.set_ticks_position('left')
ax.spines['left'].set_position(('data', 0))

plt.legend(loc=2)
plt.savefig('./swish.jpg')
plt.show()

优缺点

  • Swish函数的优点

    1. 非单调性:Swish是一个非单调函数,这使得它在负输入时仍然可以保持小梯度,同时在正输入时能够有效激活。
    2. 自适应性:Swish是一个自适应激活函数,能够根据输入动态调整输出,这有助于改善优化和泛化。
    3. 性能优势:在某些深度学习模型中,Swish已经被证明比ReLU及其变体表现更好,尤其是在深层网络中。
    4. 平滑过渡:Swish函数在整个定义域上是平滑的,这有助于避免梯度消失问题,同时提供更稳定的梯度流动。
  • Swish的缺点

    1. 计算复杂度高:与ReLU相比,Swish的计算成本更高,因为它涉及到sigmoid函数的计算。
    2. 训练时间增加:由于计算复杂度的增加,Swish可能会导致训练时间变长。
    3. 适用性有限:虽然Swish在某些任务中表现出色,但并非所有任务都能从中受益。需要根据具体任务和数据集进行实验和调整。
    4. Swish函数是一种较新的激活函数,其非单调性和自适应性使其在某些深度学习模型中表现优异,尤其是在深层网络中。然而,其计算复杂度较高,可能会增加训练时间,因此在实际应用中需要根据具体任务进行权衡和选择。

PyTorch中的Swish函数

在PyTorch中实现Swish函数的代码如下:

import torch
import torch.nn.functional as F

# 定义 Swish 函数
def swish(x, beta=1):
    return x * torch.sigmoid(beta * x)

x = torch.randn(2)  # 生成一个随机张量作为输入
swish_x = swish(x, beta=1)  # 应用 Swish 函数
print(f"x: \n{x}")
print(f"swish_x:\n{swish_x}")

"""
输出:
x: 
tensor([ 1.3991, -0.1989])
swish_x:
tensor([ 1.1221, -0.0896])
"""

TensorFlow中的Swish函数

在TensorFlow中实现Swish函数的代码如下:

import tensorflow as tf

# 创建 SWISH 激活函数
swish = tf.keras.activations.swish

# 生成随机输入
# x = tf.random.normal([2])
x = [1.3991, -0.1989]

# 应用 SWISH 激活函数
beta = 1
swish_x = swish(beta * x)
print(f"x: \n{x}")
print(f"swish_x:\n{swish_x}")

"""
输出:
x: 
[1.3991, -0.1989]
swish_x:
[ 1.1221356  -0.08959218]
"""
© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号