使用WGAN-GP生成一维滚动轴承振动数据样本
创作时间:
作者:
@小白创作中心
使用WGAN-GP生成一维滚动轴承振动数据样本
引用
CSDN
1.
https://blog.csdn.net/QQ_1309399183/article/details/144684295
本文将介绍如何使用WGAN-GP(Wasserstein GAN with Gradient Penalty)生成一维滚动轴承振动数据样本。我们将以西储大学(CWRU)数据集为例,并提供一个基于训练好的权重参数文件进行测试的代码。
步骤概述
- 数据集准备
- 构建WGAN-GP模型
- 加载预训练权重
- 生成指定故障类型的数据
- 可视化生成的数据
详细步骤
1. 数据集准备
确保你的数据集已经按照上述格式准备好,并且包含相应的文件目录结构。
bearing_datasets/
├── CWRU/
│ ├── normal.mat
│ ├── inner_race_fault.mat
│ └── ...
└── generated_data/
├── normal.npy
├── inner_race_fault.npy
└── ...
2. 构建WGAN-GP模型
使用Keras构建一个简单的WGAN-GP模型。
import os
import numpy as np
import pandas as pd
import scipy.io as sio
from sklearn.preprocessing import StandardScaler, LabelEncoder
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.constraints import Constraint
import tensorflow as tf
# Step 1: Data Preparation
# Ensure your dataset is organized as described above.
# Load and preprocess data
def load_cwru_data(dataset_path):
features = []
labels = []
for filename in os.listdir(dataset_path):
if filename.endswith('.mat'):
data = sio.loadmat(os.path.join(dataset_path, filename))
signal = data[list(data.keys())[-1]].flatten()
label = filename.split('_')[0] # Assuming label is part of the filename
features.append(signal)
labels.append(label)
return np.array(features), np.array(labels)
cwru_features, cwru_labels = load_cwru_data('bearing_datasets/CWRU')
# Encode labels
label_encoder = LabelEncoder()
cwru_labels_encoded = label_encoder.fit_transform(cwru_labels)
# Normalize features
scaler = StandardScaler()
cwru_features_normalized = scaler.fit_transform(cwru_features)
# Reshape features to include time dimension
cwru_features_reshaped = cwru_features_normalized.reshape(-1, cwru_features_normalized.shape[1], 1)
# Step 2: Build WGAN-GP Model
class ClipConstraint(Constraint):
def __init__(self, clip_value):
self.clip_value = clip_value
def __call__(self, weights):
return tf.clip_by_value(weights, -self.clip_value, self.clip_value)
def build_generator(latent_dim, output_shape):
model = Sequential()
model.add(Dense(128, activation='relu', input_dim=latent_dim))
model.add(Dense(256, activation='relu'))
model.add(Dense(512, activation='relu'))
model.add(Dense(output_shape, activation='tanh'))
return model
def build_discriminator(input_shape):
model = Sequential()
model.add(Flatten(input_shape=input_shape))
model.add(Dense(512, activation='relu', kernel_constraint=ClipConstraint(0.01)))
model.add(Dense(256, activation='relu', kernel_constraint=ClipConstraint(0.01)))
model.add(Dense(1))
return model
def wasserstein_loss(y_true, y_pred):
return tf.reduce_mean(y_true * y_pred)
def gradient_penalty_loss(y_true, y_pred, averaged_samples, weight):
gradients = tf.gradients(y_pred, averaged_samples)[0]
gradients_sqr = tf.square(gradients)
gradient_penalty = tf.reduce_mean(tf.reduce_sum(gradients_sqr, axis=np.arange(1, len(gradients_sqr.shape))))
return weight * gradient_penalty
latent_dim = 100
output_shape = cwru_features_reshaped.shape[1]
generator = build_generator(latent_dim, output_shape)
discriminator = build_discriminator((output_shape, 1))
discriminator.compile(loss=wasserstein_loss, optimizer=Adam(lr=0.0001, beta_1=0.5), metrics=['accuracy'])
discriminator.trainable = False
gan_input = Input(shape=(latent_dim,))
generated_signal = generator(gan_input)
validity = discriminator(generated_signal)
combined = Model(gan_input, validity)
combined.compile(loss=wasserstein_loss, optimizer=Adam(lr=0.0001, beta_1=0.5))
3. 加载预训练权重
假设你已经有了预训练的权重文件 generator_weights.h5 和 discriminator_weights.h5。
# Load pre-trained weights
generator.load_weights('generator_weights.h5')
discriminator.load_weights('discriminator_weights.h5')
4. 生成指定故障类型的数据
生成指定故障类型的数据,并保存到 generated_data 目录中。
# Function to generate data
def generate_data(generator, latent_dim, num_samples, fault_type, label_encoder, output_dir):
noise = np.random.normal(0, 1, (num_samples, latent_dim))
generated_signals = generator.predict(noise)
# Decode labels to get fault type index
fault_index = label_encoder.transform([fault_type])[0]
# Save generated signals
np.save(os.path.join(output_dir, f"{fault_type}.npy"), generated_signals)
return generated_signals
# Generate data for a specific fault type
fault_type = 'inner_race_fault' # Change this to any fault type you want to generate
num_samples = 1000 # Number of samples to generate
generated_signals = generate_data(generator, latent_dim, num_samples, fault_type, label_encoder, 'generated_data')
5. 可视化生成的数据
可视化生成的数据并与真实数据进行对比。
# Plot real and generated data
import matplotlib.pyplot as plt
# Select a random sample from real data
real_sample_idx = np.random.randint(0, len(cwru_features_reshaped))
real_sample = cwru_features_reshaped[real_sample_idx].flatten()
# Select a random sample from generated data
generated_sample_idx = np.random.randint(0, len(generated_signals))
generated_sample = generated_signals[generated_sample_idx].flatten()
# Plot
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(real_sample)
plt.title('Real Sample')
plt.xlabel('Time')
plt.ylabel('Amplitude')
plt.subplot(1, 2, 2)
plt.plot(generated_sample)
plt.title('Generated Sample')
plt.xlabel('Time')
plt.ylabel('Amplitude')
plt.tight_layout()
plt.show()
完整代码
以下是完整的代码示例,包含了从数据加载、模型构建、加载预训练权重、生成数据到结果可视化的所有步骤。
运行脚本
在终端中运行以下命令来执行整个流程:
python main.py
总结
以上文档包含了从数据集准备、模型构建、加载预训练权重、生成数据到结果可视化的所有步骤。希望这些详细的信息和代码能够帮助你顺利实施和优化你的滚动轴承故障诊断系统。
热门推荐
低温环境下柴油发电机组如何防止冷启动困难
一公顷与平方公里的关系:面积单位转换解析
C语言中计算矩阵对角线元素和的完整指南
春日亲子运动安排
国防科大发布首个SAR图像目标识别基础模型SARATR-X 1.0
MySQL 8教程:详解如何更改和重置用户密码
分宜县十大旅游景点
玉米田里的“害草”地锦草:既是止血良方,又是通乳妙药
养生必备:藏红花泡水秘籍,用量水温巧掌控,正确吃法轻松懂!
国足惨遭两连败,世预赛出线渺茫——背后的原因和未来展望
DeepSeek对律师行业的影响
以案释法:一起邻里纠纷引发的命案如何妥善化解
哪些因素会影响班级考核制度的效果?
银行理财产品年化收益率超5%,甚至10%,怎么做到的?能买吗?
探寻科学与哲学的交集:拓扑量子场论的深刻启示
美丽or伤害?有没有染发而不伤发的办法?
1996年NBA选秀:科比争夺战背后的智慧与勇气
泪道阻塞怎么办?
新学期当班主任,有点慌?五个策略助你底气十足
确保锂电在数据中心中安全应用的探讨
石正丽团队新发现:两种冠状病毒通过ACE2受体入侵宿主
探视行政拘留人员的手续及注意事项
软件开发生命周期阶段:从需求分析到维护的详细指南
《西游记》中的“花果山”就在连云港!
摩尔定律再进化,2纳米之后芯片如何继续突破物理极限
怎么解析一段js代码
用中医智慧调理情志:恬淡虚无,精神内守
东京出发!静冈东海道沿线景点全攻略
探讨"我行我素"个性在现代社会中的适应与挑战。
水结冰后为何体积变大?