#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
AI WebSocket Server for MetaTrader 5
=====================================
Model:  stepfun/step-3.5-flash  ← PAID, NO LIMITS
Run:    python deepseek_server.py
Port:   8989
"""

import socket
import json
import struct
import base64
import hashlib
import threading
import sys
import os
from datetime import datetime

try:
    import requests as req_lib
except ImportError:
    os.system(f"{sys.executable} -m pip install requests --quiet")
    import requests as req_lib

try:
    import numpy as np
except ImportError:
    os.system(f"{sys.executable} -m pip install numpy --quiet")
    import numpy as np

# ─── SETTINGS ────────────────────────────────────────────────────────────────
OPENROUTER_API_KEY = ""
OPENROUTER_URL     = "https://openrouter.ai/api/v1/chat/completions"
MODEL              = "stepfun/step-3.5-flash"   # ★ PAID — no limits

HOST = "127.0.0.1"
PORT = 8989

chat_history: list[dict] = []
history_lock = threading.Lock()
MAX_HISTORY = 20

HEADERS = {
    "Authorization": f"Bearer {OPENROUTER_API_KEY}",
    "Content-Type": "application/json",
    "HTTP-Referer": "https://mt5-ai-advisor.local",
    "X-Title": "MT5 Step35Flash Advisor",
}

# ─── Helper functions ─────────────────────────────────────────────────────────
def log(msg: str):
    ts = datetime.now().strftime("%H:%M:%S")
    print(f"[{ts}] {msg}", flush=True)


# ═══════════════════════════════════════════════════════════════════════════════
#                         TECHNICAL INDICATORS
# ═══════════════════════════════════════════════════════════════════════════════

def calc_ma(arr: np.ndarray, period: int) -> float:
    """Simple Moving Average — last value."""
    if len(arr) < period:
        return float(arr[-1])
    return float(np.mean(arr[-period:]))


def calc_ema(arr: np.ndarray, period: int) -> float:
    """Exponential Moving Average — last value."""
    if len(arr) < 2:
        return float(arr[-1])
    k   = 2.0 / (period + 1)
    ema = float(arr[0])
    for v in arr[1:]:
        ema = float(v) * k + ema * (1 - k)
    return ema


def calc_rsi(arr: np.ndarray, period: int = 14) -> float:
    """RSI — Relative Strength Index."""
    if len(arr) < period + 1:
        return 50.0
    deltas   = np.diff(arr)
    gains    = np.where(deltas > 0, deltas,  0.0)
    losses   = np.where(deltas < 0, -deltas, 0.0)
    avg_gain = np.mean(gains[-period:])
    avg_loss = np.mean(losses[-period:])
    if avg_loss == 0:
        return 100.0
    rs = avg_gain / avg_loss
    return round(100.0 - 100.0 / (1.0 + rs), 2)


def calc_stoch(high: np.ndarray, low: np.ndarray, close: np.ndarray,
               k_period: int = 14, d_period: int = 3):
    """Stochastic %K and %D — last values."""
    if len(close) < k_period:
        return 50.0, 50.0
    k_values = []
    for i in range(d_period):
        idx = len(close) - 1 - i
        if idx < k_period - 1:
            k_values.append(50.0)
            continue
        h_max = np.max(high[idx - k_period + 1: idx + 1])
        l_min = np.min(low [idx - k_period + 1: idx + 1])
        denom = h_max - l_min
        k_values.append(100.0 * (close[idx] - l_min) / denom if denom > 0 else 50.0)
    k_val = k_values[0]
    d_val = float(np.mean(k_values))
    return round(k_val, 2), round(d_val, 2)


def calc_atr(high: np.ndarray, low: np.ndarray, close: np.ndarray,
             period: int = 14) -> float:
    """Average True Range."""
    if len(close) < 2:
        return 0.0
    tr_list = []
    for i in range(1, len(close)):
        tr = max(
            high[i] - low[i],
            abs(high[i] - close[i - 1]),
            abs(low[i]  - close[i - 1])
        )
        tr_list.append(tr)
    return round(float(np.mean(np.array(tr_list)[-period:])), 6)


def calc_stddev(arr: np.ndarray, period: int = 20) -> float:
    """Standard Deviation of close prices."""
    if len(arr) < period:
        return round(float(np.std(arr)), 6)
    return round(float(np.std(arr[-period:])), 6)


def calc_bb(arr: np.ndarray, period: int = 20, mult: float = 2.0):
    """Bollinger Bands — (upper, middle, lower)."""
    ma  = calc_ma(arr, period)
    std = calc_stddev(arr, period)
    return round(ma + mult * std, 5), round(ma, 5), round(ma - mult * std, 5)


def build_indicators(close: np.ndarray,
                     high:  np.ndarray = None,
                     low:   np.ndarray = None,
                     vol:   np.ndarray = None) -> dict:
    """
    Compute the full set of technical indicators.
    high / low / vol are optional: if not provided, close is used instead.
    """
    if high is None: high = close.copy()
    if low  is None: low  = close.copy()

    n    = len(close)
    last = float(close[-1])

    # ── Moving Averages ───────────────────────────────────────────────────────
    ma5   = calc_ma(close,  5)
    ma10  = calc_ma(close, 10)
    ma20  = calc_ma(close, 20)
    ma50  = calc_ma(close, 50)
    ma100 = calc_ma(close, 100)
    ma200 = calc_ma(close, 200)
    ema9  = calc_ema(close,  9)
    ema21 = calc_ema(close, 21)
    ema55 = calc_ema(close, 55)

    trend_ma20  = "UP" if last > ma20  else "DOWN"
    trend_ma50  = "UP" if last > ma50  else "DOWN"
    trend_ma200 = "UP" if last > ma200 else "DOWN"

    # ── RSI across three timeframes ────────────────────────────────────────────
    rsi7  = calc_rsi(close,  7)
    rsi14 = calc_rsi(close, 14)
    rsi21 = calc_rsi(close, 21)

    # ── Stochastic ────────────────────────────────────────────────────────────
    stoch_k, stoch_d = calc_stoch(high, low, close, 14, 3)

    # ── Volatility ────────────────────────────────────────────────────────────
    atr14    = calc_atr(high, low, close, 14)
    atr21    = calc_atr(high, low, close, 21)
    stddev10 = calc_stddev(close, 10)
    stddev20 = calc_stddev(close, 20)

    # ── Bollinger Bands ───────────────────────────────────────────────────────
    bb_up, bb_mid, bb_lo = calc_bb(close, 20, 2.0)
    bb_pos = "UPPER" if last > bb_up else ("LOWER" if last < bb_lo else "MID")

    # ── Momentum ──────────────────────────────────────────────────────────────
    mom5  = round(last - float(close[-6]),  5) if n >= 6  else 0.0
    mom10 = round(last - float(close[-11]), 5) if n >= 11 else 0.0
    mom20 = round(last - float(close[-21]), 5) if n >= 21 else 0.0

    # ── Last candle ───────────────────────────────────────────────────────────
    last_open  = round(float(close[-2]), 5) if n >= 2 else last   # fallback
    last_high  = round(float(high[-1]),  5)
    last_low   = round(float(low[-1]),   5)
    last_close = round(last, 5)
    candle_body = round(abs(last_close - last_open), 5)
    candle_dir  = "BULL" if last_close >= last_open else "BEAR"

    # ── Volume ────────────────────────────────────────────────────────────────
    vol_last  = int(vol[-1])              if vol is not None and len(vol) > 0  else 0
    vol_avg   = int(np.mean(vol[-20:]))   if vol is not None and len(vol) >= 20 else 0
    vol_ratio = round(vol_last / vol_avg, 2) if vol_avg > 0 else 1.0

    return {
        "n":        n,
        "last":     last_close,
        "candle":   f"{candle_dir} O={last_open} H={last_high} L={last_low} C={last_close} Body={candle_body}",
        "volume":   f"Last={vol_last} Avg20={vol_avg} Ratio={vol_ratio}x",
        "ma":       (f"MA5={ma5:.5f} MA10={ma10:.5f} "
                     f"MA20={ma20:.5f}({trend_ma20}) "
                     f"MA50={ma50:.5f}({trend_ma50}) "
                     f"MA100={ma100:.5f} MA200={ma200:.5f}({trend_ma200})"),
        "ema":      f"EMA9={ema9:.5f} EMA21={ema21:.5f} EMA55={ema55:.5f}",
        "rsi":      f"RSI7={rsi7} RSI14={rsi14} RSI21={rsi21}",
        "stoch":    f"K={stoch_k} D={stoch_d}",
        "atr":      f"ATR14={atr14} ATR21={atr21}",
        "stddev":   f"StdDev10={stddev10} StdDev20={stddev20}",
        "bb":       f"BB_up={bb_up} BB_mid={bb_mid} BB_lo={bb_lo} Pos={bb_pos}",
        "momentum": f"Mom5={mom5:+.5f} Mom10={mom10:+.5f} Mom20={mom20:+.5f}",
    }


# ─── API call ─────────────────────────────────────────────────────────────────
def ask_deepseek(user_message: str, system_prompt: str = None) -> str:
    if system_prompt is None:
        system_prompt = (
            "You are a professional trader and financial market analyst. "
            "Analyze price series data and provide clear trading signals. "
            "Be concise and to the point. Reply in English."
        )

    messages = [{"role": "system", "content": system_prompt}]
    with history_lock:
        messages += chat_history[-MAX_HISTORY:]
        messages.append({"role": "user", "content": user_message})

    payload = {
        "model": MODEL,
        "messages": messages,
        "max_tokens": 1024,
        "temperature": 0.3,
    }

    try:
        resp = req_lib.post(
            OPENROUTER_URL,
            headers=HEADERS,
            data=json.dumps(payload, ensure_ascii=False).encode("utf-8"),
            timeout=60
        )

        if resp.status_code != 200:
            log(f"HTTP {resp.status_code}: {resp.text[:300]}")
            return f"ERROR HTTP {resp.status_code}: {resp.text[:200]}"

        data = resp.json()

        if "choices" not in data:
            log(f"No choices in response: {json.dumps(data, ensure_ascii=False)[:300]}")
            return f"ERROR: {data.get('error', {}).get('message', str(data))[:200]}"

        answer = data["choices"][0]["message"]["content"].strip()

        with history_lock:
            chat_history.append({"role": "user",      "content": user_message})
            chat_history.append({"role": "assistant",  "content": answer})

        return answer

    except Exception as e:
        log(f"API error: {e}")
        return f"ERROR: {e}"


def analyze_prices(prices: list[float], symbol: str = "UNKNOWN") -> dict:
    """
    Full technical analysis via AI.
    Accepts close prices only (PRICES: command).
    """
    n = len(prices)
    if n < 10:
        return {"signal": "hold", "comment": "Insufficient data"}

    close = np.array(prices, dtype=float)
    ind   = build_indicators(close)   # high/low/vol = None → use close

    prompt = (
        f"Symbol: {symbol} | Bars: {ind['n']}\n\n"
        f"Last candle: {ind['candle']}\n"
        f"Volume: {ind['volume']}\n\n"
        f"Moving averages:\n  {ind['ma']}\n  {ind['ema']}\n\n"
        f"Oscillators:\n  {ind['rsi']}\n  Stoch {ind['stoch']}\n\n"
        f"Volatility:\n  {ind['atr']}\n  {ind['stddev']}\n\n"
        f"Bollinger Bands:  {ind['bb']}\n"
        f"Momentum:  {ind['momentum']}\n\n"
        "Provide a trading signal. Always try to give a specific signal — buy or sell. "
        "Reply ONLY with JSON, no markdown:\n"
        '{"signal":"buy"|"sell"|"hold","comment":"analysis up to 150 characters"}'
    )

    raw = ask_deepseek(prompt)
    log(f"RAW [{symbol}]: {raw[:150]}")

    # ── Attempt 1: parse JSON ─────────────────────────────────────────────────
    try:
        start = raw.find("{")
        end   = raw.rfind("}") + 1
        if start != -1 and end > start:
            data    = json.loads(raw[start:end])
            signal  = data.get("signal", "hold").lower()
            comment = data.get("comment", raw[:150])
            if signal not in ("buy", "sell", "hold"):
                signal = "hold"
            return {"signal": signal, "comment": comment}
    except Exception:
        pass

    # ── Attempt 2: keyword search (fallback) ──────────────────────────────────
    lower = raw.lower()
    if "buy" in lower:
        signal = "buy"
    elif "sell" in lower:
        signal = "sell"
    else:
        signal = "hold"

    return {"signal": signal, "comment": raw[:200]}


# ─── WebSocket helpers ────────────────────────────────────────────────────────
WS_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"

def ws_handshake_response(http_request: str) -> str:
    key = ""
    for line in http_request.split("\r\n"):
        if "Sec-WebSocket-Key" in line:
            key = line.split(": ")[1].strip()
            break
    accept = base64.b64encode(
        hashlib.sha1((key + WS_GUID).encode("utf-8")).digest()
    ).decode("utf-8")
    return (
        "HTTP/1.1 101 Switching Protocols\r\n"
        "Upgrade: websocket\r\n"
        "Connection: Upgrade\r\n"
        f"Sec-WebSocket-Accept: {accept}\r\n"
        f"WebSocket-Location: ws://{HOST}:{PORT}/\r\n\r\n"
    )

def ws_decode(data: bytes) -> str:
    if len(data) < 2:
        return ""
    is_masked = bool((data[1] & 0x80) >> 7)
    plen      = data[1] & 0x7F

    if is_masked:
        offset = 2
        if plen == 126: offset = 4
        if plen == 127: offset = 10
        mask    = data[offset:offset+4]
        payload = data[offset+4:]
        decoded = bytearray(b ^ mask[i % 4] for i, b in enumerate(payload))
        return decoded.decode("utf-8", errors="replace")
    else:
        offset = 2
        if plen == 126: offset = 4
        if plen == 127: offset = 10
        return data[offset:].decode("utf-8", errors="replace")

def ws_encode(message: str) -> bytes:
    payload = message.encode("utf-8")
    n = len(payload)
    header = bytearray([0x81])
    if n <= 125:
        header.append(n)
    elif n <= 65535:
        header.append(126)
        header += struct.pack(">H", n)
    else:
        header.append(127)
        header += struct.pack(">Q", n)
    return bytes(header) + payload


# ─── Client handler ───────────────────────────────────────────────────────────
def handle_client(conn: socket.socket, addr):
    log(f"Connected: {addr}")
    is_websocket = False
    buffer = b""

    try:
        conn.settimeout(120.0)

        while True:
            try:
                chunk = conn.recv(8192)
            except socket.timeout:
                log("Connection timeout")
                break
            if not chunk:
                break

            buffer += chunk

            if not is_websocket:
                text = buffer.decode("utf-8", errors="replace")
                if "\r\n\r\n" in text:
                    conn.sendall(ws_handshake_response(text).encode("utf-8"))
                    is_websocket = True
                    log("WebSocket handshake complete")
                    buffer = b""
                continue

            if len(buffer) < 2:
                continue

            message = ws_decode(buffer)
            buffer  = b""

            if not message:
                continue

            message = message.strip()
            log(f"Received: {message[:100]}")

            if message.lower() == "stop":
                log("Expert Advisor sent STOP. Closing connection.")
                break

            if message.upper().startswith("CHAT:"):
                question = message[5:].strip()
                log(f"Chat: {question}")
                answer = ask_deepseek(question)
                resp   = json.dumps({"signal": "chat", "comment": answer}, ensure_ascii=False)
                conn.sendall(ws_encode(resp))
                continue

            if message.upper() == "CLEAR":
                with history_lock:
                    chat_history.clear()
                conn.sendall(ws_encode(json.dumps(
                    {"signal": "info", "comment": "History cleared"}, ensure_ascii=False)))
                continue

            if message.upper().startswith("PRICES:"):
                parts  = message[7:].split(":", 1)
                symbol = parts[0] if len(parts) > 1 else "SYM"
                csv    = parts[1] if len(parts) > 1 else parts[0]
                try:
                    prices = [float(x) for x in csv.split(",") if x.strip()]
                except ValueError:
                    prices = []

                result = (analyze_prices(prices, symbol)
                          if prices else
                          {"signal": "hold", "comment": "Parse error"})
                log(f"Signal [{symbol}]: {result['signal']} | {result['comment'][:80]}")
                conn.sendall(ws_encode(json.dumps(result, ensure_ascii=False)))
                continue

            conn.sendall(ws_encode(json.dumps(
                {"signal": "info", "comment": f"Unknown command: {message[:50]}"}, ensure_ascii=False)))

    except Exception as e:
        log(f"Error: {e}")
    finally:
        conn.close()
        log(f"Disconnected: {addr}")


# ─── Main loop ────────────────────────────────────────────────────────────────
def main():
    print("=" * 60)
    print("  AI WebSocket Server for MetaTrader 5")
    print(f"  Address: ws://{HOST}:{PORT}")
    print(f"  Model:   {MODEL}")
    print(f"  API:     {OPENROUTER_URL}")
    print("=" * 60)

    server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    server_sock.bind((HOST, PORT))
    server_sock.listen(5)
    log("Server started. Waiting for MT5 connection...")

    try:
        while True:
            conn, addr = server_sock.accept()
            t = threading.Thread(target=handle_client, args=(conn, addr), daemon=True)
            t.start()
    except KeyboardInterrupt:
        log("Server stopped.")
    finally:
        server_sock.close()


if __name__ == "__main__":
    main()
