rp
2025-04-09 bb3caa0cd41bd56f496125934491759b4d865733
修复备份
已添加1个文件
611 ■■■■■ 文件已修改
xg_优化备份.py 611 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
xg_ÓÅ»¯±¸·Ý.py
¶Ô±ÈÐÂÎļþ
@@ -0,0 +1,611 @@
import os
import pickle
import pandas as pd
import numpy as np
import tkinter as tk
import tkinter.font as tkfont
from tkinter import ttk
from datetime import timedelta
from time import time
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk
from xgboost import XGBRegressor
from lunardate import LunarDate
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error
import matplotlib
# é…ç½® matplotlib ä¸­æ–‡æ˜¾ç¤º
matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'SimSun', 'Arial Unicode MS']
matplotlib.rcParams['axes.unicode_minus'] = False
matplotlib.rcParams['font.family'] = 'sans-serif'
# å…¨å±€ç¼“存变量及特征名称(此处 feature_columns ä»…为占位)
cached_model = None
last_training_time = None
feature_columns = None
# -------------------------------
# æ•°æ®åŠ è½½ä¸Žé¢„å¤„ç†å‡½æ•°
# -------------------------------
def load_data(upstream_file, downstream_file, qinglong_lake_file=None):
    try:
        upstream_df = pd.read_csv(upstream_file)
        downstream_df = pd.read_csv(downstream_file)
        if qinglong_lake_file:
            qinglong_lake_df = pd.read_csv(qinglong_lake_file)
    except FileNotFoundError:
        print("文件未找到,请检查路径")
        return None
    # å‡è®¾åŽŸå§‹æ•°æ®åˆ—ä¾æ¬¡ä¸º ['DateTime', 'TagName', 'Value']
    upstream_df.columns = ['DateTime', 'TagName', 'Value']
    downstream_df.columns = ['DateTime', 'TagName', 'Value']
    if qinglong_lake_file:
        qinglong_lake_df.columns = ['DateTime', 'TagName', 'Value']
    # è½¬æ¢æ—¶é—´æ ¼å¼åŠæ•°å€¼å¤„理
    upstream_df['DateTime'] = pd.to_datetime(upstream_df['DateTime'])
    downstream_df['DateTime'] = pd.to_datetime(downstream_df['DateTime'])
    if qinglong_lake_file:
        qinglong_lake_df['DateTime'] = pd.to_datetime(qinglong_lake_df['DateTime'])
    upstream_df['Value'] = pd.to_numeric(upstream_df['Value'], errors='coerce')
    downstream_df['Value'] = pd.to_numeric(downstream_df['Value'], errors='coerce')
    if qinglong_lake_file:
        qinglong_lake_df['Value'] = pd.to_numeric(qinglong_lake_df['Value'], errors='coerce')
    # è¿‡æ»¤ç›åº¦å°äºŽ5的数据
    upstream_df = upstream_df[upstream_df['Value'] >= 5]
    downstream_df = downstream_df[downstream_df['Value'] >= 5]
    if qinglong_lake_file:
        qinglong_lake_df = qinglong_lake_df[qinglong_lake_df['Value'] >= 5]
    # å°†0替换为NaN,并利用3倍标准差法处理异常值
    for df in [upstream_df, downstream_df]:
        df.loc[df['Value'] == 0, 'Value'] = np.nan
        mean_val, std_val = df['Value'].mean(), df['Value'].std()
        lower_bound, upper_bound = mean_val - 3 * std_val, mean_val + 3 * std_val
        df.loc[(df['Value'] < lower_bound) | (df['Value'] > upper_bound), 'Value'] = np.nan
    if qinglong_lake_file:
        qinglong_lake_df.loc[qinglong_lake_df['Value'] == 0, 'Value'] = np.nan
        mean_val, std_val = qinglong_lake_df['Value'].mean(), qinglong_lake_df['Value'].std()
        lower_bound, upper_bound = mean_val - 3 * std_val, mean_val + 3 * std_val
        qinglong_lake_df.loc[(qinglong_lake_df['Value'] < lower_bound) | (qinglong_lake_df['Value'] > upper_bound), 'Value'] = np.nan
    # é‡å‘½å Value åˆ—并保留需要的列
    upstream_df = upstream_df.rename(columns={'Value': 'upstream'})[['DateTime', 'upstream']]
    downstream_df = downstream_df.rename(columns={'Value': 'downstream'})[['DateTime', 'downstream']]
    if qinglong_lake_file:
        qinglong_lake_df = qinglong_lake_df.rename(columns={'Value': 'qinglong_lake'})[['DateTime', 'qinglong_lake']]
    # åˆå¹¶æ•°æ®
    merged_df = pd.merge(upstream_df, downstream_df, on='DateTime', how='inner')
    if qinglong_lake_file:
        merged_df = pd.merge(merged_df, qinglong_lake_df, on='DateTime', how='left')
    print(f"合并前数据行数: {len(merged_df)}")
    merged_df = merged_df.set_index('DateTime')
    # æ’值:先用线性,再用时间插值,最后用前向后向填充
    merged_df['upstream'] = merged_df['upstream'].interpolate(method='linear', limit=4)
    merged_df['downstream'] = merged_df['downstream'].interpolate(method='linear', limit=4)
    if qinglong_lake_file:
        merged_df['qinglong_lake'] = merged_df['qinglong_lake'].interpolate(method='linear', limit=4)
    merged_df['upstream'] = merged_df['upstream'].interpolate(method='time', limit=24)
    merged_df['downstream'] = merged_df['downstream'].interpolate(method='time', limit=24)
    if qinglong_lake_file:
        merged_df['qinglong_lake'] = merged_df['qinglong_lake'].interpolate(method='time', limit=24)
    merged_df['upstream'] = merged_df['upstream'].fillna(method='ffill').fillna(method='bfill')
    merged_df['downstream'] = merged_df['downstream'].fillna(method='ffill').fillna(method='bfill')
    if qinglong_lake_file:
        merged_df['qinglong_lake'] = merged_df['qinglong_lake'].fillna(method='ffill').fillna(method='bfill')
    # å¹³æ»‘处理:使用滑动窗口移动平均
    merged_df['upstream_smooth'] = merged_df['upstream'].rolling(window=24, min_periods=1, center=True).mean()
    merged_df['downstream_smooth'] = merged_df['downstream'].rolling(window=24, min_periods=1, center=True).mean()
    if qinglong_lake_file:
        merged_df['qinglong_lake_smooth'] = merged_df['qinglong_lake'].rolling(window=24, min_periods=1, center=True).mean()
    # å¯¹ä½Žç›åº¦éƒ¨åˆ†ç”¨æ›´å¤§çª—口平滑
    low_sal_mask = merged_df['upstream'] < 50
    if low_sal_mask.any():
        merged_df.loc[low_sal_mask, 'upstream_smooth'] = merged_df.loc[low_sal_mask, 'upstream']\
            .rolling(window=48, min_periods=1, center=True).mean()
    merged_df = merged_df.dropna()
    merged_df = merged_df[merged_df['upstream'].apply(np.isfinite)]
    merged_df = merged_df[merged_df['downstream'].apply(np.isfinite)]
    if qinglong_lake_file:
        merged_df = merged_df[merged_df['qinglong_lake'].apply(np.isfinite)]
    merged_df = merged_df.reset_index()
    print(f"清洗后数据行数: {len(merged_df)}")
    print(f"上游盐度范围: {merged_df['upstream'].min()} - {merged_df['upstream'].max()}")
    print(f"下游盐度范围: {merged_df['downstream'].min()} - {merged_df['downstream'].max()}")
    if qinglong_lake_file:
        print(f"青龙湖盐度范围: {merged_df['qinglong_lake'].min()} - {merged_df['qinglong_lake'].max()}")
    merged_df = merged_df.sort_values('DateTime')
    return merged_df
# -------------------------------
# æ·»åŠ å†œåŽ†ï¼ˆæ½®æ±ï¼‰ç‰¹å¾
# -------------------------------
def add_lunar_features(df):
    lunar_day, lunar_phase_sin, lunar_phase_cos, is_high_tide = [], [], [], []
    for dt in df['DateTime']:
        ld = LunarDate.fromSolarDate(dt.year, dt.month, dt.day)
        lunar_day.append(ld.day)
        lunar_phase_sin.append(np.sin(2 * np.pi * ld.day / 15))
        lunar_phase_cos.append(np.cos(2 * np.pi * ld.day / 15))
        is_high_tide.append(1 if (ld.day <= 5 or (ld.day >= 16 and ld.day <= 20)) else 0)
    df['lunar_day'] = lunar_day
    df['lunar_phase_sin'] = lunar_phase_sin
    df['lunar_phase_cos'] = lunar_phase_cos
    df['is_high_tide'] = is_high_tide
    return df
# -------------------------------
# æ‰¹é‡ç”Ÿæˆå»¶è¿Ÿç‰¹å¾ï¼ˆå‘量化,利用 shift)
# -------------------------------
def batch_create_delay_features(df, delay_hours):
    for delay in delay_hours:
        df[f'upstream_delay_{delay}h'] = df['upstream'].shift(delay)
        df[f'downstream_delay_{delay}h'] = df['downstream'].shift(delay)
    return df
# -------------------------------
# å‘量化构造训练样本(优化特征工程)
# -------------------------------
def create_features_vectorized(df, look_back=96, forecast_horizon=5):
    """
    åˆ©ç”¨ numpy çš„ sliding_window_view å¯¹åŽ†å²çª—å£ã€ä¸‹æ¸¸çª—å£ã€æ ‡ç­¾è¿›è¡Œæ‰¹é‡åˆ‡ç‰‡ï¼Œ
    å…¶ä»–特征(时间、农历、统计、延迟特征)直接批量读取后拼接
    """
    # è¿™é‡Œå®šä¹‰ total_samples ä¸ºï¼š
    total_samples = len(df) - look_back - forecast_horizon + 1
    if total_samples <= 0:
        print("数据不足以创建特征")
        return np.array([]), np.array([])
    # åˆ©ç”¨ sliding_window_view æž„造历史窗口(上游连续 look_back ä¸ªæ•°æ®ï¼‰
    upstream_array = df['upstream'].values  # shape (n,)
    # æ»‘动窗口,结果 shape (n - look_back + 1, look_back)
    from numpy.lib.stride_tricks import sliding_window_view
    window_up = sliding_window_view(upstream_array, window_shape=look_back)[:total_samples, :]
    # ä¸‹æ¸¸æœ€è¿‘ 24 å°æ—¶ï¼šåˆ©ç”¨æ»‘动窗口构造,窗口大小为 24
    downstream_array = df['downstream'].values
    window_down_full = sliding_window_view(downstream_array, window_shape=24)
    # å¯¹äºŽæ ‡ç­¾å’Œä¸‹æ¸¸çª—口,原逻辑:取 df['downstream'].iloc[i+look_back-24:i+look_back]
    # åˆ™å¯¹åº”索引为 i+look_back-24, i ä»Ž 0 åˆ° total_samples-1
    window_down = window_down_full[look_back-24 : look_back-24 + total_samples, :]
    # æ—¶é—´ç‰¹å¾ä¸Žå†œåŽ†ç‰¹å¾ç­‰ï¼šå–æ ·åŒºé—´ä¸º df.iloc[look_back: len(df)-forecast_horizon+1]
    sample_df = df.iloc[look_back: len(df)-forecast_horizon+1].copy()
    basic_time = sample_df['DateTime'].dt.hour.values.reshape(-1, 1) / 24.0
    weekday = sample_df['DateTime'].dt.dayofweek.values.reshape(-1, 1) / 7.0
    month = sample_df['DateTime'].dt.month.values.reshape(-1, 1) / 12.0
    basic_time_feats = np.hstack([basic_time, weekday, month])
    lunar_feats = sample_df[['lunar_phase_sin','lunar_phase_cos','is_high_tide']].values
    # ç»Ÿè®¡ç‰¹å¾ï¼ˆé¢„先利用 rolling å·²è®¡ç®—好,注意取出对应行)
    try:
        stats_up = sample_df[['mean_1d_up','mean_3d_up','std_1d_up','max_1d_up','min_1d_up']].values
        stats_down = sample_df[['mean_1d_down','mean_3d_down','std_1d_down','max_1d_down','min_1d_down']].values
    except KeyError as e:
        print(f"统计特征列不存在: {e},请确保先计算统计特征")
        return np.array([]), np.array([])
    # å»¶è¿Ÿç‰¹å¾ï¼šå‡è®¾æ‰€æœ‰å»¶è¿Ÿç‰¹å¾åˆ—名均以 "upstream_delay_" æˆ– "downstream_delay_" å¼€å¤´
    delay_cols = [col for col in sample_df.columns if col.startswith('upstream_delay_') or col.startswith('downstream_delay_')]
    delay_feats = sample_df[delay_cols].values
    # æ‹¼æŽ¥æ‰€æœ‰ç‰¹å¾ï¼šå…ˆå°†åŽ†å²çª—å£ï¼ˆwindow_up)与下游窗口(window_down)拼接,再拼接其他特征
    X = np.hstack([window_up, window_down, basic_time_feats, lunar_feats, stats_up, stats_down, delay_feats])
    # æž„造标签:利用滑动窗口构造 forecast_horizon å†…的下游数据
    label_full = sliding_window_view(downstream_array, window_shape=forecast_horizon)
    # æ ‡ç­¾åŒºé—´å¯¹åº”从 index = look_back åˆ° look_back + total_samples
    y = label_full[look_back: look_back + total_samples, :]
    global feature_columns
    feature_columns = ["combined_vector_features"]
    print(f"向量化特征工程完成,有效样本数: {X.shape[0]}")
    return X, y
# -------------------------------
# èŽ·å–æ¨¡åž‹å‡†ç¡®åº¦æŒ‡æ ‡
# -------------------------------
def get_model_metrics():
    """获取保存在模型缓存中的准确度指标"""
    model_cache_file = 'salinity_model.pkl'
    if os.path.exists(model_cache_file):
        try:
            with open(model_cache_file, 'rb') as f:
                model_data = pickle.load(f)
                return {
                    'rmse': model_data.get('rmse', None),
                    'mae': model_data.get('mae', None)
                }
        except Exception as e:
            print(f"获取模型指标失败: {e}")
    return None
# -------------------------------
# æ¨¡åž‹è®­ç»ƒä¸Žé¢„测,展示验证准确度(RMSE, MAE)
# -------------------------------
def train_and_predict(df, start_time, force_retrain=False):
    global cached_model, last_training_time
    model_cache_file = 'salinity_model.pkl'
    model_needs_training = True
    if os.path.exists(model_cache_file) and force_retrain:
        try:
            os.remove(model_cache_file)
            print("已删除旧模型缓存(强制重新训练)")
        except Exception as e:
            print("删除缓存异常:", e)
    train_df = df[df['DateTime'] < start_time].copy()
    if not force_retrain and cached_model is not None and last_training_time is not None:
        if last_training_time >= train_df['DateTime'].max():
            model_needs_training = False
            print(f"使用缓存模型,训练时间: {last_training_time}")
    elif not force_retrain and os.path.exists(model_cache_file):
        try:
            with open(model_cache_file, 'rb') as f:
                model_data = pickle.load(f)
                cached_model = model_data['model']
                last_training_time = model_data['training_time']
                if last_training_time >= train_df['DateTime'].max():
                    model_needs_training = False
                    print(f"从文件加载模型,训练时间: {last_training_time}")
        except Exception as e:
            print("加载模型失败:", e)
    if model_needs_training:
        print("开始训练新模型...")
        if len(train_df) < 100:
            print("训练数据不足")
            return None, None, None, None
        start_train = time()
        X, y = create_features_vectorized(train_df, look_back=96, forecast_horizon=5)
        if len(X) == 0 or len(y) == 0:
            print("样本生成不足,训练终止")
            return None, None, None, None
        print(f"训练样本数量: {X.shape[0]}")
        X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
        model = XGBRegressor(
            n_estimators=300,
            learning_rate=0.03,
            max_depth=5,
            min_child_weight=2,
            subsample=0.85,
            colsample_bytree=0.85,
            gamma=0.1,
            reg_alpha=0.2,
            reg_lambda=1.5,
            n_jobs=-1,
            random_state=42
        )
        try:
            model.fit(X_train, y_train,
                      eval_set=[(X_val, y_val)], eval_metric='rmse',
                      early_stopping_rounds=20, verbose=False)
            # åœ¨éªŒè¯é›†ä¸Šè®¡ç®— RMSE å’Œ MAE
            y_val_pred = model.predict(X_val)
            rmse = np.sqrt(mean_squared_error(y_val, y_val_pred))
            mae = mean_absolute_error(y_val, y_val_pred)
            print(f"验证集 RMSE: {rmse:.4f}, MAE: {mae:.4f}")
            last_training_time = start_time
            cached_model = model
            with open(model_cache_file, 'wb') as f:
                pickle.dump({
                    'model': model,
                    'training_time': last_training_time,
                    'feature_columns': feature_columns,
                    'rmse': rmse,
                    'mae': mae
                }, f)
            print(f"模型训练完成,耗时: {time() - start_train:.2f}秒")
        except Exception as e:
            print("模型训练异常:", e)
            return None, None, None, None
    else:
        model = cached_model
    # é¢„测部分:构造单个预测样本(与训练时特征构造一致)
    try:
        # è¿™é‡Œé‡‡ç”¨ä¸Ž create_features_vectorized ç±»ä¼¼çš„æ€è·¯æž„造预测样本
        # å–最近数据足够构成历史窗口和其他特征
        n = len(df)
        if n < 96 + 5:
            print("预测数据不足")
            return None, None, None, None
        # ä½¿ç”¨ sliding_window_view æž„造最新的上游和下游窗口
        upstream_array = df['upstream'].values
        window_up = np.lib.stride_tricks.sliding_window_view(upstream_array, window_shape=96)[-1, :]
        downstream_array = df['downstream'].values
        window_down = np.lib.stride_tricks.sliding_window_view(downstream_array, window_shape=24)[-1, :]
        # æ—¶é—´ç‰¹å¾å’Œå†œåŽ†ç‰¹å¾åŸºäºŽå½“å‰é¢„æµ‹å¼€å§‹æ—¶åˆ»
        hour_norm = start_time.hour / 24.0
        weekday_norm = start_time.dayofweek / 7.0
        month_norm = start_time.month / 12.0
        basic_time_feats = np.array([hour_norm, weekday_norm, month_norm]).reshape(1, -1)
        ld = LunarDate.fromSolarDate(start_time.year, start_time.month, start_time.day)
        lunar_feats = np.array([np.sin(2*np.pi*ld.day/15),
                                np.cos(2*np.pi*ld.day/15),
                                1 if (ld.day <=5 or (ld.day >=16 and ld.day<=20)) else 0]).reshape(1, -1)
        # ç»Ÿè®¡ç‰¹å¾ï¼šç”¨æœ€æ–° 24/72 å°æ—¶æ•°æ®ï¼ˆå–末尾24/72)
        try:
            # ä¼˜å…ˆä½¿ç”¨DataFrame中已计算的统计特征
            stats_up = df[['mean_1d_up','mean_3d_up','std_1d_up','max_1d_up','min_1d_up']].iloc[-1:].values
            stats_down = df[['mean_1d_down','mean_3d_down','std_1d_down','max_1d_down','min_1d_down']].iloc[-1:].values
        except KeyError:
            # å¦‚果不存在,则直接计算
            recent_up = df['upstream'].values[-24:]
            stats_up = np.array([np.mean(recent_up),
                                np.mean(df['upstream'].values[-72:]),
                                np.std(recent_up),
                                np.max(recent_up),
                                np.min(recent_up)]).reshape(1, -1)
            recent_down = df['downstream'].values[-24:]
            stats_down = np.array([np.mean(recent_down),
                                    np.mean(df['downstream'].values[-72:]),
                                    np.std(recent_down),
                                    np.max(recent_down),
                                    np.min(recent_down)]).reshape(1, -1)
        # å»¶è¿Ÿç‰¹å¾ï¼šç›´æŽ¥ä»Žæœ€åŽä¸€è¡Œå»¶è¿Ÿç‰¹å¾å–值
        delay_cols = [col for col in df.columns if col.startswith('upstream_delay_') or col.startswith('downstream_delay_')]
        delay_feats = df[delay_cols].iloc[-1:].values  # shape (1, ?)
        # æ‹¼æŽ¥æ‰€æœ‰é¢„测特征
        X_pred = np.hstack([window_up.reshape(1, -1),
                            window_down.reshape(1, -1),
                            basic_time_feats, lunar_feats, stats_up, stats_down, delay_feats])
        if np.isnan(X_pred).any() or np.isinf(X_pred).any():
            X_pred = np.nan_to_num(X_pred, nan=0.0, posinf=1e6, neginf=-1e6)
        predictions = model.predict(X_pred)
        # ç”Ÿæˆæœªæ¥æ—¥æœŸæ ‡ç­¾ï¼ˆé¢„测未来 5 å¤©ï¼‰
        future_dates = [start_time + timedelta(days=i) for i in range(5)]
        print("预测完成")
        # èŽ·å–æ¨¡åž‹æŒ‡æ ‡
        metrics = None
        if os.path.exists(model_cache_file):
            try:
                with open(model_cache_file, 'rb') as f:
                    model_data = pickle.load(f)
                    metrics = {
                        'rmse': model_data.get('rmse', None),
                        'mae': model_data.get('mae', None)
                    }
            except Exception as e:
                print(f"获取模型指标失败: {e}")
        return future_dates, predictions.flatten(), model, metrics
    except Exception as e:
        print("预测过程异常:", e)
        return None, None, None, None
# -------------------------------
# GUI界面部分
# -------------------------------
def run_gui():
    def configure_gui_fonts():
        font_names = ['微软雅黑', 'Microsoft YaHei', 'SimSun', 'SimHei']
        for font_name in font_names:
            try:
                default_font = tkfont.nametofont("TkDefaultFont")
                default_font.configure(family=font_name)
                text_font = tkfont.nametofont("TkTextFont")
                text_font.configure(family=font_name)
                fixed_font = tkfont.nametofont("TkFixedFont")
                fixed_font.configure(family=font_name)
                return True
            except Exception as e:
                continue
        return False
    def on_predict():
        try:
            predict_start = time()
            status_label.config(text="预测中...")
            root.update()
            start_time_dt = pd.to_datetime(entry.get())
            force_retrain = retrain_var.get()
            future_dates, predictions, model, metrics = train_and_predict(df, start_time_dt, force_retrain)
            if future_dates is None or predictions is None:
                status_label.config(text="预测失败")
                return
            # èŽ·å–å¹¶æ˜¾ç¤ºæ¨¡åž‹å‡†ç¡®åº¦æŒ‡æ ‡
            if metrics:
                metrics_text = f"模型准确度 - RMSE: {metrics['rmse']:.4f}, MAE: {metrics['mae']:.4f}"
                metrics_label.config(text=metrics_text)
            ax.clear()
            # ç»˜åˆ¶åŽ†å²æ•°æ®ï¼ˆæœ€è¿‘ 120 å¤©ï¼‰
            history_end = min(start_time_dt, df['DateTime'].max())
            history_start = history_end - timedelta(days=120)
            hist_data = df[(df['DateTime'] >= history_start) & (df['DateTime'] <= history_end)]
            ax.plot(hist_data['DateTime'], hist_data['downstream'], label='一取水(下游)盐度', color='blue', linewidth=1.5)
            ax.plot(hist_data['DateTime'], hist_data['upstream_smooth'], label='青龙港(上游)盐度', color='purple', linewidth=1.5, alpha=0.7)
            if 'qinglong_lake_smooth' in hist_data.columns:
                ax.plot(hist_data['DateTime'], hist_data['qinglong_lake_smooth'], label='青龙湖盐度', color='green', linewidth=1.5, alpha=0.7)
            ax.plot(future_dates, predictions, marker='o', linestyle='--', label='预测盐度', color='red', linewidth=2)
            actual_data = df[(df['DateTime'] >= start_time_dt) & (df['DateTime'] <= future_dates[-1])]
            if not actual_data.empty:
                ax.plot(actual_data['DateTime'], actual_data['downstream'], marker='s', linestyle='-', label='实际盐度', color='orange', linewidth=2)
            std_dev = hist_data['downstream'].std() * 0.5
            ax.fill_between(future_dates, predictions - std_dev, predictions + std_dev, color='red', alpha=0.2)
            ax.set_xlabel('日期')
            ax.set_ylabel('盐度')
            ax.set_title(f"从 {start_time_dt.strftime('%Y-%m-%d %H:%M:%S')} å¼€å§‹çš„盐度预测")
            ax.legend(loc='upper left')
            fig.tight_layout()
            canvas.draw()
            predict_time = time() - predict_start
            status_label.config(text=f"预测完成 (耗时: {predict_time:.2f}秒)")
            result_text = "预测结果:\n"
            for i, (date, pred) in enumerate(zip(future_dates, predictions)):
                result_text += f"第 {i+1} å¤© ({date.strftime('%Y-%m-%d')}): {pred:.2f}\n"
            result_label.config(text=result_text)
        except Exception as e:
            status_label.config(text=f"错误: {str(e)}")
    def on_scroll(event):
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        zoom_factor = 1.1
        x_data = event.xdata if event.xdata is not None else (xlim[0]+xlim[1])/2
        y_data = event.ydata if event.ydata is not None else (ylim[0]+ylim[1])/2
        x_rel = (x_data - xlim[0]) / (xlim[1] - xlim[0])
        y_rel = (y_data - ylim[0]) / (ylim[1] - ylim[0])
        if event.step > 0:
            new_width = (xlim[1]-xlim[0]) / zoom_factor
            new_height = (ylim[1]-ylim[0]) / zoom_factor
            x0 = x_data - x_rel * new_width
            y0 = y_data - y_rel * new_height
            ax.set_xlim([x0, x0+new_width])
            ax.set_ylim([y0, y0+new_height])
        else:
            new_width = (xlim[1]-xlim[0]) * zoom_factor
            new_height = (ylim[1]-ylim[0]) * zoom_factor
            x0 = x_data - x_rel * new_width
            y0 = y_data - y_rel * new_height
            ax.set_xlim([x0, x0+new_width])
            ax.set_ylim([y0, y0+new_height])
        canvas.draw_idle()
    def update_cursor(event):
        if event.inaxes == ax:
            canvas.get_tk_widget().config(cursor="fleur")
        else:
            canvas.get_tk_widget().config(cursor="")
    def reset_view():
        display_history()
        status_label.config(text="图表视图已重置")
    root = tk.Tk()
    root.title("青龙港-陈行盐度预测系统")
    try:
        configure_gui_fonts()
    except Exception as e:
        print("字体配置异常:", e)
    input_frame = ttk.Frame(root, padding="10")
    input_frame.pack(fill=tk.X)
    control_frame = ttk.Frame(root, padding="5")
    control_frame.pack(fill=tk.X)
    result_frame = ttk.Frame(root, padding="10")
    result_frame.pack(fill=tk.BOTH, expand=True)
    ttk.Label(input_frame, text="输入开始时间 (YYYY-MM-DD HH:MM:SS)").pack(side=tk.LEFT)
    entry = ttk.Entry(input_frame, width=25)
    entry.pack(side=tk.LEFT, padx=5)
    predict_button = ttk.Button(input_frame, text="预测", command=on_predict)
    predict_button.pack(side=tk.LEFT)
    status_label = ttk.Label(input_frame, text="提示: ç¬¬ä¸€æ¬¡è¿è¡Œè¯·å‹¾é€‰'强制重新训练模型'")
    status_label.pack(side=tk.LEFT, padx=10)
    retrain_var = tk.BooleanVar(value=False)
    ttk.Checkbutton(control_frame, text="强制重新训练模型", variable=retrain_var).pack(side=tk.LEFT)
    legend_label = ttk.Label(control_frame, text="图例: ç´«è‰²=青龙港上游数据, è“è‰²=一取水下游数据, çº¢è‰²=预测值, ç»¿è‰²=实际值")
    legend_label.pack(side=tk.LEFT, padx=10)
    reset_button = ttk.Button(control_frame, text="重置视图", command=reset_view)
    reset_button.pack(side=tk.LEFT, padx=5)
    # æ·»åŠ æ˜¾ç¤ºæ¨¡åž‹å‡†ç¡®åº¦çš„æ ‡ç­¾
    metrics_frame = ttk.Frame(root, padding="5")
    metrics_frame.pack(fill=tk.X)
    model_metrics = get_model_metrics()
    metrics_text = "模型准确度: æœªçŸ¥" if not model_metrics else f"模型准确度 - RMSE: {model_metrics['rmse']:.4f}, MAE: {model_metrics['mae']:.4f}"
    metrics_label = ttk.Label(metrics_frame, text=metrics_text)
    metrics_label.pack(side=tk.LEFT, padx=10)
    result_label = ttk.Label(result_frame, text="", justify=tk.LEFT)
    result_label.pack(side=tk.RIGHT, fill=tk.Y)
    fig, ax = plt.subplots(figsize=(10,5), dpi=100)
    canvas = FigureCanvasTkAgg(fig, master=result_frame)
    canvas.get_tk_widget().pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
    toolbar_frame = ttk.Frame(result_frame)
    toolbar_frame.pack(side=tk.BOTTOM, fill=tk.X)
    toolbar = NavigationToolbar2Tk(canvas, toolbar_frame)
    toolbar.update()
    canvas.mpl_connect('scroll_event', on_scroll)
    canvas.mpl_connect('motion_notify_event', update_cursor)
    def display_history():
        ax.clear()
        end_date = df['DateTime'].max()
        start_date = max(df['DateTime'].min(), end_date - timedelta(days=60))
        hist_data = df[(df['DateTime']>=start_date)&(df['DateTime']<=end_date)]
        ax.plot(hist_data['DateTime'], hist_data['downstream'], label='一取水(下游)盐度', color='blue', linewidth=1.5)
        ax.plot(hist_data['DateTime'], hist_data['upstream_smooth'], label='青龙港(上游)盐度', color='purple', linewidth=1.5, alpha=0.7)
        ax.set_xlabel('日期')
        ax.set_ylabel('盐度')
        ax.set_title('历史盐度数据对比')
        ax.legend()
        fig.tight_layout()
        canvas.draw()
    display_history()
    root.mainloop()
# -------------------------------
# ä¸»ç¨‹åºå…¥å£ï¼šåŠ è½½æ•°æ®ã€æ·»åŠ ç‰¹å¾ã€ç”Ÿæˆå»¶è¿Ÿç‰¹å¾åŽå¯åŠ¨GUI
# -------------------------------
def save_processed_data(df, filename='processed_data.pkl'):
    try:
        df.to_pickle(filename)
        print(f"已保存处理后的数据到 {filename}")
        return True
    except Exception as e:
        print(f"保存数据失败: {e}")
        return False
def load_processed_data(filename='processed_data.pkl'):
    try:
        if os.path.exists(filename):
            df = pd.read_pickle(filename)
            print(f"已从 {filename} åŠ è½½å¤„ç†åŽçš„æ•°æ®")
            return df
        else:
            print(f"找不到处理后的数据文件 {filename}")
            return None
    except Exception as e:
        print(f"加载数据失败: {e}")
        return None
# å°è¯•加载处理后的数据,如果不存在则重新处理
processed_data = load_processed_data()
if processed_data is not None:
    df = processed_data
else:
    df = load_data('青龙港1.csv', '一取水.csv')
    if df is not None:
        df = add_lunar_features(df)
        delay_hours = [1,2,3,4,6,12,24,36,48,60,72,84,96,108,120]
        df = batch_create_delay_features(df, delay_hours)
        # æ·»åŠ ç»Ÿè®¡ç‰¹å¾
        df['mean_1d_up'] = df['upstream'].rolling(window=24, min_periods=1).mean()
        df['mean_3d_up'] = df['upstream'].rolling(window=72, min_periods=1).mean()
        df['std_1d_up'] = df['upstream'].rolling(window=24, min_periods=1).std()
        df['max_1d_up'] = df['upstream'].rolling(window=24, min_periods=1).max()
        df['min_1d_up'] = df['upstream'].rolling(window=24, min_periods=1).min()
        df['mean_1d_down'] = df['downstream'].rolling(window=24, min_periods=1).mean()
        df['mean_3d_down'] = df['downstream'].rolling(window=72, min_periods=1).mean()
        df['std_1d_down'] = df['downstream'].rolling(window=24, min_periods=1).std()
        df['max_1d_down'] = df['downstream'].rolling(window=24, min_periods=1).max()
        df['min_1d_down'] = df['downstream'].rolling(window=24, min_periods=1).min()
        # ä¿å­˜å¤„理后的数据
        save_processed_data(df)
if df is not None:
    run_gui()
else:
    print("数据加载失败,无法运行预测。")