import pandas as pd
import matplotlib.pyplot as plt
import dataframe_image as dfi
import os

# --- Configuration
CSV_FILE           = 'optimized_results.csv'
OUTPUT_IMAGE_HIST  = 'WalkForward_Histogram.png'
OUTPUT_IMAGE_TABLE = 'WalkForward_Table.png'
OUTPUT_TEXT        = 'WalkForward_Analysis_Report.txt'


def run_optimized_analysis():
    if not os.path.exists(CSV_FILE):
        print(f"Error: {CSV_FILE} not found.")
        return

    df = pd.read_csv(CSV_FILE)
    report_text  = "# Walk-Forward Optimized Analysis\n\n"
    symbols      = df['Symbol'].unique()
    final_comparison_rows = []

    for symbol in symbols:
        report_text += f"## Walk-Forward Analysis for {symbol}\n\n"
        symbol_data = df[df['Symbol'] == symbol]

        for indicator in symbol_data['Indicator_Name'].unique():
            ind_data = symbol_data[symbol_data['Indicator_Name'] == indicator]

            is_data  = ind_data[ind_data['Test_Phase'].str.contains(
                'InSample', case=False, na=False)]
            oos_data = ind_data[ind_data['Test_Phase'].str.contains(
                'OutSample', case=False, na=False)]

            if is_data.empty or oos_data.empty:
                continue

            # --- Locate the highest Sortino row in the InSample set
            best_is_idx = is_data['Sortino_Ratio'].idxmax()
            best_is_row = is_data.loc[best_is_idx]
            best_period = best_is_row['Filter_Period']
            best_lr     = best_is_row.get('Param_SAMA_LR', 0)

            # --- Match OutSample row using exact parameter values (no substitution)
            matching_oos = oos_data[
                (oos_data['Filter_Period'] == best_period) &
                (
                    (oos_data['Param_SAMA_LR'] == best_lr) |
                    (oos_data['Param_SAMA_LR'].isna())
                )
            ]

            if not matching_oos.empty:
                best_oos_row = matching_oos.iloc[0]

                final_comparison_rows.append({
                    'Symbol'     : symbol,
                    'Indicator'  : indicator,
                    'IS_Sortino' : best_is_row['Sortino_Ratio'],
                    'OOS_Sortino': best_oos_row['Sortino_Ratio'],
                    'IS_Profit'  : best_is_row['Net_Profit_$'],
                    'OOS_Profit' : best_oos_row['Net_Profit_$']
                })

                profit_drop = (
                    (best_is_row['Net_Profit_$'] - best_oos_row['Net_Profit_$']) /
                    best_is_row['Net_Profit_$']
                ) * 100

                report_text += (
                    f"**{indicator}**: The optimal parameters from the InSample phase "
                    f"(Period: {best_period}) "
                    f"yielded a Sortino Ratio of {best_is_row['Sortino_Ratio']:.2f}. "
                    f"When these exact parameters were applied to the blind OutSample data, "
                    f"the Sortino Ratio shifted to {best_oos_row['Sortino_Ratio']:.2f}, "
                    f"with a profit variance of {profit_drop:.1f}%.\n"
                )

        report_text += "\n"

    if not final_comparison_rows:
        print("No matched IS/OOS pairs found. Check Test_Phase labels in the CSV.")
        return

    results_df = pd.DataFrame(final_comparison_rows)

    # --- Walk-forward degradation histogram (980px layout constraint)
    fig, ax = plt.subplots(figsize=(9.8, 6), dpi=100)
    plot_data = results_df.groupby('Indicator')[['IS_Sortino', 'OOS_Sortino']].mean()
    plot_data.plot(kind='bar', ax=ax)
    plt.title('Walk-Forward Validation: InSample vs OutSample Sortino Ratio Degradation (IS vs OOS)')
    plt.ylabel('Sortino Ratio')
    plt.xticks(rotation=0)
    fig.tight_layout()
    plt.savefig(OUTPUT_IMAGE_HIST)
    plt.close()

    # --- Styled summary table image
    styled_table = (
        plot_data.style
        .background_gradient(cmap='Purples')
        .format("{:.2f}")
    )
    dfi.export(styled_table, OUTPUT_IMAGE_TABLE, max_cols=-1, max_rows=-1)

    with open(OUTPUT_TEXT, 'w') as f:
        f.write(report_text)

    print(f"Walk-forward analysis complete. Report: {OUTPUT_TEXT}")


if __name__ == "__main__":
    run_optimized_analysis()