From 9036cd24c358c7b65963ebf59e2897b46d70a470 Mon Sep 17 00:00:00 2001
From: rp <rp@outlook.com>
Date: 星期一, 14 四月 2025 00:26:24 +0800
Subject: [PATCH] 备份

---
 testonly.py |  210 ++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 210 insertions(+), 0 deletions(-)

diff --git a/testonly.py b/testonly.py
new file mode 100644
index 0000000..4527a7f
--- /dev/null
+++ b/testonly.py
@@ -0,0 +1,210 @@
+import sys
+import pandas as pd
+import numpy as np
+import matplotlib.pyplot as plt
+from sklearn.preprocessing import MinMaxScaler
+from sklearn.metrics import mean_absolute_error, mean_squared_error
+
+# Check numpy version
+if np.__version__ < '1.23.5':
+    print(f"Warning: Current numpy version {np.__version__} may cause compatibility issues.")
+    print("Please upgrade numpy to version 1.23.5 or higher.")
+
+try:
+    import tensorflow as tf
+    print(f"TensorFlow version: {tf.__version__}")
+except ImportError as e:
+    print("Error importing TensorFlow:", e)
+    sys.exit(1)
+
+# Check TensorFlow version
+if tf.__version__ < '2.10.0':
+    print(f"Warning: Current TensorFlow version {tf.__version__} may cause compatibility issues.")
+    print("Please upgrade TensorFlow to version 2.10.0 or higher.")
+
+from tensorflow.keras.models import Model
+from tensorflow.keras.layers import Input, Conv1D, LSTM, Dense, Layer
+from tensorflow.keras.callbacks import EarlyStopping
+
+# ================================
+# 1. 鏁版嵁棰勫鐞嗘ā鍧�
+# ================================
+class DataPreprocessor:
+    def __init__(self, file_path, target_col='Value'):
+        # 浣跨敤姝g‘鐨勭紪鐮佸拰鍒楀悕
+        self.df = pd.read_csv(file_path, encoding='utf-8-sig', parse_dates=['DateTime'], index_col='DateTime')
+        self.target_col = target_col
+        self.scaler = MinMaxScaler(feature_range=(0, 1))
+        
+    def preprocess(self, resample_freq='h'):  # 浣跨敤灏忓啓鐨�'h'浠f浛澶у啓鐨�'H'
+        """鏁版嵁閲嶉噰鏍蜂笌褰掍竴鍖�"""
+        # 鍙�夋嫨Value鍒楄繘琛屽鐞�
+        value_series = self.df[self.target_col]
+        
+        # 澶勭悊闈炵瓑闂撮殧閲囨牱
+        df_resampled = value_series.resample(resample_freq).mean()
+        df_filled = df_resampled.fillna(method='ffill').fillna(method='bfill')  # 鍙屽悜濉厖
+        
+        # 褰掍竴鍖栧鐞�
+        self.scaled_data = self.scaler.fit_transform(df_filled.values.reshape(-1, 1))
+        self.dates = df_filled.index
+        return self.scaled_data, self.dates
+    
+    def create_sequences(self, data, look_back=72, pred_steps=120):
+        """鍒涘缓鐩戠潱瀛︿範搴忓垪"""
+        X, Y = [], []
+        for i in range(len(data) - look_back - pred_steps):
+            X.append(data[i:(i + look_back)])
+            Y.append(data[(i + look_back):(i + look_back + pred_steps)])
+        return np.array(X), np.array(Y)
+
+# ================================
+# 2. 妯″瀷鏋勫缓妯″潡
+# ================================
+class TemporalAttention(Layer):
+    """鏃堕棿娉ㄦ剰鍔涙満鍒跺眰"""
+    def __init__(self, units):
+        super(TemporalAttention, self).__init__()
+        self.W1 = Dense(units)
+        self.W2 = Dense(units)
+        self.V = Dense(1)
+
+    def call(self, encoder_output, lstm_output):
+        lstm_output = tf.expand_dims(lstm_output, 1)
+        score = self.V(tf.nn.tanh(
+            self.W1(encoder_output) + self.W2(lstm_output)))
+        attention_weights = tf.nn.softmax(score, axis=1)
+        context_vector = attention_weights * encoder_output
+        context_vector = tf.reduce_sum(context_vector, axis=1)
+        return context_vector, attention_weights
+
+class SalinityPredictor:
+    def __init__(self, look_back=72, pred_steps=120):
+        self.look_back = look_back    # 鍘嗗彶绐楀彛锛堝皬鏃讹級
+        self.pred_steps = pred_steps  # 棰勬祴姝ラ暱锛�5澶�=120灏忔椂锛�
+        
+    def build_model(self):
+        """鏋勫缓CNN-LSTM-Attention娣峰悎妯″瀷"""
+        inputs = Input(shape=(self.look_back, 1))
+        
+        # CNN鐗瑰緛鎻愬彇
+        cnn = Conv1D(64, 3, activation='relu', padding='same')(inputs)
+        cnn = Conv1D(32, 3, activation='relu', padding='same')(cnn)
+        
+        # LSTM鏃跺簭寤烘ā
+        lstm_out = LSTM(128, return_sequences=True)(cnn)
+        lstm_out = LSTM(64, return_sequences=False)(lstm_out)
+        
+        # 娉ㄦ剰鍔涙満鍒�
+        context_vector, _ = TemporalAttention(64)(cnn, lstm_out)
+        
+        # 杈撳嚭灞�
+        outputs = Dense(self.pred_steps)(context_vector)
+        
+        self.model = Model(inputs=inputs, outputs=outputs)
+        self.model.compile(optimizer='adam', loss='mse')
+        return self.model
+    
+    def dynamic_split(self, data, dates, cutoff_date):
+        """鍔ㄦ�佸垝鍒嗚缁冮泦"""
+        cutoff_idx = np.where(dates <= cutoff_date)[0][-self.look_back]
+        train_data = data[:cutoff_idx]
+        return train_data
+    
+    def train(self, X_train, y_train, epochs=200, batch_size=32):
+        """妯″瀷璁粌"""
+        early_stop = EarlyStopping(monitor='val_loss', patience=20)
+        history = self.model.fit(
+            X_train, y_train,
+            epochs=epochs,
+            batch_size=batch_size,
+            validation_split=0.2,
+            callbacks=[early_stop],
+            verbose=1
+        )
+        return history
+    
+    def predict(self, last_sequence):
+        """閫掑綊澶氭棰勬祴"""
+        predictions = []
+        current_seq = last_sequence.copy()
+        
+        for _ in range(self.pred_steps):
+            pred = self.model.predict(current_seq[np.newaxis, :, :], verbose=0)
+            predictions.append(pred[0][0])  # 鍙彇绗竴涓娴嬪��
+            current_seq = np.roll(current_seq, -1, axis=0)
+            current_seq[-1] = pred[0][0]  # 浣跨敤鍗曚釜棰勬祴鍊兼洿鏂板簭鍒�
+        
+        return np.array(predictions)
+
+# ================================
+# 3. 瀹屾暣娴佺▼鎵ц
+# ================================
+if __name__ == "__main__":
+    # 鍙傛暟閰嶇疆
+    DATA_PATH = 'D:\opencv\.venv\涓�鍙栨按.csv'
+    CUTOFF_DATE = '2024-12-20 00:00'  # 鐢ㄦ埛鎸囧畾鍒嗗壊鏃堕棿鐐�
+    LOOK_BACK = 72    # 3澶╁巻鍙叉暟鎹�
+    PRED_STEPS = 120  # 棰勬祴5澶�
+    
+    # 鏁版嵁棰勫鐞�
+    preprocessor = DataPreprocessor(DATA_PATH)
+    scaled_data, dates = preprocessor.preprocess()
+    X, Y = preprocessor.create_sequences(scaled_data, LOOK_BACK, PRED_STEPS)
+    
+    # 妯″瀷鏋勫缓
+    predictor = SalinityPredictor(LOOK_BACK, PRED_STEPS)
+    model = predictor.build_model()
+    model.summary()
+    
+    # 鍔ㄦ�佽缁�
+    train_data = predictor.dynamic_split(scaled_data, dates, pd.to_datetime(CUTOFF_DATE))
+    X_train, y_train = preprocessor.create_sequences(train_data, LOOK_BACK, PRED_STEPS)
+    history = predictor.train(X_train, y_train)
+    
+    # 棰勬祴楠岃瘉
+    cutoff_idx = np.where(dates <= pd.to_datetime(CUTOFF_DATE))[0][-1]
+    last_seq = scaled_data[cutoff_idx-LOOK_BACK:cutoff_idx]  # 浣跨敤鍒嗗壊鐐瑰墠鐨勬暟鎹綔涓鸿緭鍏�
+    scaled_pred = predictor.predict(last_seq)
+    predictions = preprocessor.scaler.inverse_transform(scaled_pred.reshape(-1, 1))
+    
+    # 缁撴灉鍙鍖�
+    true_dates = pd.date_range(start=pd.to_datetime(CUTOFF_DATE), periods=PRED_STEPS, freq='h')
+    plt.figure(figsize=(15, 6))
+    
+    # 缁樺埗鍒嗗壊鐐瑰墠鐨勫巻鍙叉暟鎹�
+    plt.plot(dates[cutoff_idx-PRED_STEPS:cutoff_idx], 
+             preprocessor.scaler.inverse_transform(scaled_data[cutoff_idx-PRED_STEPS:cutoff_idx]), 
+             'b-', label='鍘嗗彶鏁版嵁锛堝垎鍓茬偣鍓嶏級')
+    
+    # 缁樺埗鍒嗗壊鐐瑰悗鐨勫疄闄呮暟鎹�
+    plt.plot(dates[cutoff_idx:cutoff_idx+PRED_STEPS], 
+             preprocessor.scaler.inverse_transform(scaled_data[cutoff_idx:cutoff_idx+PRED_STEPS]), 
+             'g-', label='瀹為檯鏁版嵁锛堝垎鍓茬偣鍚庯級')
+    
+    # 缁樺埗棰勬祴鏁版嵁
+    plt.plot(true_dates, predictions, 'r--', label='棰勬祴鏁版嵁')
+    
+    # 娣诲姞鍒嗗壊绾�
+    plt.axvline(x=pd.to_datetime(CUTOFF_DATE), color='k', linestyle='--', label='鍒嗗壊鏃堕棿鐐�')
+    
+    plt.title(f'鐩愬害棰勬祴瀵规瘮锛坽CUTOFF_DATE}鍚�5澶╋級')
+    plt.xlabel('鏃堕棿')
+    plt.ylabel('鐩愬害鍊�')
+    plt.legend()
+    plt.grid(True)
+    plt.xticks(rotation=45)
+    plt.tight_layout()
+    plt.show()
+    
+    # 鎬ц兘鎸囨爣
+    true_values = preprocessor.scaler.inverse_transform(scaled_data[cutoff_idx:cutoff_idx+PRED_STEPS])
+    mae = mean_absolute_error(true_values, predictions)
+    rmse = np.sqrt(mean_squared_error(true_values, predictions))
+    print(f'楠岃瘉鎸囨爣 => MAE: {mae:.3f}, RMSE: {rmse:.3f}')
+
+# ================================
+# 4. 妯″瀷淇濆瓨涓庡姞杞斤紙鍙�夛級
+# ================================
+# model.save('salinity_predictor.h5') 
+# loaded_model = tf.keras.models.load_model('salinity_predictor.h5', custom_objects={'TemporalAttention': TemporalAttention})
\ No newline at end of file

--
Gitblit v1.9.3