1276eb361055c86d2107545e66ab78438fca54ba..137b119f6dd371f61b9a528a8bdf77cb32ef519d
2025-04-08 rp
模型更换
137b11 对比 | 目录
已添加1个文件
825 ■■■■■ 文件已修改
yd_lstm_test.py 825 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
yd_lstm_test.py
¶Ô±ÈÐÂÎļþ
@@ -0,0 +1,825 @@
import pandas as pd
import numpy as np
import tkinter as tk
from tkinter import ttk
from datetime import timedelta
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import tensorflow as tf
from tensorflow.keras.models import Sequential, load_model, save_model
from tensorflow.keras.layers import LSTM, Dense, Dropout
from tensorflow.keras.callbacks import EarlyStopping
import pickle
import os
from time import time
import matplotlib
# é…ç½®matplotlib中文显示
# è®¾ç½®ä¸­æ–‡å­—体,使用系统提供的字体
matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'SimSun', 'Arial Unicode MS']  # ä¼˜å…ˆä½¿ç”¨çš„中文字体
matplotlib.rcParams['axes.unicode_minus'] = False  # è§£å†³è´Ÿå·æ˜¾ç¤ºé—®é¢˜
matplotlib.rcParams['font.family'] = 'sans-serif'  # ä½¿ç”¨æ— è¡¬çº¿å­—体
# ç¼“存变量
cached_model = None
last_training_time = None
feature_columns = None
feature_scaler = None
target_scaler = None
# æ•°æ®åŠ è½½å‡½æ•°
def load_data(upstream_file, downstream_file, qinglong_lake_file=None):
    try:
        # ä½¿ç”¨é€—号作为分隔符读取数据
        upstream_df = pd.read_csv(upstream_file)
        downstream_df = pd.read_csv(downstream_file)
        if qinglong_lake_file:
            qinglong_lake_df = pd.read_csv(qinglong_lake_file)
    except FileNotFoundError:
        print("文件未找到,请检查文件路径")
        return None
    # å‡è®¾åˆ—名被读取为 'DateTime,TagName,Value',我们需要分割
    upstream_df.columns = ['DateTime', 'TagName', 'Value']
    downstream_df.columns = ['DateTime', 'TagName', 'Value']
    if qinglong_lake_file:
        qinglong_lake_df.columns = ['DateTime', 'TagName', 'Value']
    # å°† 'DateTime' åˆ—转换为日期格式
    upstream_df['DateTime'] = pd.to_datetime(upstream_df['DateTime'])
    downstream_df['DateTime'] = pd.to_datetime(downstream_df['DateTime'])
    if qinglong_lake_file:
        qinglong_lake_df['DateTime'] = pd.to_datetime(qinglong_lake_df['DateTime'])
    # æ£€æµ‹å¹¶å¤„理异常值 (先转换为数值型)
    upstream_df['Value'] = pd.to_numeric(upstream_df['Value'], errors='coerce')
    downstream_df['Value'] = pd.to_numeric(downstream_df['Value'], errors='coerce')
    if qinglong_lake_file:
        qinglong_lake_df['Value'] = pd.to_numeric(qinglong_lake_df['Value'], errors='coerce')
    # è¿‡æ»¤æŽ‰ç›åº¦å€¼å°äºŽ5的数据
    upstream_df = upstream_df[upstream_df['Value'] >= 5]
    downstream_df = downstream_df[downstream_df['Value'] >= 5]
    if qinglong_lake_file:
        qinglong_lake_df = qinglong_lake_df[qinglong_lake_df['Value'] >= 5]
    # å°†0值替换为NaN,以便后续进行插值处理
    upstream_df.loc[upstream_df['Value'] == 0, 'Value'] = np.nan
    downstream_df.loc[downstream_df['Value'] == 0, 'Value'] = np.nan
    if qinglong_lake_file:
        qinglong_lake_df.loc[qinglong_lake_df['Value'] == 0, 'Value'] = np.nan
    # ä½¿ç”¨åŸºæœ¬ç»Ÿè®¡æ–¹æ³•识别并替换异常值 (3倍标准差法)
    for df in [upstream_df, downstream_df]:
        mean = df['Value'].mean()
        std = df['Value'].std()
        lower_bound = mean - 3 * std
        upper_bound = mean + 3 * std
        # å°†è¶…出范围的值替换为NaN
        df.loc[(df['Value'] < lower_bound) | (df['Value'] > upper_bound), 'Value'] = np.nan
    if qinglong_lake_file:
        mean = qinglong_lake_df['Value'].mean()
        std = qinglong_lake_df['Value'].std()
        lower_bound = mean - 3 * std
        upper_bound = mean + 3 * std
        qinglong_lake_df.loc[(qinglong_lake_df['Value'] < lower_bound) | (qinglong_lake_df['Value'] > upper_bound), 'Value'] = np.nan
    # é‡å‘½å 'Value' ä¸º 'upstream' å’Œ 'downstream'
    upstream_df = upstream_df.rename(columns={'Value': 'upstream'})[['DateTime', 'upstream']]
    downstream_df = downstream_df.rename(columns={'Value': 'downstream'})[['DateTime', 'downstream']]
    if qinglong_lake_file:
        qinglong_lake_df = qinglong_lake_df.rename(columns={'Value': 'qinglong_lake'})[['DateTime', 'qinglong_lake']]
    # åˆå¹¶æ•°æ®
    merged_df = pd.merge(upstream_df, downstream_df, on='DateTime', how='inner')
    if qinglong_lake_file:
        merged_df = pd.merge(merged_df, qinglong_lake_df, on='DateTime', how='left')
    # å¤„理 NaN å’Œæ— æ•ˆå€¼
    print(f"合并前数据行数: {len(merged_df)}")
    # è®¾ç½®DateTime为索引以允许时间插值
    merged_df = merged_df.set_index('DateTime')
    # ä½¿ç”¨å¤šç§æ’值方法处理NaN值
    # 1. é¦–先使用线性插值填充短时间的NaN
    merged_df['upstream'] = merged_df['upstream'].interpolate(method='linear', limit=4)
    merged_df['downstream'] = merged_df['downstream'].interpolate(method='linear', limit=4)
    if qinglong_lake_file:
        merged_df['qinglong_lake'] = merged_df['qinglong_lake'].interpolate(method='linear', limit=4)
    # 2. å¯¹äºŽè¾ƒé•¿æ—¶é—´çš„NaN,使用时间加权插值
    merged_df['upstream'] = merged_df['upstream'].interpolate(method='time', limit=24)
    merged_df['downstream'] = merged_df['downstream'].interpolate(method='time', limit=24)
    if qinglong_lake_file:
        merged_df['qinglong_lake'] = merged_df['qinglong_lake'].interpolate(method='time', limit=24)
    # 3. å¯¹äºŽä»ç„¶å­˜åœ¨çš„NaN,使用前向填充和后向填充
    merged_df['upstream'] = merged_df['upstream'].fillna(method='ffill').fillna(method='bfill')
    merged_df['downstream'] = merged_df['downstream'].fillna(method='ffill').fillna(method='bfill')
    if qinglong_lake_file:
        merged_df['qinglong_lake'] = merged_df['qinglong_lake'].fillna(method='ffill').fillna(method='bfill')
    # 4. æ·»åŠ å¹³æ»‘å¤„ç†
    # ä½¿ç”¨ç§»åŠ¨å¹³å‡è¿›è¡Œå¹³æ»‘å¤„ç†
    merged_df['upstream_smooth'] = merged_df['upstream'].rolling(window=24, min_periods=1, center=True).mean()
    merged_df['downstream_smooth'] = merged_df['downstream'].rolling(window=24, min_periods=1, center=True).mean()
    if qinglong_lake_file:
        merged_df['qinglong_lake_smooth'] = merged_df['qinglong_lake'].rolling(window=24, min_periods=1, center=True).mean()
    # å¯¹é’龙港数据中盐度值低于50的部分进行额外平滑处理
    low_salinity_mask = merged_df['upstream'] < 50
    if low_salinity_mask.any():
        # å¯¹ä½Žç›åº¦éƒ¨åˆ†ä½¿ç”¨æ›´å¤§çš„平滑窗口
        merged_df.loc[low_salinity_mask, 'upstream_smooth'] = merged_df.loc[low_salinity_mask, 'upstream'].rolling(
            window=48, min_periods=1, center=True).mean()
    # åˆ é™¤å‰©ä½™çš„NaN和无穷大值
    merged_df = merged_df.dropna()
    merged_df = merged_df[merged_df['upstream'].apply(lambda x: np.isfinite(x))]
    merged_df = merged_df[merged_df['downstream'].apply(lambda x: np.isfinite(x))]
    if qinglong_lake_file:
        merged_df = merged_df[merged_df['qinglong_lake'].apply(lambda x: np.isfinite(x))]
    # é‡ç½®ç´¢å¼•,将DateTime重新作为列
    merged_df = merged_df.reset_index()
    # æœ€ç»ˆæ£€æŸ¥æ•°æ®
    print(f"清洗后数据行数: {len(merged_df)}")
    print(f"上游盐度范围: {merged_df['upstream'].min()} - {merged_df['upstream'].max()}")
    print(f"下游盐度范围: {merged_df['downstream'].min()} - {merged_df['downstream'].max()}")
    if qinglong_lake_file:
        print(f"青龙湖盐度范围: {merged_df['qinglong_lake'].min()} - {merged_df['qinglong_lake'].max()}")
    # ç¡®ä¿æ•°æ®æŒ‰æ—¶é—´æŽ’序
    merged_df = merged_df.sort_values('DateTime')
    return merged_df
# ç‰¹å¾å·¥ç¨‹ - LSTM版本
def create_sequences(df, look_back=96, forecast_horizon=5):
    print("开始特征工程(LSTM序列模式)...")
    start_time = time()
    # æå–主要特征列
    upstream = df['upstream'].values
    downstream = df['downstream'].values
    # é¢„先计算时间特征
    date_features = np.array([
        [x.hour/24, x.dayofweek/7, x.month/12]
        for x in df['DateTime']
    ])
    # åˆ›å»ºX和y序列
    X = []
    y = []
    # è®¡ç®—可用的样本数
    total_samples = len(df) - look_back - forecast_horizon
    if total_samples <= 0:
        print("数据不足以创建特征")
        return np.array([]), np.array([])
    print(f"开始创建序列,总样本数: {total_samples}")
    # æ‰¹é‡å¤„理以提高效率
    batch_size = 1000
    for batch_start in range(0, total_samples, batch_size):
        batch_end = min(batch_start + batch_size, total_samples)
        print(f"处理样本批次: {batch_start}-{batch_end}/{total_samples}")
        for i in range(batch_start, batch_end):
            # èŽ·å–å½“å‰æ—¶é—´çª—å£ç´¢å¼•
            end_idx = i + look_back
            # åŸºæœ¬åºåˆ—特征
            upstream_seq = upstream[i:end_idx]
            downstream_seq = downstream[i:end_idx]
            time_seq = date_features[i:end_idx]
            # è·³è¿‡å«æœ‰NaN的窗口
            if np.isnan(upstream_seq).any() or np.isnan(downstream_seq).any():
                continue
            # åˆå¹¶ç‰¹å¾ [samples, timesteps, features]
            # æ¯ä¸ªæ—¶é—´æ­¥åŒ…含: ä¸Šæ¸¸ç›åº¦, ä¸‹æ¸¸ç›åº¦, æ—¶é—´ç‰¹å¾(小时,星期,月份)
            input_seq = np.column_stack([
                upstream_seq,
                downstream_seq,
                time_seq
            ])
            # ç›®æ ‡æ˜¯é¢„测未来forecast_horizon天的下游盐度
            target_seq = downstream[end_idx:end_idx+forecast_horizon]
            # ç¡®ä¿ç›®æ ‡æ²¡æœ‰NaN
            if not np.isnan(target_seq).any():
                X.append(input_seq)
                y.append(target_seq)
    X = np.array(X)
    y = np.array(y)
    if len(X) == 0 or len(y) == 0:
        print("警告:没有能够生成有效的特征和标签对")
        return np.array([]), np.array([])
    end_time = time()
    print(f"特征工程完成,有效样本数: {len(X)},特征形状: {X.shape},标签形状: {y.shape},耗时: {end_time-start_time:.2f}秒")
    # ä¿å­˜ç‰¹å¾åˆ—名称以备将来使用
    global feature_columns
    feature_columns = ['upstream', 'downstream', 'hour', 'day_of_week', 'month']
    return X, y
# æž„建LSTM模型
def build_lstm_model(input_shape, output_length):
    model = Sequential([
        LSTM(units=64, return_sequences=True, input_shape=input_shape),
        Dropout(0.2),
        LSTM(units=32),
        Dropout(0.2),
        Dense(output_length)
    ])
    model.compile(optimizer='adam', loss='mse', metrics=['mae'])
    return model
# è®­ç»ƒå’Œé¢„测
def train_and_predict(df, start_time, force_retrain=False):
    global cached_model, last_training_time, feature_scaler, target_scaler
    # æ£€æŸ¥æ˜¯å¦æœ‰ç¼“存的模型和上次训练的时间
    model_cache_file = 'salinity_lstm_model.h5'
    scaler_cache_file = 'salinity_scalers.pkl'
    model_needs_training = True
    # ç‰¹å¾ç»“构已改变,强制重新训练 - åˆ é™¤æ—§æ¨¡åž‹æ–‡ä»¶
    if os.path.exists(model_cache_file) and force_retrain:
        try:
            os.remove(model_cache_file)
            if os.path.exists(scaler_cache_file):
                os.remove(scaler_cache_file)
            print("特征结构已更改,已删除旧模型缓存")
        except:
            pass
    # åªä½¿ç”¨æ—¶é—´ç‚¹ä¹‹å‰çš„æ•°æ®
    train_df = df[df['DateTime'] < start_time].copy()
    # å¦‚果有缓存模型且不需要强制重新训练
    if not force_retrain and cached_model is not None and last_training_time is not None and feature_scaler is not None and target_scaler is not None:
        # å¦‚果上次训练后没有新数据,使用缓存的模型
        if last_training_time >= train_df['DateTime'].max():
            model_needs_training = False
            print(f"使用缓存模型 (上次训练时间: {last_training_time})")
    # å¦‚果文件存在且不需要强制重新训练
    elif not force_retrain and os.path.exists(model_cache_file) and os.path.exists(scaler_cache_file):
        try:
            cached_model = load_model(model_cache_file)
            with open(scaler_cache_file, 'rb') as f:
                scalers = pickle.load(f)
                feature_scaler = scalers['feature_scaler']
                target_scaler = scalers['target_scaler']
                last_training_time = scalers['training_time']
                # å¦‚果上次训练后没有新数据,使用缓存的模型
                if last_training_time >= train_df['DateTime'].max():
                    model_needs_training = False
                    print(f"从文件加载模型 (上次训练时间: {last_training_time})")
        except Exception as e:
            print(f"加载模型失败: {e}")
    if model_needs_training:
        print("需要训练新模型...")
        if len(train_df) < 10:  # ç¡®ä¿æœ‰è¶³å¤Ÿçš„训练数据
            print("训练数据不足")
            return None, None, None
        # æ£€æŸ¥è®­ç»ƒæ•°æ®è´¨é‡
        print(f"训练数据范围: {train_df['DateTime'].min()} åˆ° {train_df['DateTime'].max()}")
        print(f"训练数据中NaN值统计:\n上游: {train_df['upstream'].isna().sum()}\n下游: {train_df['downstream'].isna().sum()}")
        # è®¡æ—¶å¼€å§‹
        start_time_training = time()
        # è¿›è¡Œç‰¹å¾å·¥ç¨‹
        X, y = create_sequences(train_df)
        # æ£€æŸ¥æ˜¯å¦æœ‰è¶³å¤Ÿçš„æ ·æœ¬
        if len(X) == 0 or len(y) == 0:
            print("没有足够的有效样本进行训练")
            return None, None, None
        # æ£€æŸ¥æ ·æœ¬æ•°é‡
        print(f"用于训练的样本数: {len(X)}")
        # åˆ›å»ºå¹¶åº”用缩放器 - ä¸ºLSTM缩放数据
        feature_scaler = MinMaxScaler(feature_range=(0, 1))
        # å¯¹æ¯ä¸ªæ ·æœ¬çš„æ¯ä¸ªæ—¶é—´ç‚¹è¿›è¡Œç¼©æ”¾
        # åŽŸå§‹å½¢çŠ¶ [samples, timesteps, features]
        n_samples, n_timesteps, n_features = X.shape
        X_reshaped = X.reshape(n_samples * n_timesteps, n_features)
        X_scaled = feature_scaler.fit_transform(X_reshaped)
        X_scaled = X_scaled.reshape(n_samples, n_timesteps, n_features)
        # å¯¹ç›®æ ‡å˜é‡è¿›è¡Œç¼©æ”¾
        target_scaler = MinMaxScaler(feature_range=(0, 1))
        y_scaled = target_scaler.fit_transform(y)
        # åˆ’分训练集和验证集
        X_train, X_val, y_train, y_val = train_test_split(X_scaled, y_scaled, test_size=0.2, random_state=42)
        # æž„建LSTM模型
        input_shape = (X_train.shape[1], X_train.shape[2])
        output_length = y_train.shape[1]
        model = build_lstm_model(input_shape, output_length)
        # è®­ç»ƒæ¨¡åž‹
        try:
            print("开始训练模型...")
            # æ—©åœè®¾ç½®
            early_stopping = EarlyStopping(
                monitor='val_loss',
                patience=20,
                restore_best_weights=True
            )
            # è®­ç»ƒæ¨¡åž‹
            history = model.fit(
                X_train, y_train,
                validation_data=(X_val, y_val),
                epochs=10,    #修改记得
                batch_size=32,
                callbacks=[early_stopping],
                verbose=1
            )
            # è®¡ç®—验证集性能
            val_pred_scaled = model.predict(X_val)
            val_pred = target_scaler.inverse_transform(val_pred_scaled)
            y_val_inv = target_scaler.inverse_transform(y_val)
            rmse = np.sqrt(np.mean((val_pred - y_val_inv) ** 2))
            print(f"验证集RMSE: {rmse:.4f}")
            # è®°å½•训练时间
            last_training_time = start_time
            # ç¼“存模型
            cached_model = model
            # ä¿å­˜æ¨¡åž‹åˆ°æ–‡ä»¶
            model.save(model_cache_file)
            with open(scaler_cache_file, 'wb') as f:
                pickle.dump({
                    'feature_scaler': feature_scaler,
                    'target_scaler': target_scaler,
                    'training_time': last_training_time,
                    'feature_columns': feature_columns,
                    'rmse': rmse
                }, f)
            print(f"模型训练完成,耗时: {time()-start_time_training:.2f}秒")
        except Exception as e:
            print(f"模型训练失败: {e}")
            return None, None, None
    else:
        # ä½¿ç”¨ç¼“存的模型
        model = cached_model
    # å‡†å¤‡æœ€æ–°æ•°æ®è¿›è¡Œé¢„测
    try:
        print("准备预测数据...")
        look_back = 96  # ç¡®ä¿ä¸Žè®­ç»ƒæ—¶ä¸€è‡´
        # èŽ·å–æ›´å¤šçš„åŽ†å²æ•°æ®ä»¥è€ƒè™‘å»¶è¿Ÿ
        latest_data = df[df['DateTime'] < start_time].tail(look_back).copy()
        if len(latest_data) < look_back:  # ç¡®ä¿æœ‰è¶³å¤Ÿçš„历史数据
            print("预测所需的历史数据不足")
            return None, None, None
        # å¤„理可能的NaN值
        if latest_data['upstream'].isna().any() or latest_data['downstream'].isna().any():
            latest_data['upstream'] = latest_data['upstream'].fillna(method='ffill').fillna(method='bfill')
            latest_data['downstream'] = latest_data['downstream'].fillna(method='ffill').fillna(method='bfill')
        # æå–特征序列
        upstream_seq = latest_data['upstream'].values
        downstream_seq = latest_data['downstream'].values
        # æ—¶é—´ç‰¹å¾
        time_seq = np.array([
            [x.hour/24, x.dayofweek/7, x.month/12]
            for x in latest_data['DateTime']
        ])
        # åˆå¹¶ç‰¹å¾
        input_seq = np.column_stack([
            upstream_seq,
            downstream_seq,
            time_seq
        ])
        # æ£€æŸ¥ç‰¹å¾æ˜¯å¦æœ‰æ•ˆ
        if np.isnan(input_seq).any() or np.isinf(input_seq).any():
            print("预测特征包含无效值")
            input_seq = np.nan_to_num(input_seq, nan=0.0, posinf=1e6, neginf=-1e6)
        # å¢žåŠ æ‰¹æ¬¡ç»´åº¦å¹¶ç¼©æ”¾
        input_seq = input_seq.reshape(1, look_back, -1)
        input_seq_reshaped = input_seq.reshape(look_back, -1)
        input_seq_scaled = feature_scaler.transform(input_seq_reshaped)
        input_seq_scaled = input_seq_scaled.reshape(1, look_back, -1)
        # é¢„测
        print("执行预测...")
        predictions_scaled = model.predict(input_seq_scaled)
        # åå‘缩放
        predictions = target_scaler.inverse_transform(predictions_scaled)[0]
        # ç”Ÿæˆæœªæ¥æ—¥æœŸ
        forecast_horizon = 5  # ç¡®ä¿ä¸Žè®­ç»ƒæ—¶ä¸€è‡´
        future_dates = [start_time + timedelta(days=i) for i in range(forecast_horizon)]
        print("预测成功完成")
        return future_dates, predictions, model
    except Exception as e:
        print(f"预测过程发生错误: {e}")
        return None, None, None
# GUI ç•Œé¢
def run_gui():
    # é…ç½®tkinter中文显示
    def configure_gui_fonts():
        # å°è¯•设置支持中文的字体
        font_names = ['微软雅黑', 'Microsoft YaHei', 'SimSun', 'SimHei']
        for font_name in font_names:
            try:
                default_font = tk.font.nametofont("TkDefaultFont")
                default_font.configure(family=font_name)
                text_font = tk.font.nametofont("TkTextFont")
                text_font.configure(family=font_name)
                fixed_font = tk.font.nametofont("TkFixedFont")
                fixed_font.configure(family=font_name)
                return True
            except:
                continue
        return False
    def on_predict():
        try:
            predict_start_time = time()
            status_label.config(text="预测中...")
            root.update()
            start_time = pd.to_datetime(entry.get())
            # æ£€æŸ¥æ¨¡åž‹ç¼“存情况
            cache_exists = os.path.exists('salinity_lstm_model.h5')
            if cache_exists and not retrain_var.get():
                try:
                    with open('salinity_scalers.pkl', 'rb') as f:
                        scalers = pickle.load(f)
                        # æ£€æŸ¥æ¨¡åž‹ç‰¹å¾æ•°é‡æ˜¯å¦ä¸€è‡´
                        model_features = scalers.get('feature_columns', [])
                        expected_features = ['upstream', 'downstream', 'hour', 'day_of_week', 'month']
                        if len(model_features) != len(expected_features):
                            status_label.config(text="特征结构已更改,请勾选'强制重新训练模型'")
                            return
                except:
                    pass
            force_retrain = retrain_var.get()
            future_dates, predictions, model = train_and_predict(df, start_time, force_retrain)
            if future_dates is None or predictions is None:
                status_label.config(text="预测失败")
                return
            # æ¸…空之前的图形
            ax.clear()
            # ç»˜åˆ¶åŽ†å²æ•°æ®
            history_end = min(start_time, df['DateTime'].max())
            history_start = history_end - timedelta(days=120)  # æ˜¾ç¤ºè¿‘30天的历史数据
            history_data = df[(df['DateTime'] >= history_start) & (df['DateTime'] <= history_end)]
            # ç»˜åˆ¶ä¸€å–æ°´(下游)的历史数据
            ax.plot(history_data['DateTime'], history_data['downstream'],
                    label='一取水(下游)盐度', color='blue', linewidth=1.5)
            # ç»˜åˆ¶é’龙港(上游)的历史数据 - ä½¿ç”¨å¹³æ»‘后的数据
            ax.plot(history_data['DateTime'], history_data['upstream_smooth'],
                    label='青龙港(上游)盐度', color='purple', linewidth=1.5, alpha=0.7)
            # ç»˜åˆ¶é’龙湖的历史数据 - ä½¿ç”¨å¹³æ»‘后的数据
            if 'qinglong_lake_smooth' in history_data.columns:
                ax.plot(history_data['DateTime'], history_data['qinglong_lake_smooth'],
                        label='青龙湖盐度', color='green', linewidth=1.5, alpha=0.7)
            # èŽ·å–é¢„æµ‹æœŸé—´çš„çœŸå®žå€¼ï¼ˆå¦‚æžœæœ‰ï¼‰
            actual_data = df[(df['DateTime'] >= start_time) &
                             (df['DateTime'] <= future_dates[-1])]
            # ç»˜åˆ¶é¢„测数据
            ax.plot(future_dates, predictions, marker='o', linestyle='--',
                    label='预测盐度', color='red', linewidth=2)
            # å¦‚果有真实值,绘制真实值
            if not actual_data.empty:
                ax.plot(actual_data['DateTime'], actual_data['downstream'],
                        marker='s', linestyle='-', label='真实盐度',
                        color='orange', linewidth=2)
                # è®¡ç®—预测误差
                # æ‰¾åˆ°æœ€æŽ¥è¿‘预测日期的实际值
                actual_values = []
                for pred_date in future_dates:
                    # æ‰¾åˆ°æœ€æŽ¥è¿‘的日期
                    closest_idx = (actual_data['DateTime'] - pred_date).abs().idxmin()
                    actual_values.append(actual_data.loc[closest_idx, 'downstream'])
                if len(actual_values) == len(predictions):
                    mse = np.mean((np.array(actual_values) - predictions) ** 2)
                    rmse = np.sqrt(mse)
                    mae = np.mean(np.abs(np.array(actual_values) - predictions))
                    # åœ¨å›¾ä¸Šæ˜¾ç¤ºè¯¯å·®æŒ‡æ ‡
                    error_text = f"RMSE: {rmse:.2f}, MAE: {mae:.2f}"
                    ax.text(0.02, 0.05, error_text, transform=ax.transAxes,
                            bbox=dict(facecolor='white', alpha=0.8))
            # æ·»åŠ ç½®ä¿¡åŒºé—´
            std_dev = history_data['downstream'].std() * 0.5  # ä½¿ç”¨åŽ†å²æ•°æ®çš„æ ‡å‡†å·®ä½œä¸ºé¢„æµ‹ä¸ç¡®å®šæ€§çš„ä¼°è®¡
            ax.fill_between(future_dates,
                            predictions - std_dev,
                            predictions + std_dev,
                            color='red', alpha=0.2)
            # è®¾ç½®æ ‡é¢˜ã€æ ‡ç­¾å’Œå›¾ä¾‹
            ax.set_xlabel('日期')
            ax.set_ylabel('盐度')
            ax.set_title(f"从 {start_time.strftime('%Y-%m-%d %H:%M:%S')} å¼€å§‹çš„盐度预测")
            ax.legend(loc='upper left')
            # è°ƒæ•´å¸ƒå±€ï¼Œé˜²æ­¢æ ‡ç­¾è¢«é®æŒ¡
            fig.tight_layout()
            # æ›´æ–°ç”»å¸ƒæ˜¾ç¤º
            canvas.draw()
            # è®¡ç®—预测总耗时
            predict_time = time() - predict_start_time
            # æ›´æ–°çŠ¶æ€
            status_label.config(text=f"预测完成 (耗时: {predict_time:.2f}秒)")
            # å±•示预测结果
            result_text = "预测结果:\n"
            for i, (date, pred) in enumerate(zip(future_dates, predictions)):
                result_text += f"第 {i+1} å¤© ({date.strftime('%Y-%m-%d')}): {pred:.2f}\n"
                # å¦‚果有真实值,添加到结果中
                if not actual_data.empty and i < len(actual_values):
                    result_text += f"   å®žé™…盐度: {actual_values[i]:.2f}\n"
                    result_text += f"   è¯¯å·®: {abs(actual_values[i] - pred):.2f}\n"
            result_label.config(text=result_text)
        except Exception as e:
            status_label.config(text=f"错误: {str(e)}")
    def on_scroll(event):
        # èŽ·å–å½“å‰åæ ‡è½´çš„èŒƒå›´
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        # è®¾ç½®æ»šè½®ç¼©æ”¾çš„增量
        zoom_factor = 1.1
        # ç¡®å®šé¼ æ ‡ä½ç½®åˆ°è½´çš„相对位置
        x_data = event.xdata
        y_data = event.ydata
        # å¦‚果鼠标不在坐标轴内,则使用轴中心
        if x_data is None:
            x_data = (xlim[0] + xlim[1]) / 2
        if y_data is None:
            y_data = (ylim[0] + ylim[1]) / 2
        # è®¡ç®—相对位置
        x_rel = (x_data - xlim[0]) / (xlim[1] - xlim[0])
        y_rel = (y_data - ylim[0]) / (ylim[1] - ylim[0])
        # æ£€æŸ¥æ»šè½®çš„æ»šåŠ¨æ–¹å‘ - ä½¿ç”¨event.step替代event.button,更准确
        # å‘上滚动(放大)为正值,向下滚动(缩小)为负值
        if event.step > 0:  # å‘上滚动 = æ”¾å¤§
            # è®¡ç®—新的区间
            new_width = (xlim[1] - xlim[0]) / zoom_factor
            new_height = (ylim[1] - ylim[0]) / zoom_factor
            # è®¡ç®—新的区间边界,保持鼠标位置相对不变
            x0 = x_data - x_rel * new_width
            y0 = y_data - y_rel * new_height
            x1 = x0 + new_width
            y1 = y0 + new_height
            ax.set_xlim([x0, x1])
            ax.set_ylim([y0, y1])
        else:  # å‘下滚动 = ç¼©å°
            # è®¡ç®—新的区间
            new_width = (xlim[1] - xlim[0]) * zoom_factor
            new_height = (ylim[1] - ylim[0]) * zoom_factor
            # è®¡ç®—新的区间边界,保持鼠标位置相对不变
            x0 = x_data - x_rel * new_width
            y0 = y_data - y_rel * new_height
            x1 = x0 + new_width
            y1 = y0 + new_height
            ax.set_xlim([x0, x1])
            ax.set_ylim([y0, y1])
        # æ›´æ–°ç”»å¸ƒæ˜¾ç¤º
        canvas.draw_idle()
    # å®šä¹‰é¼ æ ‡æ‹–动功能
    def on_mouse_press(event):
        if event.button == 1:  # å·¦é”®
            canvas.mpl_disconnect(hover_cid)
            canvas._pan_start = (event.x, event.y)
            canvas._xlim = ax.get_xlim()
            canvas._ylim = ax.get_ylim()
            canvas.mpl_connect('motion_notify_event', on_mouse_move)
    def on_mouse_release(event):
        if event.button == 1:  # å·¦é”®
            canvas.mpl_disconnect(move_cid[0])
            global hover_cid
            hover_cid = canvas.mpl_connect('motion_notify_event', update_cursor)
    def on_mouse_move(event):
        if event.button == 1 and hasattr(canvas, '_pan_start'):
            dx = event.x - canvas._pan_start[0]
            dy = event.y - canvas._pan_start[1]
            # è½¬æ¢åƒç´ ç§»åŠ¨åˆ°æ•°æ®åæ ‡ç§»åŠ¨
            x_span = canvas._xlim[1] - canvas._xlim[0]
            y_span = canvas._ylim[1] - canvas._ylim[0]
            # è½¬æ¢å› å­(用于将像素移动转换为数据范围移动)
            width, height = canvas.get_width_height()
            x_scale = x_span / width
            y_scale = y_span / height
            # è®¡ç®—新的限制
            xlim = [canvas._xlim[0] - dx * x_scale,
                    canvas._xlim[1] - dx * x_scale]
            ylim = [canvas._ylim[0] + dy * y_scale,
                    canvas._ylim[1] + dy * y_scale]
            # è®¾ç½®æ–°çš„限制
            ax.set_xlim(xlim)
            ax.set_ylim(ylim)
            canvas.draw_idle()
    # æ›´æ–°é¼ æ ‡æŒ‡é’ˆæ ·å¼
    def update_cursor(event):
        if event.inaxes == ax:
            canvas.get_tk_widget().config(cursor="fleur")  # æ‰‹å½¢å…‰æ ‡è¡¨ç¤ºå¯æ‹–动
        else:
            canvas.get_tk_widget().config(cursor="")
    # é‡ç½®è§†å›¾
    def reset_view():
        display_history()
        status_label.config(text="图表视图已重置")
    root = tk.Tk()
    root.title("青龙港-陈行盐度预测系统")  # ä¿®æ”¹ä¸ºä¸­æ–‡æ ‡é¢˜
    # å°è¯•配置中文字体
    try:
        import tkinter.font as tkfont
        configure_gui_fonts()
    except:
        print("无法配置GUI字体,可能影响中文显示")
    # åˆ›å»ºæ¡†æž¶
    input_frame = ttk.Frame(root, padding="10")
    input_frame.pack(fill=tk.X)
    control_frame = ttk.Frame(root, padding="5")
    control_frame.pack(fill=tk.X)
    result_frame = ttk.Frame(root, padding="10")
    result_frame.pack(fill=tk.BOTH, expand=True)
    # è¾“入框和预测按钮
    ttk.Label(input_frame, text="输入开始时间 (YYYY-MM-DD HH:MM:SS)").pack(side=tk.LEFT)
    entry = ttk.Entry(input_frame, width=25)
    entry.pack(side=tk.LEFT, padx=5)
    predict_button = ttk.Button(input_frame, text="预测", command=on_predict)
    predict_button.pack(side=tk.LEFT)
    # çŠ¶æ€æ ‡ç­¾
    status_label = ttk.Label(input_frame, text="提示: ç¬¬ä¸€æ¬¡è¿è¡Œè¯·å‹¾é€‰'强制重新训练模型'")
    status_label.pack(side=tk.LEFT, padx=10)
    # æŽ§åˆ¶é€‰é¡¹
    retrain_var = tk.BooleanVar(value=False)
    ttk.Checkbutton(control_frame, text="强制重新训练模型", variable=retrain_var).pack(side=tk.LEFT)
    # æ·»åŠ å›¾ä¾‹è¯´æ˜Ž
    legend_label = ttk.Label(control_frame, text="图例: ç´«è‰²=青龙港上游数据, è“è‰²=一取水下游数据, çº¢è‰²=预测值, ç»¿è‰²=实际值")
    legend_label.pack(side=tk.LEFT, padx=10)
    # ç»“果标签
    result_label = ttk.Label(result_frame, text="", justify=tk.LEFT)
    result_label.pack(side=tk.RIGHT, fill=tk.Y)
    # ç»˜å›¾åŒºåŸŸ - è®¾ç½®dpi提高清晰度
    fig, ax = plt.subplots(figsize=(10, 5), dpi=100)
    # åˆ›å»ºç”»å¸ƒå¹¶æ·»åŠ å·¥å…·æ 
    canvas = FigureCanvasTkAgg(fig, master=result_frame)
    canvas.get_tk_widget().pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
    # æ·»åŠ å·¥å…·æ 
    toolbar_frame = ttk.Frame(result_frame)
    toolbar_frame.pack(side=tk.BOTTOM, fill=tk.X)
    toolbar = NavigationToolbar2Tk(canvas, toolbar_frame)
    toolbar.update()
    # æ·»åŠ è‡ªå®šä¹‰é‡ç½®æŒ‰é’®
    reset_button = ttk.Button(control_frame, text="重置视图", command=reset_view)
    reset_button.pack(side=tk.LEFT, padx=5)
    # è¿žæŽ¥é¼ æ ‡äº‹ä»¶
    canvas.mpl_connect('button_press_event', on_mouse_press)
    canvas.mpl_connect('button_release_event', on_mouse_release)
    canvas.mpl_connect('scroll_event', on_scroll)
    # å…¨å±€å˜é‡ï¼Œç”¨äºŽå­˜å‚¨äº‹ä»¶è¿žæŽ¥ID
    move_cid = [None]
    hover_cid = canvas.mpl_connect('motion_notify_event', update_cursor)
    # é»˜è®¤åŠ è½½åŽ†å²æ•°æ®
    def display_history():
        # æ¸…空之前的图形
        ax.clear()
        # ç¡®ä¿æ˜¾ç¤ºå…¨éƒ¨åŽ†å²æ•°æ®ä½†ä¸è¶…è¿‡60天
        end_date = df['DateTime'].max()
        start_date = max(df['DateTime'].min(), end_date - timedelta(days=60))
        display_data = df[(df['DateTime'] >= start_date) & (df['DateTime'] <= end_date)]
        # ç»˜åˆ¶ä¸€å–æ°´(下游)历史数据
        ax.plot(display_data['DateTime'], display_data['downstream'],
                label='一取水(下游)盐度', color='blue', linewidth=1.5)
        # ç»˜åˆ¶é’龙港(上游)的历史数据 - ä½¿ç”¨å¹³æ»‘后的数据
        ax.plot(display_data['DateTime'], display_data['upstream_smooth'],
                label='青龙港(上游)盐度', color='purple', linewidth=1.5, alpha=0.7)
        # # ç»˜åˆ¶é’龙港历史数据 - ä½¿ç”¨å¹³æ»‘后的数据
        # if 'qinglong_lake_smooth' in display_data.columns:
        #     ax.plot(display_data['DateTime'], display_data['qinglong_lake_smooth'],
        #             label='青龙湖盐度', color='green', linewidth=1.5, alpha=0.7)
        # è®¾ç½®æ ‡é¢˜ã€æ ‡ç­¾å’Œå›¾ä¾‹
        ax.set_xlabel('日期')
        ax.set_ylabel('盐度')
        ax.set_title('历史盐度数据对比')
        ax.legend()
        # è°ƒæ•´å¸ƒå±€ï¼Œé˜²æ­¢æ ‡ç­¾è¢«é®æŒ¡
        fig.tight_layout()
        # æ›´æ–°ç”»å¸ƒæ˜¾ç¤º
        canvas.draw()
    # é»˜è®¤åŠ è½½åŽ†å²æ•°æ®
    display_history()
    # å¯åЍGUI
    root.mainloop()
# è¿è¡Œ
df = load_data('青龙港1.csv', '一取水.csv', )
# å¦‚果数据加载成功,则运行GUI界面
if df is not None:
    run_gui()
else:
    print("数据加载失败,无法运行预测。")