rp
2025-04-14 9036cd24c358c7b65963ebf59e2897b46d70a470
备份
已添加1个文件
210 ■■■■■ 文件已修改
testonly.py 210 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
testonly.py
¶Ô±ÈÐÂÎļþ
@@ -0,0 +1,210 @@
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_absolute_error, mean_squared_error
# Check numpy version
if np.__version__ < '1.23.5':
    print(f"Warning: Current numpy version {np.__version__} may cause compatibility issues.")
    print("Please upgrade numpy to version 1.23.5 or higher.")
try:
    import tensorflow as tf
    print(f"TensorFlow version: {tf.__version__}")
except ImportError as e:
    print("Error importing TensorFlow:", e)
    sys.exit(1)
# Check TensorFlow version
if tf.__version__ < '2.10.0':
    print(f"Warning: Current TensorFlow version {tf.__version__} may cause compatibility issues.")
    print("Please upgrade TensorFlow to version 2.10.0 or higher.")
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv1D, LSTM, Dense, Layer
from tensorflow.keras.callbacks import EarlyStopping
# ================================
# 1. æ•°æ®é¢„处理模块
# ================================
class DataPreprocessor:
    def __init__(self, file_path, target_col='Value'):
        # ä½¿ç”¨æ­£ç¡®çš„编码和列名
        self.df = pd.read_csv(file_path, encoding='utf-8-sig', parse_dates=['DateTime'], index_col='DateTime')
        self.target_col = target_col
        self.scaler = MinMaxScaler(feature_range=(0, 1))
    def preprocess(self, resample_freq='h'):  # ä½¿ç”¨å°å†™çš„'h'代替大写的'H'
        """数据重采样与归一化"""
        # åªé€‰æ‹©Value列进行处理
        value_series = self.df[self.target_col]
        # å¤„理非等间隔采样
        df_resampled = value_series.resample(resample_freq).mean()
        df_filled = df_resampled.fillna(method='ffill').fillna(method='bfill')  # åŒå‘å¡«å……
        # å½’一化处理
        self.scaled_data = self.scaler.fit_transform(df_filled.values.reshape(-1, 1))
        self.dates = df_filled.index
        return self.scaled_data, self.dates
    def create_sequences(self, data, look_back=72, pred_steps=120):
        """创建监督学习序列"""
        X, Y = [], []
        for i in range(len(data) - look_back - pred_steps):
            X.append(data[i:(i + look_back)])
            Y.append(data[(i + look_back):(i + look_back + pred_steps)])
        return np.array(X), np.array(Y)
# ================================
# 2. æ¨¡åž‹æž„建模块
# ================================
class TemporalAttention(Layer):
    """时间注意力机制层"""
    def __init__(self, units):
        super(TemporalAttention, self).__init__()
        self.W1 = Dense(units)
        self.W2 = Dense(units)
        self.V = Dense(1)
    def call(self, encoder_output, lstm_output):
        lstm_output = tf.expand_dims(lstm_output, 1)
        score = self.V(tf.nn.tanh(
            self.W1(encoder_output) + self.W2(lstm_output)))
        attention_weights = tf.nn.softmax(score, axis=1)
        context_vector = attention_weights * encoder_output
        context_vector = tf.reduce_sum(context_vector, axis=1)
        return context_vector, attention_weights
class SalinityPredictor:
    def __init__(self, look_back=72, pred_steps=120):
        self.look_back = look_back    # åŽ†å²çª—å£ï¼ˆå°æ—¶ï¼‰
        self.pred_steps = pred_steps  # é¢„测步长(5天=120小时)
    def build_model(self):
        """构建CNN-LSTM-Attention混合模型"""
        inputs = Input(shape=(self.look_back, 1))
        # CNN特征提取
        cnn = Conv1D(64, 3, activation='relu', padding='same')(inputs)
        cnn = Conv1D(32, 3, activation='relu', padding='same')(cnn)
        # LSTM时序建模
        lstm_out = LSTM(128, return_sequences=True)(cnn)
        lstm_out = LSTM(64, return_sequences=False)(lstm_out)
        # æ³¨æ„åŠ›æœºåˆ¶
        context_vector, _ = TemporalAttention(64)(cnn, lstm_out)
        # è¾“出层
        outputs = Dense(self.pred_steps)(context_vector)
        self.model = Model(inputs=inputs, outputs=outputs)
        self.model.compile(optimizer='adam', loss='mse')
        return self.model
    def dynamic_split(self, data, dates, cutoff_date):
        """动态划分训练集"""
        cutoff_idx = np.where(dates <= cutoff_date)[0][-self.look_back]
        train_data = data[:cutoff_idx]
        return train_data
    def train(self, X_train, y_train, epochs=200, batch_size=32):
        """模型训练"""
        early_stop = EarlyStopping(monitor='val_loss', patience=20)
        history = self.model.fit(
            X_train, y_train,
            epochs=epochs,
            batch_size=batch_size,
            validation_split=0.2,
            callbacks=[early_stop],
            verbose=1
        )
        return history
    def predict(self, last_sequence):
        """递归多步预测"""
        predictions = []
        current_seq = last_sequence.copy()
        for _ in range(self.pred_steps):
            pred = self.model.predict(current_seq[np.newaxis, :, :], verbose=0)
            predictions.append(pred[0][0])  # åªå–第一个预测值
            current_seq = np.roll(current_seq, -1, axis=0)
            current_seq[-1] = pred[0][0]  # ä½¿ç”¨å•个预测值更新序列
        return np.array(predictions)
# ================================
# 3. å®Œæ•´æµç¨‹æ‰§è¡Œ
# ================================
if __name__ == "__main__":
    # å‚数配置
    DATA_PATH = 'D:\opencv\.venv\一取水.csv'
    CUTOFF_DATE = '2024-12-20 00:00'  # ç”¨æˆ·æŒ‡å®šåˆ†å‰²æ—¶é—´ç‚¹
    LOOK_BACK = 72    # 3天历史数据
    PRED_STEPS = 120  # é¢„测5天
    # æ•°æ®é¢„处理
    preprocessor = DataPreprocessor(DATA_PATH)
    scaled_data, dates = preprocessor.preprocess()
    X, Y = preprocessor.create_sequences(scaled_data, LOOK_BACK, PRED_STEPS)
    # æ¨¡åž‹æž„建
    predictor = SalinityPredictor(LOOK_BACK, PRED_STEPS)
    model = predictor.build_model()
    model.summary()
    # åŠ¨æ€è®­ç»ƒ
    train_data = predictor.dynamic_split(scaled_data, dates, pd.to_datetime(CUTOFF_DATE))
    X_train, y_train = preprocessor.create_sequences(train_data, LOOK_BACK, PRED_STEPS)
    history = predictor.train(X_train, y_train)
    # é¢„测验证
    cutoff_idx = np.where(dates <= pd.to_datetime(CUTOFF_DATE))[0][-1]
    last_seq = scaled_data[cutoff_idx-LOOK_BACK:cutoff_idx]  # ä½¿ç”¨åˆ†å‰²ç‚¹å‰çš„æ•°æ®ä½œä¸ºè¾“å…¥
    scaled_pred = predictor.predict(last_seq)
    predictions = preprocessor.scaler.inverse_transform(scaled_pred.reshape(-1, 1))
    # ç»“果可视化
    true_dates = pd.date_range(start=pd.to_datetime(CUTOFF_DATE), periods=PRED_STEPS, freq='h')
    plt.figure(figsize=(15, 6))
    # ç»˜åˆ¶åˆ†å‰²ç‚¹å‰çš„历史数据
    plt.plot(dates[cutoff_idx-PRED_STEPS:cutoff_idx],
             preprocessor.scaler.inverse_transform(scaled_data[cutoff_idx-PRED_STEPS:cutoff_idx]),
             'b-', label='历史数据(分割点前)')
    # ç»˜åˆ¶åˆ†å‰²ç‚¹åŽçš„实际数据
    plt.plot(dates[cutoff_idx:cutoff_idx+PRED_STEPS],
             preprocessor.scaler.inverse_transform(scaled_data[cutoff_idx:cutoff_idx+PRED_STEPS]),
             'g-', label='实际数据(分割点后)')
    # ç»˜åˆ¶é¢„测数据
    plt.plot(true_dates, predictions, 'r--', label='预测数据')
    # æ·»åŠ åˆ†å‰²çº¿
    plt.axvline(x=pd.to_datetime(CUTOFF_DATE), color='k', linestyle='--', label='分割时间点')
    plt.title(f'盐度预测对比({CUTOFF_DATE}后5天)')
    plt.xlabel('时间')
    plt.ylabel('盐度值')
    plt.legend()
    plt.grid(True)
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()
    # æ€§èƒ½æŒ‡æ ‡
    true_values = preprocessor.scaler.inverse_transform(scaled_data[cutoff_idx:cutoff_idx+PRED_STEPS])
    mae = mean_absolute_error(true_values, predictions)
    rmse = np.sqrt(mean_squared_error(true_values, predictions))
    print(f'验证指标 => MAE: {mae:.3f}, RMSE: {rmse:.3f}')
# ================================
# 4. æ¨¡åž‹ä¿å­˜ä¸ŽåŠ è½½ï¼ˆå¯é€‰ï¼‰
# ================================
# model.save('salinity_predictor.h5')
# loaded_model = tf.keras.models.load_model('salinity_predictor.h5', custom_objects={'TemporalAttention': TemporalAttention})