From 9036cd24c358c7b65963ebf59e2897b46d70a470 Mon Sep 17 00:00:00 2001 From: rp <rp@outlook.com> Date: 星期一, 14 四月 2025 00:26:24 +0800 Subject: [PATCH] 备份 --- testonly.py | 210 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 files changed, 210 insertions(+), 0 deletions(-) diff --git a/testonly.py b/testonly.py new file mode 100644 index 0000000..4527a7f --- /dev/null +++ b/testonly.py @@ -0,0 +1,210 @@ +import sys +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +from sklearn.preprocessing import MinMaxScaler +from sklearn.metrics import mean_absolute_error, mean_squared_error + +# Check numpy version +if np.__version__ < '1.23.5': + print(f"Warning: Current numpy version {np.__version__} may cause compatibility issues.") + print("Please upgrade numpy to version 1.23.5 or higher.") + +try: + import tensorflow as tf + print(f"TensorFlow version: {tf.__version__}") +except ImportError as e: + print("Error importing TensorFlow:", e) + sys.exit(1) + +# Check TensorFlow version +if tf.__version__ < '2.10.0': + print(f"Warning: Current TensorFlow version {tf.__version__} may cause compatibility issues.") + print("Please upgrade TensorFlow to version 2.10.0 or higher.") + +from tensorflow.keras.models import Model +from tensorflow.keras.layers import Input, Conv1D, LSTM, Dense, Layer +from tensorflow.keras.callbacks import EarlyStopping + +# ================================ +# 1. 鏁版嵁棰勫鐞嗘ā鍧� +# ================================ +class DataPreprocessor: + def __init__(self, file_path, target_col='Value'): + # 浣跨敤姝g‘鐨勭紪鐮佸拰鍒楀悕 + self.df = pd.read_csv(file_path, encoding='utf-8-sig', parse_dates=['DateTime'], index_col='DateTime') + self.target_col = target_col + self.scaler = MinMaxScaler(feature_range=(0, 1)) + + def preprocess(self, resample_freq='h'): # 浣跨敤灏忓啓鐨�'h'浠f浛澶у啓鐨�'H' + """鏁版嵁閲嶉噰鏍蜂笌褰掍竴鍖�""" + # 鍙�夋嫨Value鍒楄繘琛屽鐞� + value_series = self.df[self.target_col] + + # 澶勭悊闈炵瓑闂撮殧閲囨牱 + df_resampled = value_series.resample(resample_freq).mean() + df_filled = df_resampled.fillna(method='ffill').fillna(method='bfill') # 鍙屽悜濉厖 + + # 褰掍竴鍖栧鐞� + self.scaled_data = self.scaler.fit_transform(df_filled.values.reshape(-1, 1)) + self.dates = df_filled.index + return self.scaled_data, self.dates + + def create_sequences(self, data, look_back=72, pred_steps=120): + """鍒涘缓鐩戠潱瀛︿範搴忓垪""" + X, Y = [], [] + for i in range(len(data) - look_back - pred_steps): + X.append(data[i:(i + look_back)]) + Y.append(data[(i + look_back):(i + look_back + pred_steps)]) + return np.array(X), np.array(Y) + +# ================================ +# 2. 妯″瀷鏋勫缓妯″潡 +# ================================ +class TemporalAttention(Layer): + """鏃堕棿娉ㄦ剰鍔涙満鍒跺眰""" + def __init__(self, units): + super(TemporalAttention, self).__init__() + self.W1 = Dense(units) + self.W2 = Dense(units) + self.V = Dense(1) + + def call(self, encoder_output, lstm_output): + lstm_output = tf.expand_dims(lstm_output, 1) + score = self.V(tf.nn.tanh( + self.W1(encoder_output) + self.W2(lstm_output))) + attention_weights = tf.nn.softmax(score, axis=1) + context_vector = attention_weights * encoder_output + context_vector = tf.reduce_sum(context_vector, axis=1) + return context_vector, attention_weights + +class SalinityPredictor: + def __init__(self, look_back=72, pred_steps=120): + self.look_back = look_back # 鍘嗗彶绐楀彛锛堝皬鏃讹級 + self.pred_steps = pred_steps # 棰勬祴姝ラ暱锛�5澶�=120灏忔椂锛� + + def build_model(self): + """鏋勫缓CNN-LSTM-Attention娣峰悎妯″瀷""" + inputs = Input(shape=(self.look_back, 1)) + + # CNN鐗瑰緛鎻愬彇 + cnn = Conv1D(64, 3, activation='relu', padding='same')(inputs) + cnn = Conv1D(32, 3, activation='relu', padding='same')(cnn) + + # LSTM鏃跺簭寤烘ā + lstm_out = LSTM(128, return_sequences=True)(cnn) + lstm_out = LSTM(64, return_sequences=False)(lstm_out) + + # 娉ㄦ剰鍔涙満鍒� + context_vector, _ = TemporalAttention(64)(cnn, lstm_out) + + # 杈撳嚭灞� + outputs = Dense(self.pred_steps)(context_vector) + + self.model = Model(inputs=inputs, outputs=outputs) + self.model.compile(optimizer='adam', loss='mse') + return self.model + + def dynamic_split(self, data, dates, cutoff_date): + """鍔ㄦ�佸垝鍒嗚缁冮泦""" + cutoff_idx = np.where(dates <= cutoff_date)[0][-self.look_back] + train_data = data[:cutoff_idx] + return train_data + + def train(self, X_train, y_train, epochs=200, batch_size=32): + """妯″瀷璁粌""" + early_stop = EarlyStopping(monitor='val_loss', patience=20) + history = self.model.fit( + X_train, y_train, + epochs=epochs, + batch_size=batch_size, + validation_split=0.2, + callbacks=[early_stop], + verbose=1 + ) + return history + + def predict(self, last_sequence): + """閫掑綊澶氭棰勬祴""" + predictions = [] + current_seq = last_sequence.copy() + + for _ in range(self.pred_steps): + pred = self.model.predict(current_seq[np.newaxis, :, :], verbose=0) + predictions.append(pred[0][0]) # 鍙彇绗竴涓娴嬪�� + current_seq = np.roll(current_seq, -1, axis=0) + current_seq[-1] = pred[0][0] # 浣跨敤鍗曚釜棰勬祴鍊兼洿鏂板簭鍒� + + return np.array(predictions) + +# ================================ +# 3. 瀹屾暣娴佺▼鎵ц +# ================================ +if __name__ == "__main__": + # 鍙傛暟閰嶇疆 + DATA_PATH = 'D:\opencv\.venv\涓�鍙栨按.csv' + CUTOFF_DATE = '2024-12-20 00:00' # 鐢ㄦ埛鎸囧畾鍒嗗壊鏃堕棿鐐� + LOOK_BACK = 72 # 3澶╁巻鍙叉暟鎹� + PRED_STEPS = 120 # 棰勬祴5澶� + + # 鏁版嵁棰勫鐞� + preprocessor = DataPreprocessor(DATA_PATH) + scaled_data, dates = preprocessor.preprocess() + X, Y = preprocessor.create_sequences(scaled_data, LOOK_BACK, PRED_STEPS) + + # 妯″瀷鏋勫缓 + predictor = SalinityPredictor(LOOK_BACK, PRED_STEPS) + model = predictor.build_model() + model.summary() + + # 鍔ㄦ�佽缁� + train_data = predictor.dynamic_split(scaled_data, dates, pd.to_datetime(CUTOFF_DATE)) + X_train, y_train = preprocessor.create_sequences(train_data, LOOK_BACK, PRED_STEPS) + history = predictor.train(X_train, y_train) + + # 棰勬祴楠岃瘉 + cutoff_idx = np.where(dates <= pd.to_datetime(CUTOFF_DATE))[0][-1] + last_seq = scaled_data[cutoff_idx-LOOK_BACK:cutoff_idx] # 浣跨敤鍒嗗壊鐐瑰墠鐨勬暟鎹綔涓鸿緭鍏� + scaled_pred = predictor.predict(last_seq) + predictions = preprocessor.scaler.inverse_transform(scaled_pred.reshape(-1, 1)) + + # 缁撴灉鍙鍖� + true_dates = pd.date_range(start=pd.to_datetime(CUTOFF_DATE), periods=PRED_STEPS, freq='h') + plt.figure(figsize=(15, 6)) + + # 缁樺埗鍒嗗壊鐐瑰墠鐨勫巻鍙叉暟鎹� + plt.plot(dates[cutoff_idx-PRED_STEPS:cutoff_idx], + preprocessor.scaler.inverse_transform(scaled_data[cutoff_idx-PRED_STEPS:cutoff_idx]), + 'b-', label='鍘嗗彶鏁版嵁锛堝垎鍓茬偣鍓嶏級') + + # 缁樺埗鍒嗗壊鐐瑰悗鐨勫疄闄呮暟鎹� + plt.plot(dates[cutoff_idx:cutoff_idx+PRED_STEPS], + preprocessor.scaler.inverse_transform(scaled_data[cutoff_idx:cutoff_idx+PRED_STEPS]), + 'g-', label='瀹為檯鏁版嵁锛堝垎鍓茬偣鍚庯級') + + # 缁樺埗棰勬祴鏁版嵁 + plt.plot(true_dates, predictions, 'r--', label='棰勬祴鏁版嵁') + + # 娣诲姞鍒嗗壊绾� + plt.axvline(x=pd.to_datetime(CUTOFF_DATE), color='k', linestyle='--', label='鍒嗗壊鏃堕棿鐐�') + + plt.title(f'鐩愬害棰勬祴瀵规瘮锛坽CUTOFF_DATE}鍚�5澶╋級') + plt.xlabel('鏃堕棿') + plt.ylabel('鐩愬害鍊�') + plt.legend() + plt.grid(True) + plt.xticks(rotation=45) + plt.tight_layout() + plt.show() + + # 鎬ц兘鎸囨爣 + true_values = preprocessor.scaler.inverse_transform(scaled_data[cutoff_idx:cutoff_idx+PRED_STEPS]) + mae = mean_absolute_error(true_values, predictions) + rmse = np.sqrt(mean_squared_error(true_values, predictions)) + print(f'楠岃瘉鎸囨爣 => MAE: {mae:.3f}, RMSE: {rmse:.3f}') + +# ================================ +# 4. 妯″瀷淇濆瓨涓庡姞杞斤紙鍙�夛級 +# ================================ +# model.save('salinity_predictor.h5') +# loaded_model = tf.keras.models.load_model('salinity_predictor.h5', custom_objects={'TemporalAttention': TemporalAttention}) \ No newline at end of file -- Gitblit v1.9.3