使用WGAN-GP生成一维滚动轴承振动数据样本
创作时间:
作者:
@小白创作中心
使用WGAN-GP生成一维滚动轴承振动数据样本
引用
CSDN
1.
https://blog.csdn.net/QQ_1309399183/article/details/144684295
本文将介绍如何使用WGAN-GP(Wasserstein GAN with Gradient Penalty)生成一维滚动轴承振动数据样本。我们将以西储大学(CWRU)数据集为例,并提供一个基于训练好的权重参数文件进行测试的代码。
步骤概述
- 数据集准备
- 构建WGAN-GP模型
- 加载预训练权重
- 生成指定故障类型的数据
- 可视化生成的数据
详细步骤
1. 数据集准备
确保你的数据集已经按照上述格式准备好,并且包含相应的文件目录结构。
bearing_datasets/
├── CWRU/
│ ├── normal.mat
│ ├── inner_race_fault.mat
│ └── ...
└── generated_data/
├── normal.npy
├── inner_race_fault.npy
└── ...
2. 构建WGAN-GP模型
使用Keras构建一个简单的WGAN-GP模型。
import os
import numpy as np
import pandas as pd
import scipy.io as sio
from sklearn.preprocessing import StandardScaler, LabelEncoder
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.constraints import Constraint
import tensorflow as tf
# Step 1: Data Preparation
# Ensure your dataset is organized as described above.
# Load and preprocess data
def load_cwru_data(dataset_path):
features = []
labels = []
for filename in os.listdir(dataset_path):
if filename.endswith('.mat'):
data = sio.loadmat(os.path.join(dataset_path, filename))
signal = data[list(data.keys())[-1]].flatten()
label = filename.split('_')[0] # Assuming label is part of the filename
features.append(signal)
labels.append(label)
return np.array(features), np.array(labels)
cwru_features, cwru_labels = load_cwru_data('bearing_datasets/CWRU')
# Encode labels
label_encoder = LabelEncoder()
cwru_labels_encoded = label_encoder.fit_transform(cwru_labels)
# Normalize features
scaler = StandardScaler()
cwru_features_normalized = scaler.fit_transform(cwru_features)
# Reshape features to include time dimension
cwru_features_reshaped = cwru_features_normalized.reshape(-1, cwru_features_normalized.shape[1], 1)
# Step 2: Build WGAN-GP Model
class ClipConstraint(Constraint):
def __init__(self, clip_value):
self.clip_value = clip_value
def __call__(self, weights):
return tf.clip_by_value(weights, -self.clip_value, self.clip_value)
def build_generator(latent_dim, output_shape):
model = Sequential()
model.add(Dense(128, activation='relu', input_dim=latent_dim))
model.add(Dense(256, activation='relu'))
model.add(Dense(512, activation='relu'))
model.add(Dense(output_shape, activation='tanh'))
return model
def build_discriminator(input_shape):
model = Sequential()
model.add(Flatten(input_shape=input_shape))
model.add(Dense(512, activation='relu', kernel_constraint=ClipConstraint(0.01)))
model.add(Dense(256, activation='relu', kernel_constraint=ClipConstraint(0.01)))
model.add(Dense(1))
return model
def wasserstein_loss(y_true, y_pred):
return tf.reduce_mean(y_true * y_pred)
def gradient_penalty_loss(y_true, y_pred, averaged_samples, weight):
gradients = tf.gradients(y_pred, averaged_samples)[0]
gradients_sqr = tf.square(gradients)
gradient_penalty = tf.reduce_mean(tf.reduce_sum(gradients_sqr, axis=np.arange(1, len(gradients_sqr.shape))))
return weight * gradient_penalty
latent_dim = 100
output_shape = cwru_features_reshaped.shape[1]
generator = build_generator(latent_dim, output_shape)
discriminator = build_discriminator((output_shape, 1))
discriminator.compile(loss=wasserstein_loss, optimizer=Adam(lr=0.0001, beta_1=0.5), metrics=['accuracy'])
discriminator.trainable = False
gan_input = Input(shape=(latent_dim,))
generated_signal = generator(gan_input)
validity = discriminator(generated_signal)
combined = Model(gan_input, validity)
combined.compile(loss=wasserstein_loss, optimizer=Adam(lr=0.0001, beta_1=0.5))
3. 加载预训练权重
假设你已经有了预训练的权重文件 generator_weights.h5
和 discriminator_weights.h5
。
# Load pre-trained weights
generator.load_weights('generator_weights.h5')
discriminator.load_weights('discriminator_weights.h5')
4. 生成指定故障类型的数据
生成指定故障类型的数据,并保存到 generated_data
目录中。
# Function to generate data
def generate_data(generator, latent_dim, num_samples, fault_type, label_encoder, output_dir):
noise = np.random.normal(0, 1, (num_samples, latent_dim))
generated_signals = generator.predict(noise)
# Decode labels to get fault type index
fault_index = label_encoder.transform([fault_type])[0]
# Save generated signals
np.save(os.path.join(output_dir, f"{fault_type}.npy"), generated_signals)
return generated_signals
# Generate data for a specific fault type
fault_type = 'inner_race_fault' # Change this to any fault type you want to generate
num_samples = 1000 # Number of samples to generate
generated_signals = generate_data(generator, latent_dim, num_samples, fault_type, label_encoder, 'generated_data')
5. 可视化生成的数据
可视化生成的数据并与真实数据进行对比。
# Plot real and generated data
import matplotlib.pyplot as plt
# Select a random sample from real data
real_sample_idx = np.random.randint(0, len(cwru_features_reshaped))
real_sample = cwru_features_reshaped[real_sample_idx].flatten()
# Select a random sample from generated data
generated_sample_idx = np.random.randint(0, len(generated_signals))
generated_sample = generated_signals[generated_sample_idx].flatten()
# Plot
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(real_sample)
plt.title('Real Sample')
plt.xlabel('Time')
plt.ylabel('Amplitude')
plt.subplot(1, 2, 2)
plt.plot(generated_sample)
plt.title('Generated Sample')
plt.xlabel('Time')
plt.ylabel('Amplitude')
plt.tight_layout()
plt.show()
完整代码
以下是完整的代码示例,包含了从数据加载、模型构建、加载预训练权重、生成数据到结果可视化的所有步骤。
运行脚本
在终端中运行以下命令来执行整个流程:
python main.py
总结
以上文档包含了从数据集准备、模型构建、加载预训练权重、生成数据到结果可视化的所有步骤。希望这些详细的信息和代码能够帮助你顺利实施和优化你的滚动轴承故障诊断系统。
热门推荐
揭秘硼砂:非法食品添加剂的使用现状与识别方法
2025年6G技术路线图:四大关键领域取得重要进展
养狗全方位提升身心健康,还能促进社交与培养责任感
中国启动孤独症辅助犬服务,首批三只完成家庭匹配
北海市:海上丝绸之路起点的南珠传奇
北海市的气候秘密:四季旅游全攻略
极端环境下的生存密码:生物适应机制与农业育种新思路
国有独资公司薪资发放流程与规范
姜黄肉桂巧搭配,抗炎代谢双管齐下
国企薪酬制度改革方案怎么实施?
北大医药获批新型抗抑郁药,肠溶缓释技术提升疗效安全性
《道德经》四智慧破解职场压力:从无为到知足的东方解决方案
超越表象,顺应自然:《道德经》第四十一章的现代启示
《道德经》现代实践:从个人修道到校园教育
武隆天生三桥:三座天桥各具特色,电影取景地里的东方瑞士
重庆仙女山:东方瑞士的四季美景与多元玩法全攻略
汕尾的“茶”文化:从开灯日到咸茶
汕尾美食打卡:金町湾必吃清单
汕尾二马路:美食与海浴的完美邂逅
汕尾必打卡人气餐厅大揭秘:友谊大酒楼、明叔捞面、宝木肠粉
冯家江治理:北海生态焕发新生的绿色答卷
涠洲岛:北海旅游必打卡的自然奇观
北海:国家历史文化名城的魅力
北海旅游正当时,你还在等啥?
东方市水果批发站:严把食品安全关,守护百姓“果盘子”
二手车轮胎寿命评估攻略:从磨损到生产日期,全面检查要点详解
多汗症患者的饮食调理指南
电影《第二十条》引爆正当防卫热议
夏季多汗症高发,如何科学防汗?
多汗症与焦虑情绪的“致命”关联揭秘