from flask import Flask, request, jsonify, send_file
import numpy as np
import talib
import json
import pandas as pd
import matplotlib
# headless
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import mplfinance as mpf
from mplfinance import make_marketcolors, make_mpf_style
import os, uuid, logging
from matplotlib.lines import Line2D

app = Flask(__name__)
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
app.logger.setLevel(logging.INFO)

# load all TA‑Lib candlestick functions
CDL_FUNCS = {
    name: getattr(talib, name)
    for name in talib.get_functions() if name.startswith("CDL")
}

# manual‑filter helpers
def is_manual_harami(open_arr, close_arr, idx, tol=0.0):
    if idx < 1: 
        return False
    o1, c1 = open_arr[idx-1], close_arr[idx-1]
    o2, c2 = open_arr[idx],   close_arr[idx]
    h1, l1 = max(o1, c1), min(o1, c1)
    h2, l2 = max(o2, c2), min(o2, c2)
    return (h2 <= h1 + tol) and (l2 >= l1 - tol)

def is_manual_engulfing(open_arr, close_arr, idx, tol=0.0):
    if idx < 1:
        return False
    o1, c1 = open_arr[idx-1], close_arr[idx-1]
    o2, c2 = open_arr[idx],   close_arr[idx]
    if c2 > o2:  # bullish
        return (o2 < c1 - tol) and (c2 > o1 + tol)
    else:        # bearish
        return (o2 > c1 + tol) and (c2 < o1 - tol)

HARAMI_TOL, ENGULF_TOL = 0.0, 0.0

@app.route('/patterns', methods=['POST'])
def patterns():
    app.logger.info("Received /patterns request")
    # parse JSON
    try:
        raw = request.data
        if b'\x00' in raw:
            raw = raw.split(b'\x00', 1)[0]
        data = json.loads(raw.decode('utf-8'))
    except Exception as e:
        return jsonify(error="Invalid JSON", details=str(e)), 400

    # read symbol for title
    symbol = data.get('symbol', 'Instrument')

    fh = bool(data.get('filter_harami', True))
    fe = bool(data.get('filter_engulfing', True))

    # extract price arrays
    try:
        ts     = data.get('time', [])
        open_  = np.array(data['open'][::-1],  dtype=float)
        high   = np.array(data['high'][::-1],  dtype=float)
        low    = np.array(data['low'][::-1],   dtype=float)
        close  = np.array(data['close'][::-1], dtype=float)
        idx    = pd.to_datetime(np.array(ts[::-1], dtype='int64'), unit='s')
        app.logger.info(f"Loaded {len(open_)} bars for {symbol}")
    except KeyError as ke:
        return jsonify(error=f"Missing field {ke}"), 400
    except Exception as e:
        return jsonify(error="Bad field format", details=str(e)), 400

    n = len(open_)
    all_hits = [[] for _ in range(n)]

    # detect candlestick patterns
    for name, func in CDL_FUNCS.items():
        try:
            res = func(open_, high, low, close)
        except Exception:
            continue
        for i, v in enumerate(res):
            if v == 0:
                continue
            if name == "CDLHARAMI" and fh and not is_manual_harami(open_, close, i, tol=HARAMI_TOL):
                continue
            if name == "CDLENGULFING" and fe and not is_manual_engulfing(open_, close, i, tol=ENGULF_TOL):
                continue
            all_hits[i].append((name, v))

    # pick one pattern per bar
    PRIORITY = ["CDLENGULFING", "CDLHARAMI", "CDLDOJI"]
    detected = [None] * n
    signals  = [None] * n
    for i, hits in enumerate(all_hits):
        if not hits:
            continue
        pick = next(((nm, val) for pat in PRIORITY for nm, val in hits if nm == pat), None)
        if pick is None:
            pick = max(hits, key=lambda x: abs(x[1]))
        nm, val = pick
        detected[i] = nm
        signals[i]  = "bullish" if val > 0 else "bearish"

    # build DataFrame
    df = pd.DataFrame({
        "Open":  open_,
        "High":  high,
        "Low":   low,
        "Close": close
    }, index=idx)
    df["Pattern"] = pd.Series(detected, index=idx).fillna("None")
    df["Signal"]  = pd.Series(signals,  index=idx).fillna("")

    # ensure oldest→newest left-to-right
    df.sort_index(inplace=True)

    # prepare buy/sell markers
    adds = []
    price_rng = df["High"].max() - df["Low"].min()
    for tstamp, row in df.iterrows():
        if row["Pattern"] == "None":
            continue
        if row["Signal"] == "bullish":
            # up arrow (buy) in green at bar low
            y = row["Low"] - price_rng * 0.005
            adds.append(mpf.make_addplot(
                [y if i == tstamp else np.nan for i in df.index],
                type="scatter",
                marker="^", markersize=80, color="green"
            ))
        else:
            # down arrow (sell) in red at bar high
            y = row["High"] + price_rng * 0.005
            adds.append(mpf.make_addplot(
                [y if i == tstamp else np.nan for i in df.index],
                type="scatter",
                marker="v", markersize=80, color="red"
            ))

    # custom candle colors (green up, red down)
    mc = make_marketcolors(up='green', down='red', edge='inherit', wick='inherit')
    style = make_mpf_style(marketcolors=mc, base_mpf_style='default')

    # plot with dynamic title
    fig, axes = mpf.plot(
        df,
        type="candle",
        style=style,
        title=f"{symbol} Patterns",
        addplot=adds,
        volume=False,
        returnfig=True,
        tight_layout=True
    )
    ax = axes[0]

    # legend mapping
    legend_elements = [
        Line2D([0], [0], color='green', lw=4, label='Buy Candle'),
        Line2D([0], [0], color='red',   lw=4, label='Sell Candle'),
        Line2D([0], [0], marker='^', linestyle='None', color='green', markersize=12, label='Buy Signal'),
        Line2D([0], [0], marker='v', linestyle='None', color='red',   markersize=12, label='Sell Signal'),
    ]
    ax.legend(handles=legend_elements, loc='upper left', frameon=True)

    # resize & save
    ow, oh = fig.get_size_inches()
    nw = 7.5; nh = nw * (oh/ow)
    fig.set_size_inches(nw, nh)

    fname = f"pattern_chart_{uuid.uuid4().hex[:8]}.png"
    path  = os.path.join(os.path.dirname(__file__), fname)
    fig.savefig(path, dpi=100)
    plt.close(fig)
    app.logger.info(f"Chart saved: {fname}")

    return jsonify(
        patterns=[p or "None" for p in detected[::-1]],
        signals =[s or "none" for s in signals [::-1]],
        log=[],
        chart=fname
    )

@app.route('/chart/<filename>')
def get_chart(filename):
    path = os.path.join(os.path.dirname(__file__), filename)
    if os.path.exists(path):
        return send_file(path, mimetype="image/png")
    return jsonify(error="Chart not found"), 404

if __name__ == "__main__":
    app.logger.info("Starting server on http://127.0.0.1:5000")
    app.run(host="127.0.0.1", port=5000)
