rp
2025-04-16 714cd801bd3831d775fc807d56d6a6dbc97ce73a
bug修复1
已添加1个文件
1419 ■■■■■ 文件已修改
demo.py 1419 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
demo.py
¶Ô±ÈÐÂÎļþ
@@ -0,0 +1,1419 @@
# xgboost修改版本
import os
import pickle
import pandas as pd
import numpy as np
from numpy.lib.stride_tricks import sliding_window_view
import tkinter as tk
import tkinter.font as tkfont
from tkinter import ttk
from datetime import timedelta
from time import time
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk
from xgboost import XGBRegressor
from lunardate import LunarDate
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error
import matplotlib
# é…ç½® matplotlib ä¸­æ–‡æ˜¾ç¤º
matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'SimSun', 'Arial Unicode MS']
matplotlib.rcParams['axes.unicode_minus'] = False
matplotlib.rcParams['font.family'] = 'sans-serif'
# å…¨å±€ç¼“存变量及特征名称
cached_model = None
last_training_time = None
feature_columns = None
current_view = {'xlim': None, 'ylim': None, 'ylim2': None}  # ç”¨äºŽå­˜å‚¨å½“前图表视图
event_cids = []  # ç”¨äºŽå­˜å‚¨äº‹ä»¶è¿žæŽ¥çš„ID
# æ•°æ®åŠ è½½ä¸Žé¢„å¤„ç†å‡½æ•°
# -------------------------------
def load_data(upstream_file, downstream_file, river_level_file=None, flow_file=None):
    """
    åŠ è½½æ‰€æœ‰ç›¸å…³æ•°æ®å¹¶è¿›è¡Œæ•°æ®è´¨é‡å¤„ç†
    """
    try:
        # è¯»å–上游和下游数据
        upstream_df = pd.read_csv(upstream_file)
        downstream_df = pd.read_csv(downstream_file)
    except FileNotFoundError:
        print("文件未找到,请检查路径")
        return None
    # ç¡®ä¿åˆ—名一致
    upstream_df.columns = ['DateTime', 'TagName', 'Value']
    downstream_df.columns = ['DateTime', 'TagName', 'Value']
    # è½¬æ¢æ—¶é—´æ ¼å¼å¹¶è®¾ç½®ä¸ºç´¢å¼•
    upstream_df['DateTime'] = pd.to_datetime(upstream_df['DateTime'])
    downstream_df['DateTime'] = pd.to_datetime(downstream_df['DateTime'])
    # è®¾ç½®DateTime为索引
    upstream_df.set_index('DateTime', inplace=True)
    downstream_df.set_index('DateTime', inplace=True)
    # æ•°å€¼å¤„理 - ä½¿ç”¨æ›´ç¨³å¥çš„转换方法
    for df in [upstream_df, downstream_df]:
        df['Value'] = pd.to_numeric(df['Value'], errors='coerce')
        # ä½¿ç”¨IQR方法检测异常值
        Q1 = df['Value'].quantile(0.25)
        Q3 = df['Value'].quantile(0.75)
        IQR = Q3 - Q1
        lower_bound = Q1 - 1.5 * IQR
        upper_bound = Q3 + 1.5 * IQR
        # å°†å¼‚常值替换为边界值
        df.loc[df['Value'] < lower_bound, 'Value'] = lower_bound
        df.loc[df['Value'] > upper_bound, 'Value'] = upper_bound
    # å¤„理低盐度值(小于5)
    # ä¸ç›´æŽ¥è¿‡æ»¤ï¼Œè€Œæ˜¯æ ‡è®°ä¸ºNaN并使用插值方法处理
    for df in [upstream_df, downstream_df]:
        # æ ‡è®°ä½Žç›åº¦å€¼ä¸ºNaN
        low_salinity_mask = df['Value'] < 5
        if low_salinity_mask.any():
            print(f"发现{low_salinity_mask.sum()}个低盐度值(<5),将使用插值处理")
            df.loc[low_salinity_mask, 'Value'] = np.nan
            # å¯¹çŸ­æœŸç¼ºå¤±ä½¿ç”¨çº¿æ€§æ’值
            df['Value'] = df['Value'].interpolate(method='linear', limit=4)
            # å¯¹è¾ƒé•¿æœŸç¼ºå¤±ä½¿ç”¨åŸºäºŽæ—¶é—´çš„æ’值
            df['Value'] = df['Value'].interpolate(method='time', limit=24)
            # å¯¹å‰©ä½™ç¼ºå¤±ä½¿ç”¨å‰å‘和后向填充
            df['Value'] = df['Value'].fillna(method='ffill').fillna(method='bfill')
            # ä½¿ç”¨æ»šåŠ¨ä¸­ä½æ•°å¹³æ»‘å¤„ç†åŽçš„å€¼
            df['Value'] = df['Value'].rolling(window=12, center=True, min_periods=1).median()
    # é‡å‘½åValue列
    upstream_df = upstream_df.rename(columns={'Value': 'upstream'})[['upstream']]
    downstream_df = downstream_df.rename(columns={'Value': 'downstream'})[['downstream']]
    # åˆå¹¶æ•°æ®
    merged_df = pd.merge(upstream_df, downstream_df, left_index=True, right_index=True, how='inner')
    # åŠ è½½é•¿æ±Ÿæ°´ä½æ•°æ®ï¼ˆå¦‚æžœæä¾›ï¼‰
    if river_level_file:
        try:
            river_level_df = pd.read_csv(river_level_file)
            print(f"成功读取水位数据文件: {river_level_file}")
            # ç¡®ä¿åˆ—名一致
            if len(river_level_df.columns) >= 3:
                river_level_df.columns = ['DateTime', 'TagName', 'Value']
            elif len(river_level_df.columns) == 2:
                river_level_df.columns = ['DateTime', 'Value']
                river_level_df['TagName'] = 'water_level'
            # æ•°æ®å¤„理
            river_level_df['DateTime'] = pd.to_datetime(river_level_df['DateTime'])
            river_level_df.set_index('DateTime', inplace=True)
            river_level_df['Value'] = pd.to_numeric(river_level_df['Value'], errors='coerce')
            # ä½¿ç”¨IQR方法处理异常值
            Q1 = river_level_df['Value'].quantile(0.25)
            Q3 = river_level_df['Value'].quantile(0.75)
            IQR = Q3 - Q1
            lower_bound = Q1 - 1.5 * IQR
            upper_bound = Q3 + 1.5 * IQR
            river_level_df.loc[river_level_df['Value'] < lower_bound, 'Value'] = lower_bound
            river_level_df.loc[river_level_df['Value'] > upper_bound, 'Value'] = upper_bound
            # é‡å‘½åå¹¶ä¿ç•™éœ€è¦çš„列
            river_level_df = river_level_df.rename(columns={'Value': 'water_level'})[['water_level']]
            # åˆå¹¶åˆ°ä¸»æ•°æ®æ¡†
            merged_df = pd.merge(merged_df, river_level_df, left_index=True, right_index=True, how='left')
            # å¯¹æ°´ä½æ•°æ®è¿›è¡Œæ’值处理
            merged_df['water_level'] = merged_df['water_level'].interpolate(method='time', limit=24)
            merged_df['water_level'] = merged_df['water_level'].fillna(method='ffill').fillna(method='bfill')
            # åˆ›å»ºå¹³æ»‘的水位数据
            merged_df['water_level_smooth'] = merged_df['water_level'].rolling(window=24, min_periods=1, center=True).mean()
            # æ·»åŠ æ°´ä½è¶‹åŠ¿ç‰¹å¾
            merged_df['water_level_trend_1h'] = merged_df['water_level_smooth'].diff(1)
            merged_df['water_level_trend_24h'] = merged_df['water_level_smooth'].diff(24)
            print(f"水位数据加载成功,范围: {merged_df['water_level'].min()} - {merged_df['water_level'].max()}")
        except Exception as e:
            print(f"水位数据加载失败: {str(e)}")
    # åŠ è½½å¤§é€šæµé‡æ•°æ®ï¼ˆå¦‚æžœæä¾›ï¼‰
    if flow_file:
        try:
            flow_df = pd.read_csv(flow_file)
            print(f"成功读取流量数据文件: {flow_file}")
            # ç¡®ä¿åˆ—名一致
            if len(flow_df.columns) >= 3:
                flow_df.columns = ['DateTime', 'TagName', 'Value']
            elif len(flow_df.columns) == 2:
                flow_df.columns = ['DateTime', 'Value']
                flow_df['TagName'] = 'flow'
            # æ•°æ®å¤„理
            flow_df['DateTime'] = pd.to_datetime(flow_df['DateTime'])
            flow_df.set_index('DateTime', inplace=True)
            flow_df['Value'] = pd.to_numeric(flow_df['Value'], errors='coerce')
            # ä½¿ç”¨IQR方法处理异常值
            Q1 = flow_df['Value'].quantile(0.25)
            Q3 = flow_df['Value'].quantile(0.75)
            IQR = Q3 - Q1
            lower_bound = Q1 - 1.5 * IQR
            upper_bound = Q3 + 1.5 * IQR
            flow_df.loc[flow_df['Value'] < lower_bound, 'Value'] = lower_bound
            flow_df.loc[flow_df['Value'] > upper_bound, 'Value'] = upper_bound
            # é‡å‘½åå¹¶ä¿ç•™éœ€è¦çš„列
            flow_df = flow_df.rename(columns={'Value': 'flow'})[['flow']]
            # åˆå¹¶åˆ°ä¸»æ•°æ®æ¡†
            merged_df = pd.merge(merged_df, flow_df, left_index=True, right_index=True, how='left')
            # å¯¹æµé‡æ•°æ®è¿›è¡Œæ’值处理
            merged_df['flow'] = merged_df['flow'].interpolate(method='time', limit=24)
            merged_df['flow'] = merged_df['flow'].fillna(method='ffill').fillna(method='bfill')
            # åˆ›å»ºå¹³æ»‘的流量数据
            merged_df['flow_smooth'] = merged_df['flow'].rolling(window=24, min_periods=1, center=True).mean()
            # æ·»åŠ æµé‡è¶‹åŠ¿ç‰¹å¾
            merged_df['flow_trend_1h'] = merged_df['flow_smooth'].diff(1)
            merged_df['flow_trend_24h'] = merged_df['flow_smooth'].diff(24)
            # æ·»åŠ æµé‡ç»Ÿè®¡ç‰¹å¾
            merged_df['mean_1d_flow'] = merged_df['flow_smooth'].rolling(window=24, min_periods=1).mean()
            merged_df['mean_3d_flow'] = merged_df['flow_smooth'].rolling(window=72, min_periods=1).mean()
            merged_df['std_1d_flow'] = merged_df['flow_smooth'].rolling(window=24, min_periods=1).std()
            # æ·»åŠ æµé‡å˜åŒ–ç‰¹å¾
            merged_df['flow_change_1h'] = merged_df['flow_smooth'].diff(1)
            merged_df['flow_change_24h'] = merged_df['flow_smooth'].diff(24)
            # # æ·»åŠ æµé‡ä¸Žç›åº¦æ¯”çŽ‡ï¼ˆç¡®ä¿ä¸‹æ¸¸å¹³æ»‘æ•°æ®å·²åˆ›å»ºï¼‰
            # if 'downstream_smooth' in merged_df.columns:
            #     merged_df['flow_sal_ratio'] = merged_df['flow_smooth'] / merged_df['downstream_smooth']
            # else:
            #     print("警告: ä¸‹æ¸¸å¹³æ»‘数据未创建,跳过flow_sal_ratio计算")
            print(f"流量数据加载成功,范围: {merged_df['flow'].min()} - {merged_df['flow'].max()} m³/s")
        except Exception as e:
            print(f"流量数据加载失败: {str(e)}")
    # å¯¹ç›åº¦æ•°æ®è¿›è¡Œæ’值和平滑处理
    merged_df['upstream'] = merged_df['upstream'].interpolate(method='time', limit=24)
    merged_df['downstream'] = merged_df['downstream'].interpolate(method='time', limit=24)
    # ä½¿ç”¨å‰å‘后向填充处理剩余的NaN值
    merged_df['upstream'] = merged_df['upstream'].ffill().bfill()
    merged_df['downstream'] = merged_df['downstream'].ffill().bfill()
    # åˆ›å»ºå¹³æ»‘的盐度数据
    merged_df['upstream_smooth'] = merged_df['upstream'].rolling(window=24, min_periods=1, center=True).mean()
    merged_df['downstream_smooth'] = merged_df['downstream'].rolling(window=24, min_periods=1, center=True).mean()
    # æ·»åŠ è¶‹åŠ¿ç‰¹å¾
    merged_df['upstream_trend_1h'] = merged_df['upstream_smooth'].diff(1)
    merged_df['upstream_trend_24h'] = merged_df['upstream_smooth'].diff(24)
    merged_df['downstream_trend_1h'] = merged_df['downstream_smooth'].diff(1)
    merged_df['downstream_trend_24h'] = merged_df['downstream_smooth'].diff(24)
    # å¡«å……NaN值
    merged_df['upstream_trend_1h'] = merged_df['upstream_trend_1h'].fillna(0)
    merged_df['upstream_trend_24h'] = merged_df['upstream_trend_24h'].fillna(0)
    merged_df['downstream_trend_1h'] = merged_df['downstream_trend_1h'].fillna(0)
    merged_df['downstream_trend_24h'] = merged_df['downstream_trend_24h'].fillna(0)
    # å¯¹ä½Žç›åº¦éƒ¨åˆ†ä½¿ç”¨æ›´å¤§çš„窗口进行平滑
    low_sal_mask = merged_df['upstream'] < 50
    if low_sal_mask.any():
        merged_df.loc[low_sal_mask, 'upstream_smooth'] = merged_df.loc[low_sal_mask, 'upstream']\
            .rolling(window=48, min_periods=1, center=True).mean()
    # æ•°æ®éªŒè¯å’Œç»Ÿè®¡
    print("\n数据质量统计:")
    print(f"总数据量: {len(merged_df)}")
    print(f"上游盐度范围: {merged_df['upstream'].min():.2f} - {merged_df['upstream'].max():.2f}")
    print(f"下游盐度范围: {merged_df['downstream'].min():.2f} - {merged_df['downstream'].max():.2f}")
    if 'water_level' in merged_df.columns:
        print(f"水位范围: {merged_df['water_level'].min():.2f} - {merged_df['water_level'].max():.2f}")
        print(f"水位缺失比例: {merged_df['water_level'].isna().mean()*100:.2f}%")
    if 'flow' in merged_df.columns:
        print(f"流量范围: {merged_df['flow'].min():.2f} - {merged_df['flow'].max():.2f} m³/s")
        print(f"流量缺失比例: {merged_df['flow'].isna().mean()*100:.2f}%")
    # é‡ç½®ç´¢å¼•,将DateTime作为列
    merged_df = merged_df.reset_index()
    return merged_df
# df = load_data('青龙港1.csv', '一取水.csv')
# æµ‹è¯•
# df = load_data('青龙港1.csv', '一取水.csv')
# df.to_csv('merged_data.csv', index=False)
# print(f"Merged data saved to 'merged_data.csv' successfully")
# # ç»˜åˆ¶ç›åº¦éšæ—¶é—´å˜åŒ–图
# plt.figure(figsize=(12, 6))
# plt.plot(df['DateTime'], df['upstream_smooth'], label='上游盐度', color='blue')
# plt.plot(df['DateTime'], df['downstream_smooth'], label='下游盐度', color='red')
# plt.xlabel('时间')
# plt.ylabel('盐度')
# plt.title('盐度随时间变化图')
# plt.legend()
# plt.grid(True)
# plt.tight_layout()
# plt.savefig('salinity_time_series.png', dpi=300)
# plt.show()
# ----------------------特征工程部分
# -------------------------------
# æ·»åŠ å†œåŽ†ï¼ˆæ½®æ±ï¼‰ç‰¹å¾
# -------------------------------
def add_lunar_features(df):
    lunar_day, lunar_phase_sin, lunar_phase_cos, is_high_tide = [], [], [], []
    for dt in df['DateTime']:
        ld = LunarDate.fromSolarDate(dt.year, dt.month, dt.day)
        lunar_day.append(ld.day)
        lunar_phase_sin.append(np.sin(2 * np.pi * ld.day / 15))
        lunar_phase_cos.append(np.cos(2 * np.pi * ld.day / 15))
        is_high_tide.append(1 if (ld.day <= 5 or (ld.day >= 16 and ld.day <= 20)) else 0)
    df['lunar_day'] = lunar_day
    df['lunar_phase_sin'] = lunar_phase_sin
    df['lunar_phase_cos'] = lunar_phase_cos
    df['is_high_tide'] = is_high_tide
    return df
# -------------------------------
# ç”Ÿæˆå»¶è¿Ÿç‰¹å¾ï¼ˆå‘量化,利用 shift)
# -------------------------------
def batch_create_delay_features(df, delay_hours):
    """
    ä¸ºæ•°æ®æ¡†ä¸­çš„特定列创建延迟特征
    """
    # å®šä¹‰éœ€è¦åˆ›å»ºå»¶è¿Ÿç‰¹å¾çš„列
    target_columns = ['upstream_smooth']
    # åˆ›å»ºå»¶è¿Ÿç‰¹å¾
    for column in target_columns:
        if column in df.columns:
            for delay in delay_hours:
                df[f'{column.split("_")[0]}_delay_{delay}h'] = df[column].shift(delay)
        else:
            print(f"警告: åˆ— {column} ä¸å­˜åœ¨ï¼Œè·³è¿‡åˆ›å»ºå»¶è¿Ÿç‰¹å¾")
    return df
# ç”Ÿæˆå…¶ä»–特征
def generate_features(df):
    """
    ç”Ÿæˆå…¶ä»–特征,包括历史数据、时间特征、统计特征和外部特征,并将这些特征添加到原始DataFrame中
    """
    try:
        # åˆ›å»ºå¹³æ»‘的盐度数据
        df['upstream_smooth'] = df['upstream'].rolling(window=24, min_periods=1, center=True).mean()
        df['downstream_smooth'] = df['downstream'].rolling(window=24, min_periods=1, center=True).mean()
        # æ—¶é—´ç‰¹å¾
        df['hour'] = df['DateTime'].dt.hour
        df['weekday'] = df['DateTime'].dt.dayofweek
        df['month'] = df['DateTime'].dt.month
        # æ—¶é—´ç‰¹å¾çš„sin和cos转换
        df['hour_sin'] = np.sin(2 * np.pi * df['hour'] / 24)
        df['hour_cos'] = np.cos(2 * np.pi * df['hour'] / 24)
        df['weekday_sin'] = np.sin(2 * np.pi * df['weekday'] / 7)
        df['weekday_cos'] = np.cos(2 * np.pi * df['weekday'] / 7)
        df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
        df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
        # ç»Ÿè®¡ç‰¹å¾
        df['mean_1d_up'] = df['upstream_smooth'].rolling(window=24, min_periods=1).mean()
        df['mean_3d_up'] = df['upstream_smooth'].rolling(window=72, min_periods=1).mean()
        df['std_1d_up'] = df['upstream_smooth'].rolling(window=24, min_periods=1).std()
        df['mean_1d_down'] = df['downstream_smooth'].rolling(window=24, min_periods=1).mean()
        df['mean_3d_down'] = df['downstream_smooth'].rolling(window=72, min_periods=1).mean()
        df['std_1d_down'] = df['downstream_smooth'].rolling(window=24, min_periods=1).std()
        # è¶‹åŠ¿ç‰¹å¾
        df['trend_1h_up'] = df['upstream_smooth'].diff(1)
        df['trend_3h_up'] = df['upstream_smooth'].diff(3)
        df['trend_6h_up'] = df['upstream_smooth'].diff(6)
        df['trend_12h_up'] = df['upstream_smooth'].diff(12)
        df['trend_24h_up'] = df['upstream_smooth'].diff(24)
        df['trend_1h_down'] = df['downstream_smooth'].diff(1)
        df['trend_3h_down'] = df['downstream_smooth'].diff(3)
        df['trend_6h_down'] = df['downstream_smooth'].diff(6)
        df['trend_12h_down'] = df['downstream_smooth'].diff(12)
        df['trend_24h_down'] = df['downstream_smooth'].diff(24)
        # å¤–部特征(水位和流量)
        if 'water_level_smooth' in df.columns:
            df['water_level_trend_1h'] = df['water_level_smooth'].diff(1)
            df['water_level_trend_24h'] = df['water_level_smooth'].diff(24)
            df['mean_1d_water_level'] = df['water_level_smooth'].rolling(window=24, min_periods=1).mean()
            df['mean_3d_water_level'] = df['water_level_smooth'].rolling(window=72, min_periods=1).mean()
            df['std_1d_water_level'] = df['water_level_smooth'].rolling(window=24, min_periods=1).std()
        if 'flow_smooth' in df.columns:
            df['flow_trend_1h'] = df['flow_smooth'].diff(1)
            df['flow_trend_24h'] = df['flow_smooth'].diff(24)
            df['mean_1d_flow'] = df['flow_smooth'].rolling(window=24, min_periods=1).mean()
            df['mean_3d_flow'] = df['flow_smooth'].rolling(window=72, min_periods=1).mean()
            df['std_1d_flow'] = df['flow_smooth'].rolling(window=24, min_periods=1).std()
        return df
    except Exception as e:
        print(f"特征生成异常: {e}")
        return df
# -------------------------------
# å‘量化构造训练样本(优化特征工程)
# -------------------------------
def create_features_vectorized(df, look_back=168, forecast_horizon=1):
    """
    å‘量化构造训练样本,使用过去7天的所有原始数据来预测未来1天的下游盐度均值
    """
    try:
        # ç¡®ä¿æ•°æ®æŒ‰æ—¶é—´æŽ’序
        df = df.sort_values('DateTime')
        # èŽ·å–æ‰€æœ‰æ•°å€¼åˆ—ï¼ˆæŽ’é™¤DateTime列)
        numeric_columns = df.select_dtypes(include=[np.number]).columns.tolist()
        if 'DateTime' in numeric_columns:
            numeric_columns.remove('DateTime')
        # åˆå§‹åŒ–特征和标签列表
        features = []  # x输入
        targets = []   # y输出
        # ä½¿ç”¨æ»‘动窗口创建样本
        for i in range(len(df) - look_back - forecast_horizon + 1):
            # èŽ·å–7天的特征窗口
            window = df.iloc[i:i+look_back]
            # æå–特征 - ä½¿ç”¨æ‰€æœ‰åŽŸå§‹æ•°æ®
            window_features = []
            for col in numeric_columns:
                # èŽ·å–åˆ—æ•°æ®å¹¶å¤„ç†NaN值
                col_values = window[col].fillna(method='ffill').fillna(method='bfill').values
                window_features.extend(col_values)
            # æ·»åŠ æ—¶é—´ç‰¹å¾
            current_date = window['DateTime'].iloc[-1]
            window_features.extend([
                current_date.month,
                current_date.day,
                current_date.weekday()
            ])
            # èŽ·å–ç›®æ ‡å€¼ï¼ˆæœªæ¥1天的下游盐度均值)
            next_day = df.iloc[i+look_back:i+look_back+24]  # èŽ·å–æœªæ¥24小时的数据
            # å¤„理目标值中的NaN
            target_values = next_day['downstream_smooth'].fillna(method='ffill').fillna(method='bfill').values
            target = np.mean(target_values)
            # æ£€æŸ¥ç‰¹å¾å’Œç›®æ ‡å€¼æ˜¯å¦æœ‰æ•ˆ
            if not np.any(np.isnan(window_features)) and not np.isnan(target) and not np.isinf(target):
                features.append(window_features)
                targets.append(target)
        if not features:
            print("警告: æœªèƒ½ç”Ÿæˆä»»ä½•有效特征")
            return np.array([]), np.array([])
        # è½¬æ¢ä¸ºnumpy数组
        X = np.array(features)
        y = np.array(targets)
        print(f"成功生成特征矩阵,形状: {X.shape}")
        return X, y
    except Exception as e:
        print(f"特征创建异常: {e}")
        return np.array([]), np.array([])
def generate_prediction_features(df, current_date, look_back=168):
    """
    ä¸ºé¢„测生成特征,使用与create_features_vectorized相同的特征生成逻辑
    """
    try:
        # ç¡®ä¿æ•°æ®æŒ‰æ—¶é—´æŽ’序
        df = df.sort_values('DateTime')
        # èŽ·å–æ‰€æœ‰æ•°å€¼åˆ—ï¼ˆæŽ’é™¤DateTime列)
        numeric_columns = df.select_dtypes(include=[np.number]).columns.tolist()
        if 'DateTime' in numeric_columns:
            numeric_columns.remove('DateTime')
        # æ‰¾åˆ°å½“前日期在数据中的位置
        current_idx = df[df['DateTime'] <= current_date].index[-1]
        # èŽ·å–è¿‡åŽ»168小时(7天)的数据窗口
        if current_idx < look_back:
            print(f"数据不足,需要{look_back}小时的数据,但只有{current_idx+1}小时")
            return None
        window = df.iloc[current_idx-look_back+1:current_idx+1]
        # æå–特征 - ä½¿ç”¨æ‰€æœ‰åŽŸå§‹æ•°æ®
        features = []
        for col in numeric_columns:
            # ç›´æŽ¥ä½¿ç”¨åŽŸå§‹æ•°æ®ä½œä¸ºç‰¹å¾
            features.extend(window[col].values)
        # æ·»åŠ æ—¶é—´ç‰¹å¾
        features.extend([
            current_date.month,
            current_date.day,
            current_date.weekday()
        ])
        return np.array(features)
    except Exception as e:
        print(f"预测特征生成异常: {e}")
        return None
# -------------------------------
# èŽ·å–æ¨¡åž‹å‡†ç¡®åº¦æŒ‡æ ‡
# -------------------------------
def get_model_metrics():
    """获取保存在模型缓存中的准确度指标"""
    model_cache_file = 'salinity_model.pkl'
    if os.path.exists(model_cache_file):
        try:
            with open(model_cache_file, 'rb') as f:
                model_data = pickle.load(f)
                return {
                    'rmse': model_data.get('rmse', None),
                    'mae': model_data.get('mae', None)
                }
        except Exception as e:
            print(f"获取模型指标失败: {e}")
    return None
# -------------------------------
# æ¨¡åž‹è®­ç»ƒä¸Žé¢„测,展示验证准确度(RMSE, MAE)
# -------------------------------
def train_and_predict(df, start_time, force_retrain=False):
    global cached_model, last_training_time
    model_cache_file = 'salinity_model.pkl'
    model_needs_training = True
    if os.path.exists(model_cache_file) and force_retrain:
        try:
            os.remove(model_cache_file)
            print("已删除旧模型缓存(强制重新训练)")
        except Exception as e:
            print("删除缓存异常:", e)
    # èŽ·å–è®­ç»ƒæ•°æ®ï¼ˆä½¿ç”¨å¼€å§‹æ—¶é—´ä¹‹å‰çš„æ‰€æœ‰æ•°æ®ï¼‰
    train_df = df[df['DateTime'] < start_time].copy()
    if len(train_df) < 100:
        print(f"训练数据不足,需要至少100个样本,当前只有{len(train_df)}个样本")
        return None, None, None, None
    print(f"使用 {train_df['DateTime'].min()} åˆ° {train_df['DateTime'].max()} çš„æ•°æ®è¿›è¡Œè®­ç»ƒ")
    print(f"训练数据总量: {len(train_df)} ä¸ªæ ·æœ¬")
    # åˆ›å»ºæµ‹è¯•特征,检查当前特征维度
    test_X, test_y = create_features_vectorized(train_df, look_back=168, forecast_horizon=1)
    if test_X is None or test_y is None:
        print("特征生成失败")
        return None, None, None, None
    current_feature_dim = test_X.shape[1] if len(test_X) > 0 else 0
    print(f"当前特征维度: {current_feature_dim}")
    cached_feature_dim = None
    if not force_retrain and cached_model is not None and last_training_time is not None:
        if last_training_time >= train_df['DateTime'].max():
            try:
                cached_feature_dim = cached_model.n_features_in_
                print(f"缓存模型特征维度: {cached_feature_dim}")
                if cached_feature_dim == current_feature_dim:
                    model_needs_training = False
                    print(f"使用缓存模型,训练时间: {last_training_time}")
                else:
                    print(f"特征维度不匹配(缓存模型: {cached_feature_dim},当前: {current_feature_dim}),需要重新训练")
            except Exception as e:
                print(f"检查模型特征维度失败: {e}")
    elif not force_retrain and os.path.exists(model_cache_file):
        try:
            with open(model_cache_file, 'rb') as f:
                model_data = pickle.load(f)
                cached_model = model_data['model']
                last_training_time = model_data['training_time']
                try:
                    cached_feature_dim = cached_model.n_features_in_
                    print(f"文件缓存模型特征维度: {cached_feature_dim}")
                    if cached_feature_dim == current_feature_dim:
                        if last_training_time >= train_df['DateTime'].max():
                            model_needs_training = False
                            print(f"从文件加载模型,训练时间: {last_training_time}")
                        else:
                            print(f"特征维度不匹配(文件模型: {cached_feature_dim},当前: {current_feature_dim}),需要重新训练")
                except Exception as e:
                    print(f"检查模型特征维度失败: {e}")
        except Exception as e:
            print("加载模型失败:", e)
    if model_needs_training:
        print("开始训练新模型...")
        start_train = time()
        # ç”Ÿæˆç‰¹å¾
        X, y = create_features_vectorized(train_df, look_back=168, forecast_horizon=1)
        if X is None or y is None:
            print("特征生成失败")
            return None, None, None, None
        if len(X) == 0 or len(y) == 0:
            print("样本生成不足,训练终止")
            return None, None, None, None
        print(f"训练样本数量: {X.shape[0]}, ç‰¹å¾ç»´åº¦: {X.shape[1]}")
        # æŒ‰æ—¶é—´é¡ºåºåˆ’分训练集和验证集(使用最后10%的数据作为验证集)
        split_idx = int(len(X) * 0.9)
        X_train, X_val = X[:split_idx], X[split_idx:]
        y_train, y_val = y[:split_idx], y[split_idx:]
        print(f"训练集大小: {len(X_train)}, éªŒè¯é›†å¤§å°: {len(X_val)}")
        # åˆ›å»ºæ¨¡åž‹
        model = XGBRegressor(
            n_estimators=200,
            learning_rate=0.1,
            max_depth=6,
            min_child_weight=2,
            subsample=0.8,
            colsample_bytree=0.8,
            gamma=0.1,
            reg_alpha=0.1,
            reg_lambda=1.0,
            n_jobs=-1,
            random_state=42,
            early_stopping_rounds=10
        )
        try:
            model.fit(X_train, y_train,
                     eval_set=[(X_val, y_val)],
                     eval_metric='rmse',
                     verbose=False)
            # åœ¨éªŒè¯é›†ä¸Šè®¡ç®— RMSE å’Œ MAE
            y_val_pred = model.predict(X_val)
            rmse = np.sqrt(mean_squared_error(y_val, y_val_pred))
            mae = mean_absolute_error(y_val, y_val_pred)
            print(f"验证集 RMSE: {rmse:.4f}, MAE: {mae:.4f}")
            # ç‰¹å¾é‡è¦æ€§åˆ†æž
            feature_importance = model.feature_importances_
            sorted_idx = np.argsort(feature_importance)[::-1]
            # ç”Ÿæˆç‰¹å¾åç§°
            feature_names = []
            numeric_columns = train_df.select_dtypes(include=[np.number]).columns.tolist()
            if 'DateTime' in numeric_columns:
                numeric_columns.remove('DateTime')
            # ä¸ºæ¯ä¸ªæ•°å€¼åˆ—添加特征名称
            for col in numeric_columns:
                feature_names.extend([f'{col}_t-{i}' for i in range(168)])
            # æ·»åŠ æ—¶é—´ç‰¹å¾åç§°
            feature_names.extend(['month', 'day', 'weekday'])
            # ç¡®ä¿ç‰¹å¾åç§°æ•°é‡ä¸Žé‡è¦æ€§æ•°ç»„长度匹配
            if len(feature_names) != len(feature_importance):
                print(f"警告: ç‰¹å¾åç§°æ•°é‡({len(feature_names)})与重要性数组长度({len(feature_importance)})不匹配")
                feature_names = feature_names[:len(feature_importance)]
            # æ‰“印前10个重要特征
            print("\nTop 10 é‡è¦ç‰¹å¾:")
            for i in range(min(10, len(sorted_idx))):
                print(f"{i+1}. {feature_names[sorted_idx[i]]}: {feature_importance[sorted_idx[i]]:.6f}")
            last_training_time = start_time
            cached_model = model
            with open(model_cache_file, 'wb') as f:
                pickle.dump({
                    'model': model,
                    'training_time': last_training_time,
                    'feature_columns': feature_names,
                    'rmse': rmse,
                    'mae': mae,
                    'feature_dim': current_feature_dim
                }, f)
            print(f"模型训练完成,耗时: {time() - start_train:.2f}秒,特征维度: {current_feature_dim}")
        except Exception as e:
            print("模型训练异常:", e)
            return None, None, None, None
    else:
        model = cached_model
    # é¢„测部分
    try:
        # åˆå§‹åŒ–存储预测结果的列表
        future_dates = [start_time + timedelta(days=i) for i in range(5)]
        predictions = np.zeros(5)
        # åˆ›å»ºé¢„测所需的特征矩阵
        X_pred = []
        for i in range(5):
            current_date = future_dates[i]
            features = generate_prediction_features(df, current_date, look_back=168)
            if features is None:
                print(f"生成预测特征失败: {current_date}")
                return None, None, None, None
            X_pred.append(features)
        # æ‰¹é‡é¢„测
        X_pred = np.array(X_pred)
        predictions = model.predict(X_pred)
        # è®¡ç®—预测的置信区间
        if model_needs_training:
            # ä½¿ç”¨è®­ç»ƒæ—¶çš„验证集误差
            y_train_pred = model.predict(X_train)
            train_std = np.std(y_train - y_train_pred)
        else:
            # ä½¿ç”¨æ¨¡åž‹ç¼“存中的RMSE作为误差估计
            try:
                with open(model_cache_file, 'rb') as f:
                    model_data = pickle.load(f)
                    train_std = model_data.get('rmse', 1.0)
            except:
                train_std = 1.0
        prediction_intervals = np.array([
            predictions - 1.96 * train_std,
            predictions + 1.96 * train_std
        ])
        return future_dates, predictions, model, prediction_intervals
    except Exception as e:
        print("预测过程异常:", e)
        return None, None, None, None
# -------------------------------
# GUI界面部分
# -------------------------------
def run_gui():
    def configure_gui_fonts():
        font_names = ['微软雅黑', 'Microsoft YaHei', 'SimSun', 'SimHei']
        for font_name in font_names:
            try:
                default_font = tkfont.nametofont("TkDefaultFont")
                default_font.configure(family=font_name)
                text_font = tkfont.nametofont("TkTextFont")
                text_font.configure(family=font_name)
                fixed_font = tkfont.nametofont("TkFixedFont")
                fixed_font.configure(family=font_name)
                return True
            except Exception as e:
                continue
        return False
    def on_predict():
        try:
            predict_start = time()
            status_label.config(text="预测中...")
            root.update()
            # æ¸…理之前的图表和事件连接
            ax.clear()
            # ç§»é™¤æ‰€æœ‰äº‹ä»¶è¿žæŽ¥
            for cid in event_cids:
                canvas.mpl_disconnect(cid)
            event_cids.clear()
            # æ£€æŸ¥å¹¶æ¸…理第二个轴
            for ax_in_fig in fig.get_axes():
                if ax_in_fig != ax:
                    ax_in_fig.remove()
            start_time_dt = pd.to_datetime(entry.get())
            force_retrain = retrain_var.get()
            future_dates, predictions, model, prediction_intervals = train_and_predict(df, start_time_dt, force_retrain)
            if future_dates is None or predictions is None:
                status_label.config(text="预测失败")
                return
            # èŽ·å–å¹¶æ˜¾ç¤ºæ¨¡åž‹å‡†ç¡®åº¦æŒ‡æ ‡
            model_metrics = get_model_metrics()
            if model_metrics:
                metrics_text = f"模型准确度 - RMSE: {model_metrics['rmse']:.4f}, MAE: {model_metrics['mae']:.4f}"
                metrics_label.config(text=metrics_text)
            # åˆ›å»ºåŒy轴图表
            ax2 = None
            has_water_level = 'water_level' in df.columns and 'water_level_smooth' in df.columns
            if has_water_level:
                try:
                    ax2 = ax.twinx()
                except Exception as e:
                    print(f"创建双y轴失败: {e}")
                    ax2 = None
            # ç»˜åˆ¶åŽ†å²æ•°æ®ï¼ˆæœ€è¿‘ 120 å¤©ï¼‰
            history_end = min(start_time_dt, df['DateTime'].max())
            history_start = history_end - timedelta(days=120)
            hist_data = df[(df['DateTime'] >= history_start) & (df['DateTime'] <= history_end)]
            # ç¡®ä¿æ•°æ®ä¸ä¸ºç©º
            if len(hist_data) == 0:
                status_label.config(text="错误: æ‰€é€‰æ—¶é—´èŒƒå›´å†…没有历史数据")
                return
            # ç»˜åˆ¶åŸºæœ¬æ•°æ®
            ax.plot(hist_data['DateTime'], hist_data['downstream_smooth'],
                    label='一取水(下游)盐度', color='blue', linewidth=1.5)
            ax.plot(hist_data['DateTime'], hist_data['upstream_smooth'],
                    label='青龙港(上游)盐度', color='purple', linewidth=1.5, alpha=0.7)
            # ç»˜åˆ¶æ°´ä½æ•°æ®ï¼ˆå¦‚果有)
            if ax2 is not None and has_water_level:
                try:
                    # æ£€æŸ¥æ°´ä½æ•°æ®æ˜¯å¦æœ‰è¶³å¤Ÿçš„非NaN值
                    valid_water_level = hist_data['water_level_smooth'].dropna()
                    if len(valid_water_level) > 10:  # è‡³å°‘有10个有效值
                        # åªæ˜¾ç¤ºåˆ°è¾“入时刻的水位数据
                        water_level_data = hist_data[hist_data['DateTime'] <= start_time_dt]
                        ax2.plot(water_level_data['DateTime'], water_level_data['water_level_smooth'],
                                label='长江水位', color='green', linewidth=1.5, linestyle='--')
                        ax2.set_ylabel('水位', color='green')
                        ax2.tick_params(axis='y', labelcolor='green')
                    else:
                        print("水位数据有效值不足,跳过水位图")
                except Exception as e:
                    print(f"绘制水位数据时出错: {e}")
            # ç»˜åˆ¶é¢„测数据
            if len(future_dates) > 0 and len(predictions) > 0:
                ax.plot(future_dates, predictions, marker='o', linestyle='--',
                        label='递归预测盐度', color='red', linewidth=2)
                # æ·»åŠ é¢„æµ‹çš„ç½®ä¿¡åŒºé—´
                if prediction_intervals is not None:
                    ax.fill_between(future_dates, prediction_intervals[0], prediction_intervals[1],
                                   color='red', alpha=0.2, label='95% ç½®ä¿¡åŒºé—´')
            # ç»˜åˆ¶å®žé™…数据(如果有)
            actual_data = df[(df['DateTime'] >= start_time_dt) & (df['DateTime'] <= future_dates[-1])]
            actual_values = None
            if not actual_data.empty:
                actual_values = []
                # èŽ·å–ä¸Žé¢„æµ‹æ—¥æœŸæœ€æŽ¥è¿‘çš„å®žé™…æ•°æ®
                for pred_date in future_dates:
                    # åªèŽ·å–åœ¨æ•°æ®èŒƒå›´å†…çš„å®žé™…å€¼
                    if pred_date <= df['DateTime'].max():
                        closest_idx = np.argmin(np.abs(actual_data['DateTime'] - pred_date))
                        actual_values.append(actual_data['downstream_smooth'].iloc[closest_idx])
                    else:
                        # å¯¹äºŽè¶…出数据范围的日期,使用None表示无实际值
                        actual_values.append(None)
                # ç»˜åˆ¶å®žé™…盐度曲线(只绘制有实际值的点)
                valid_dates = [date for date, val in zip(future_dates, actual_values) if val is not None]
                valid_values = [val for val in actual_values if val is not None]
                if valid_dates and valid_values:
                    ax.plot(valid_dates, valid_values, marker='s', linestyle='-',
                        label='实际盐度', color='orange', linewidth=2)
            # è®¾ç½®å›¾è¡¨æ ‡é¢˜å’Œæ ‡ç­¾
            ax.set_xlabel('日期')
            ax.set_ylabel('盐度')
            ax.set_title(f"从 {start_time_dt.strftime('%Y-%m-%d %H:%M:%S')} å¼€å§‹çš„递归单步盐度预测")
            # è®¾ç½®å›¾ä¾‹å¹¶åº”用紧凑布局
            if ax2 is not None:
                try:
                    lines1, labels1 = ax.get_legend_handles_labels()
                    lines2, labels2 = ax2.get_legend_handles_labels()
                    if lines2:  # ç¡®ä¿æ°´ä½æ•°æ®å·²ç»˜åˆ¶
                        # åˆå¹¶å›¾ä¾‹ï¼Œé¿å…é‡å¤
                        all_lines = lines1 + lines2
                        all_labels = labels1 + labels2
                        # ç§»é™¤é‡å¤çš„æ ‡ç­¾
                        unique_labels = []
                        unique_lines = []
                        for label, line in zip(all_labels, all_lines):
                            if label not in unique_labels:
                                unique_labels.append(label)
                                unique_lines.append(line)
                        ax.legend(unique_lines, unique_labels, loc='best')
                    else:
                        ax.legend(loc='best')
                except Exception as e:
                    print(f"创建图例时出错: {e}")
                    ax.legend(loc='best')
            else:
                ax.legend(loc='best')
            fig.tight_layout()
            # å¼ºåˆ¶é‡ç»˜
            plt.close(fig)
            fig.canvas.draw()
            fig.canvas.flush_events()
            plt.draw()
            # ä¿å­˜é»˜è®¤è§†å›¾èŒƒå›´
            current_view['xlim'] = ax.get_xlim()
            current_view['ylim'] = ax.get_ylim()
            if ax2 is not None:
                current_view['ylim2'] = ax2.get_ylim()
            # åˆå§‹åŒ–拖动属性
            ax._pan_start = None
            # é‡æ–°è¿žæŽ¥æ‰€æœ‰äº‹ä»¶
            event_cids.append(canvas.mpl_connect('resize_event', on_resize))
            event_cids.append(canvas.mpl_connect('scroll_event', on_scroll))
            event_cids.append(canvas.mpl_connect('button_press_event', on_press))
            event_cids.append(canvas.mpl_connect('button_release_event', on_release))
            event_cids.append(canvas.mpl_connect('motion_notify_event', on_motion))
            # æ›´æ–°é¢„测结果文本
            predict_time = time() - predict_start
            status_label.config(text=f"递归预测完成 (耗时: {predict_time:.2f}秒)")
            # æ˜¾ç¤ºé¢„测结果
            result_text = "递归单步预测结果:\n\n"
            # å¦‚果有实际值,计算差值和百分比误差
            if actual_values is not None:
                result_text += "日期         é¢„测值      å®žé™…值     å·®å€¼\n"
                result_text += "--------------------------------------\n"
                for i, (date, pred, actual) in enumerate(zip(future_dates, predictions, actual_values)):
                    if actual is not None:  # åªåœ¨æœ‰å®žé™…值时显示差值
                        diff = pred - actual
                        result_text += f"{date.strftime('%Y-%m-%d')}  {pred:6.2f}    {actual:6.2f}    {diff:6.2f}\n"
                    else:
                        result_text += f"{date.strftime('%Y-%m-%d')}  {pred:6.2f}    --         --\n"
            else:
                result_text += "日期         é¢„测值\n"
                result_text += "-------------------\n"
                for i, (date, pred) in enumerate(zip(future_dates, predictions)):
                    result_text += f"{date.strftime('%Y-%m-%d')}  {pred:6.2f}\n"
                result_text += "\n无实际值进行对比"
            update_result_text(result_text)
        except Exception as e:
            status_label.config(text=f"错误: {str(e)}")
            import traceback
            traceback.print_exc()
    def reset_view():
        # èŽ·å–å½“å‰æ´»åŠ¨çš„è½´
        current_ax2 = None
        for ax_in_fig in fig.get_axes():
            if ax_in_fig != ax and hasattr(ax_in_fig, 'get_shared_x_axes'):
                current_ax2 = ax_in_fig
                break
        # æ¢å¤ä¿å­˜çš„视图范围
        if current_view['xlim'] is not None:
            ax.set_xlim(current_view['xlim'])
            ax.set_ylim(current_view['ylim'])
            if current_ax2 is not None and current_view['ylim2'] is not None:
                current_ax2.set_ylim(current_view['ylim2'])
        # åº”用紧凑布局并重绘
        fig.tight_layout()
        canvas.draw_idle()
        status_label.config(text="图表视图已重置")
    def on_scroll(event):
        # èŽ·å–å½“å‰æ´»åŠ¨çš„è½´
        current_ax2 = None
        for ax_in_fig in fig.get_axes():
            if ax_in_fig != ax and hasattr(ax_in_fig, 'get_shared_x_axes'):
                current_ax2 = ax_in_fig
                break
        # æ£€æŸ¥é¼ æ ‡æ˜¯å¦åœ¨ä»»ä¸€è½´åŒºåŸŸå†…
        if event.inaxes != ax and (current_ax2 is None or event.inaxes != current_ax2):
            return
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        zoom_factor = 1.1
        x_data = event.xdata if event.xdata is not None else (xlim[0]+xlim[1])/2
        y_data = event.ydata if event.ydata is not None else (ylim[0]+ylim[1])/2
        x_rel = (x_data - xlim[0]) / (xlim[1] - xlim[0])
        y_rel = (y_data - ylim[0]) / (ylim[1] - ylim[0])
        if event.step > 0:
            new_width = (xlim[1]-xlim[0]) / zoom_factor
            new_height = (ylim[1]-ylim[0]) / zoom_factor
            x0 = x_data - x_rel * new_width
            y0 = y_data - y_rel * new_height
            ax.set_xlim([x0, x0+new_width])
            ax.set_ylim([y0, y0+new_height])
            # å¦‚果有第二个轴,同步更新
            if current_ax2 is not None:
                ylim2 = current_ax2.get_ylim()
                # è®¡ç®—第二个轴的缩放比例
                y_scale2 = (ylim2[1] - ylim2[0]) / (ylim[1] - ylim[0])
                # è®¡ç®—第二个轴的新高度
                new_height2 = new_height * y_scale2
                # è®¡ç®—第二个轴的新y0
                y02 = ylim2[0] + (y0 - ylim[0]) * y_scale2
                current_ax2.set_xlim([x0, x0+new_width])
                current_ax2.set_ylim([y02, y02+new_height2])
        else:
            new_width = (xlim[1]-xlim[0]) * zoom_factor
            new_height = (ylim[1]-ylim[0]) * zoom_factor
            x0 = x_data - x_rel * new_width
            y0 = y_data - y_rel * new_height
            ax.set_xlim([x0, x0+new_width])
            ax.set_ylim([y0, y0+new_height])
            # å¦‚果有第二个轴,同步更新
            if current_ax2 is not None:
                ylim2 = current_ax2.get_ylim()
                # è®¡ç®—第二个轴的缩放比例
                y_scale2 = (ylim2[1] - ylim2[0]) / (ylim[1] - ylim[0])
                # è®¡ç®—第二个轴的新高度
                new_height2 = new_height * y_scale2
                # è®¡ç®—第二个轴的新y0
                y02 = ylim2[0] + (y0 - ylim[0]) * y_scale2
                current_ax2.set_xlim([x0, x0+new_width])
                current_ax2.set_ylim([y02, y02+new_height2])
        canvas.draw_idle()
    def on_motion(event):
        if not hasattr(ax, '_pan_start') or ax._pan_start is None:
            # èŽ·å–å½“å‰æ´»åŠ¨çš„è½´
            current_ax2 = None
            for ax_in_fig in fig.get_axes():
                if ax_in_fig != ax and hasattr(ax_in_fig, 'get_shared_x_axes'):
                    current_ax2 = ax_in_fig
                    break
            # æ£€æŸ¥é¼ æ ‡æ˜¯å¦åœ¨ä»»ä¸€è½´åŒºåŸŸå†…
            if event.inaxes == ax or (current_ax2 is not None and event.inaxes == current_ax2):
                canvas.get_tk_widget().config(cursor="fleur")
            else:
                canvas.get_tk_widget().config(cursor="")
            return
        # èŽ·å–å½“å‰æ´»åŠ¨çš„è½´
        current_ax2 = None
        for ax_in_fig in fig.get_axes():
            if ax_in_fig != ax and hasattr(ax_in_fig, 'get_shared_x_axes'):
                current_ax2 = ax_in_fig
                break
        # æ£€æŸ¥é¼ æ ‡æ˜¯å¦åœ¨ä»»ä¸€è½´åŒºåŸŸå†…
        if event.inaxes != ax and (current_ax2 is None or event.inaxes != current_ax2):
            return
        start_x, start_y, x_data, y_data = ax._pan_start
        dx = event.x - start_x
        dy = event.y - start_y
        # èŽ·å–å½“å‰è§†å›¾
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        # è®¡ç®—图表坐标系中的移动
        x_scale = (xlim[1] - xlim[0]) / canvas.get_tk_widget().winfo_width()
        y_scale = (ylim[1] - ylim[0]) / canvas.get_tk_widget().winfo_height()
        # æ›´æ–°è§†å›¾
        new_xlim = (xlim[0] - dx * x_scale, xlim[1] - dx * x_scale)
        new_ylim = (ylim[0] + dy * y_scale, ylim[1] + dy * y_scale)
        ax.set_xlim(new_xlim)
        ax.set_ylim(new_ylim)
        # å¦‚果有第二个轴,同步更新
        if current_ax2 is not None:
            # èŽ·å–ç¬¬äºŒä¸ªè½´çš„å½“å‰èŒƒå›´
            ylim2 = current_ax2.get_ylim()
            # è®¡ç®—第二个轴的移动比例
            y_scale2 = (ylim2[1] - ylim2[0]) / (ylim[1] - ylim[0])
            # æ›´æ–°ç¬¬äºŒä¸ªè½´çš„范围
            new_ylim2 = (ylim2[0] + dy * y_scale * y_scale2, ylim2[1] + dy * y_scale * y_scale2)
            current_ax2.set_xlim(new_xlim)
            current_ax2.set_ylim(new_ylim2)
        # æ›´æ–°æ‹–动起点
        ax._pan_start = (event.x, event.y, event.xdata, event.ydata)
        canvas.draw_idle()
    root = tk.Tk()
    root.title("青龙港-陈行盐度预测系统")
    try:
        configure_gui_fonts()
    except Exception as e:
        print("字体配置异常:", e)
    # æ¢å¤è¾“入框和控制按钮
    input_frame = ttk.Frame(root, padding="10")
    input_frame.pack(fill=tk.X)
    ttk.Label(input_frame, text="输入开始时间 (YYYY-MM-DD HH:MM:SS)").pack(side=tk.LEFT)
    entry = ttk.Entry(input_frame, width=25)
    entry.pack(side=tk.LEFT, padx=5)
    predict_button = ttk.Button(input_frame, text="预测", command=on_predict)
    predict_button.pack(side=tk.LEFT)
    status_label = ttk.Label(input_frame, text="提示: ç¬¬ä¸€æ¬¡è¿è¡Œè¯·å‹¾é€‰'强制重新训练模型'")
    status_label.pack(side=tk.LEFT, padx=10)
    control_frame = ttk.Frame(root, padding="5")
    control_frame.pack(fill=tk.X)
    retrain_var = tk.BooleanVar(value=False)
    ttk.Checkbutton(control_frame, text="强制重新训练模型", variable=retrain_var).pack(side=tk.LEFT)
    # æ›´æ–°å›¾ä¾‹è¯´æ˜Žï¼ŒåŠ å…¥æ°´ä½æ•°æ®ä¿¡æ¯
    if 'water_level' in df.columns:
        legend_label = ttk.Label(control_frame, text="图例: ç´«è‰²=青龙港上游数据, è“è‰²=一取水下游数据, çº¢è‰²=预测值, ç»¿è‰²=长江水位")
    else:
        legend_label = ttk.Label(control_frame, text="图例: ç´«è‰²=青龙港上游数据, è“è‰²=一取水下游数据, çº¢è‰²=预测值, æ©™è‰²=实际值")
    legend_label.pack(side=tk.LEFT, padx=10)
    reset_button = ttk.Button(control_frame, text="重置视图", command=reset_view)
    reset_button.pack(side=tk.LEFT, padx=5)
    # æ·»åŠ æ˜¾ç¤ºæ¨¡åž‹å‡†ç¡®åº¦çš„æ ‡ç­¾
    metrics_frame = ttk.Frame(root, padding="5")
    metrics_frame.pack(fill=tk.X)
    model_metrics = get_model_metrics()
    metrics_text = "模型准确度: æœªçŸ¥" if not model_metrics else f"模型准确度 - RMSE: {model_metrics['rmse']:.4f}, MAE: {model_metrics['mae']:.4f}"
    metrics_label = ttk.Label(metrics_frame, text=metrics_text)
    metrics_label.pack(side=tk.LEFT, padx=10)
    # ç»“果显示区域
    result_frame = ttk.Frame(root, padding="10")
    result_frame.pack(fill=tk.BOTH, expand=True)
    # å·¦ä¾§æ”¾ç½®å›¾è¡¨
    plot_frame = ttk.Frame(result_frame, width=800, height=600)
    plot_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
    plot_frame.pack_propagate(False)  # ä¸å…è®¸æ¡†æž¶æ ¹æ®å†…容调整大小
    # å³ä¾§æ”¾ç½®æ–‡æœ¬ç»“æžœ
    text_frame = ttk.Frame(result_frame)
    text_frame.pack(side=tk.RIGHT, fill=tk.Y)
    # ä½¿ç”¨ç­‰å®½å­—体显示结果
    result_font = tkfont.Font(family="Courier New", size=10, weight="normal")
    # æ·»åŠ æ–‡æœ¬æ¡†å’Œæ»šåŠ¨æ¡
    result_text = tk.Text(text_frame, width=50, height=25, font=result_font, wrap=tk.NONE)
    result_text.pack(side=tk.LEFT, fill=tk.BOTH)
    result_scroll = ttk.Scrollbar(text_frame, orient="vertical", command=result_text.yview)
    result_scroll.pack(side=tk.RIGHT, fill=tk.Y)
    result_text.configure(yscrollcommand=result_scroll.set)
    result_text.configure(state=tk.DISABLED)  # åˆå§‹è®¾ä¸ºåªè¯»
    # æ›´æ–°ç»“果文本的函数
    def update_result_text(text):
        result_text.configure(state=tk.NORMAL)
        result_text.delete(1.0, tk.END)
        result_text.insert(tk.END, text)
        result_text.configure(state=tk.DISABLED)
    # åˆ›å»ºæ›´é«˜DPI的图形以获得更好的显示质量
    fig, ax = plt.subplots(figsize=(10, 6), dpi=100)
    fig.tight_layout(pad=3.0)  # å¢žåŠ å†…è¾¹è·ï¼Œé˜²æ­¢æ ‡ç­¾è¢«æˆªæ–­
    # åˆ›å»ºç”»å¸ƒå¹¶æ·»åŠ åˆ°å›ºå®šå¤§å°çš„æ¡†æž¶
    canvas = FigureCanvasTkAgg(fig, master=plot_frame)
    canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True)
    # æ·»åŠ å·¥å…·æ ï¼ŒåŒ…å«ç¼©æ”¾ã€ä¿å­˜ç­‰åŠŸèƒ½
    toolbar_frame = ttk.Frame(plot_frame)
    toolbar_frame.pack(side=tk.BOTTOM, fill=tk.X)
    toolbar = NavigationToolbar2Tk(canvas, toolbar_frame)
    toolbar.update()
    # å¯ç”¨ç´§å‡‘布局,并设置自动调整以使图表完全显示
    def on_resize(event):
        fig.tight_layout()
        canvas.draw_idle()
    # æ·»åŠ å›¾è¡¨äº¤äº’åŠŸèƒ½
    canvas.mpl_connect('resize_event', on_resize)
    # æ·»åŠ é¼ æ ‡æ‹–åŠ¨åŠŸèƒ½
    def on_press(event):
        # èŽ·å–å½“å‰æ´»åŠ¨çš„è½´
        current_ax2 = None
        for ax_in_fig in fig.get_axes():
            if ax_in_fig != ax and hasattr(ax_in_fig, 'get_shared_x_axes'):
                current_ax2 = ax_in_fig
                break
        # æ£€æŸ¥é¼ æ ‡æ˜¯å¦åœ¨ä»»ä¸€è½´åŒºåŸŸå†…
        if event.inaxes == ax or (current_ax2 is not None and event.inaxes == current_ax2):
            canvas.get_tk_widget().config(cursor="fleur")
            ax._pan_start = (event.x, event.y, event.xdata, event.ydata)
        else:
            ax._pan_start = None
    def on_release(event):
        ax._pan_start = None
        canvas.get_tk_widget().config(cursor="")
        canvas.draw_idle()
    def on_motion(event):
        if not hasattr(ax, '_pan_start') or ax._pan_start is None:
            # èŽ·å–å½“å‰æ´»åŠ¨çš„è½´
            current_ax2 = None
            for ax_in_fig in fig.get_axes():
                if ax_in_fig != ax and hasattr(ax_in_fig, 'get_shared_x_axes'):
                    current_ax2 = ax_in_fig
                    break
            # æ£€æŸ¥é¼ æ ‡æ˜¯å¦åœ¨ä»»ä¸€è½´åŒºåŸŸå†…
            if event.inaxes == ax or (current_ax2 is not None and event.inaxes == current_ax2):
                canvas.get_tk_widget().config(cursor="fleur")
            else:
                canvas.get_tk_widget().config(cursor="")
            return
        # èŽ·å–å½“å‰æ´»åŠ¨çš„è½´
        current_ax2 = None
        for ax_in_fig in fig.get_axes():
            if ax_in_fig != ax and hasattr(ax_in_fig, 'get_shared_x_axes'):
                current_ax2 = ax_in_fig
                break
        # æ£€æŸ¥é¼ æ ‡æ˜¯å¦åœ¨ä»»ä¸€è½´åŒºåŸŸå†…
        if event.inaxes != ax and (current_ax2 is None or event.inaxes != current_ax2):
            return
        start_x, start_y, x_data, y_data = ax._pan_start
        dx = event.x - start_x
        dy = event.y - start_y
        # èŽ·å–å½“å‰è§†å›¾
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        # è®¡ç®—图表坐标系中的移动
        x_scale = (xlim[1] - xlim[0]) / canvas.get_tk_widget().winfo_width()
        y_scale = (ylim[1] - ylim[0]) / canvas.get_tk_widget().winfo_height()
        # æ›´æ–°è§†å›¾
        new_xlim = (xlim[0] - dx * x_scale, xlim[1] - dx * x_scale)
        new_ylim = (ylim[0] + dy * y_scale, ylim[1] + dy * y_scale)
        ax.set_xlim(new_xlim)
        ax.set_ylim(new_ylim)
        # å¦‚果有第二个轴,同步更新
        if current_ax2 is not None:
            # èŽ·å–ç¬¬äºŒä¸ªè½´çš„å½“å‰èŒƒå›´
            ylim2 = current_ax2.get_ylim()
            # è®¡ç®—第二个轴的移动比例
            y_scale2 = (ylim2[1] - ylim2[0]) / (ylim[1] - ylim[0])
            # æ›´æ–°ç¬¬äºŒä¸ªè½´çš„范围
            new_ylim2 = (ylim2[0] + dy * y_scale * y_scale2, ylim2[1] + dy * y_scale * y_scale2)
            current_ax2.set_xlim(new_xlim)
            current_ax2.set_ylim(new_ylim2)
        # æ›´æ–°æ‹–动起点
        ax._pan_start = (event.x, event.y, event.xdata, event.ydata)
        canvas.draw_idle()
    # è¿žæŽ¥é¼ æ ‡äº‹ä»¶
    canvas.mpl_connect('button_press_event', on_press)
    canvas.mpl_connect('button_release_event', on_release)
    canvas.mpl_connect('motion_notify_event', on_motion)
    # åˆå§‹åŒ–图表显示
    def init_plot():
        try:
            ax.clear()
            ax.set_xlabel('日期')
            ax.set_ylabel('盐度')
            ax.set_title('盐度预测')
            fig.tight_layout()
            canvas.draw()
        except Exception as e:
            status_label.config(text=f"初始化图表时出错: {str(e)}")
    init_plot()
    root.mainloop()
def resample_to_hourly(df):
    """
    æ•°æ®ä¿®æ”¹ï¼Œè®¡ç®—每小时的平均值
    """
    try:
        # ç¡®ä¿DateTime是索引
        if 'DateTime' in df.columns:
            df = df.set_index('DateTime')
        # èŽ·å–æ‰€æœ‰æ•°å€¼åˆ—
        numeric_columns = df.select_dtypes(include=[np.number]).columns.tolist()
        # æŒ‰å°æ—¶é‡é‡‡æ ·ï¼Œè®¡ç®—平均值
        hourly_df = df[numeric_columns].resample('H').mean()
        # é‡ç½®ç´¢å¼•,将DateTime作为列
        hourly_df = hourly_df.reset_index()
        print(f"数据已从分钟级重采样为小时级,原始数据行数: {len(df)},重采样后行数: {len(hourly_df)}")
        return hourly_df
    except Exception as e:
        print(f"重采样数据异常: {e}")
        return df
# -------------------------------
# ä¸»ç¨‹åºå…¥å£ï¼šåŠ è½½æ•°æ®ã€æ·»åŠ ç‰¹å¾ã€ç”Ÿæˆå»¶è¿Ÿç‰¹å¾åŽå¯åŠ¨GUI
# -------------------------------
def save_processed_data(df, filename='processed_data.pkl'):
    try:
        df.to_pickle(filename)
        print(f"已保存处理后的数据到 {filename}")
        return True
    except Exception as e:
        print(f"保存数据失败: {e}")
        return False
def load_processed_data(filename='processed_data.pkl'):
    try:
        if os.path.exists(filename):
            df = pd.read_pickle(filename)
            print(f"已从 {filename} åŠ è½½å¤„ç†åŽçš„æ•°æ®")
            return df
        else:
            print(f"找不到处理后的数据文件 {filename}")
            return None
    except Exception as e:
        print(f"加载数据失败: {e}")
        return None
# # åˆ é™¤æ—§çš„处理数据(如果存在),以应用修复后的代码
# if os.path.exists('processed_data.pkl'):
#     try:
#         os.remove('processed_data.pkl')
#         print("已删除旧的处理数据缓存,将使用修复后的代码重新处理数据")
#     except Exception as e:
#         print(f"删除缓存文件失败: {e}")
# # åˆ é™¤æ—§çš„æ¨¡åž‹æ–‡ä»¶ï¼ˆå¦‚果存在)
# if os.path.exists('salinity_model.pkl'):
#     try:
#         os.remove('salinity_model.pkl')
#         print("已删除旧的模型文件,将重新训练模型")
#     except Exception as e:
#         print(f"删除模型文件失败: {e}")
# å°è¯•加载处理后的数据,如果不存在则重新处理
processed_data = load_processed_data()
if processed_data is not None:
    df = processed_data
else:
    # æ·»åŠ é•¿æ±Ÿæ¶²ä½æ•°æ®ä½œä¸ºå‚æ•°
    df = load_data('青龙港1.csv', '一取水.csv', '长江液位.csv', '大通流量.csv')
    if df is not None:
        # æ·»åŠ æ—¶é—´ç‰¹å¾
        df['hour'] = df['DateTime'].dt.hour
        df['weekday'] = df['DateTime'].dt.dayofweek
        df['month'] = df['DateTime'].dt.month
        # æ·»åŠ å†œåŽ†ç‰¹å¾
        df = add_lunar_features(df)
        # æ·»åŠ å»¶è¿Ÿç‰¹å¾ä¸Šæ¸¸åˆ°ä¸‹æ¸¸3-5天,暂时每12小时为一个节点,根据效果后续再调整
        # delay_hours = [1,2,3,4,6,12,24,36,48,60,72,84,96,108,120]
        delay_hours = [72,84,96,108,120]
        df = batch_create_delay_features(df, delay_hours)
        # æ·»åŠ ç»Ÿè®¡ç‰¹å¾
        df['mean_1d_up'] = df['upstream_smooth'].rolling(window=24, min_periods=1).mean()
        df['mean_3d_up'] = df['upstream_smooth'].rolling(window=72, min_periods=1).mean()
        df['std_1d_up'] = df['upstream_smooth'].rolling(window=24, min_periods=1).std()
        df['max_1d_up'] = df['upstream_smooth'].rolling(window=24, min_periods=1).max()
        df['min_1d_up'] = df['upstream_smooth'].rolling(window=24, min_periods=1).min()
        df['mean_1d_down'] = df['downstream_smooth'].rolling(window=24, min_periods=1).mean()
        df['mean_3d_down'] = df['downstream_smooth'].rolling(window=72, min_periods=1).mean()
        df['std_1d_down'] = df['downstream_smooth'].rolling(window=24, min_periods=1).std()
        df['max_1d_down'] = df['downstream_smooth'].rolling(window=24, min_periods=1).max()
        df['min_1d_down'] = df['downstream_smooth'].rolling(window=24, min_periods=1).min()
        # æ·»åŠ æ°´ä½ç»Ÿè®¡ç‰¹å¾ï¼ˆå¦‚æžœæ°´ä½æ•°æ®å­˜åœ¨ï¼‰
        if 'water_level' in df.columns:
            # é¦–先创建水位平滑特征
            if 'water_level_smooth' not in df.columns:
                df['water_level_smooth'] = df['water_level'].rolling(window=24, min_periods=1, center=True).mean()
                df['water_level_smooth'] = df['water_level_smooth'].fillna(df['water_level'])
            # æ·»åŠ æ°´ä½ç»Ÿè®¡ç‰¹å¾
            df['mean_1d_water_level'] = df['water_level_smooth'].rolling(window=24, min_periods=1).mean()
            df['mean_3d_water_level'] = df['water_level_smooth'].rolling(window=72, min_periods=1).mean()
            df['std_1d_water_level'] = df['water_level_smooth'].rolling(window=24, min_periods=1).std()
            df['max_1d_water_level'] = df['water_level_smooth'].rolling(window=24, min_periods=1).max()
            df['min_1d_water_level'] = df['water_level_smooth'].rolling(window=24, min_periods=1).min()
            # è®¡ç®—水位变化率
            df['water_level_change_1h'] = df['water_level_smooth'].diff()
            df['water_level_change_24h'] = df['water_level_smooth'].diff(24)
            # è®¡ç®—水位与盐度的相关特征
            df['water_level_sal_ratio'] = df['water_level_smooth'] / df['downstream_smooth']
            print("水位特征已添加")
        # æ·»åŠ å…¶ä»–ç‰¹å¾
        df = generate_features(df)
        # å°†æ•°æ®é‡é‡‡æ ·ä¸ºå°æ—¶çº§
        df = resample_to_hourly(df)
        # ä¿å­˜å¤„理后的数据
        df.to_csv('merged_data_hour.csv', index=False)
        print(f"Merged data saved to 'merged_data_hour.csv' successfully")
        save_processed_data(df)
if df is not None:
    run_gui()
else:
    print("数据加载失败,无法运行预测。")