#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
LLM-агент с SQL-памятью для MetaTrader 5
==========================================
Model   : grok-3-mini  (xAI API — прямой REST, без SDK)
Port    : 8976
DB      : agent_memory.db  (SQLite, создаётся автоматически)

╔══════════════════════════════════════════════════════════════╗
║         LLM MEMORY AGENT  v1.0  [xAI / Grok]                ║
║  Агент помнит свои решения и их исходы между сессиями.       ║
║                                                              ║
║  МЕХАНИКА:                                                   ║
║   1. Каждое решение пишется в SQLite с рыночным режимом      ║
║   2. При следующем запросе база отдаёт релевантный контекст  ║
║      — похожие условия RSI + последние 3 решения             ║
║   3. Агент читает свою историю перед принятием решения       ║
║   4. После закрытия позиции MT5 отправляет RESULT:           ║
║      база обновляет исход и пересчитывает статистику         ║
║                                                              ║
║  Команды МТ5:                                                ║
║   ANALYZE:SYMBOL:close_csv         → сигнал + decision_id    ║
║   RESULT:SYMBOL:decision_id:pnl    → обновить исход          ║
║   MEM_STATUS:SYMBOL                → статистика агента       ║
║   STOP                             → отключиться             ║
╚══════════════════════════════════════════════════════════════╝
"""

import socket
import json
import struct
import base64
import hashlib
import threading
import sqlite3
import sys
import os
import time
from datetime import datetime
from typing import List, Optional

try:
    import requests as req_lib
except ImportError:
    os.system(f"{sys.executable} -m pip install requests --quiet")
    import requests as req_lib

try:
    import numpy as np
except ImportError:
    os.system(f"{sys.executable} -m pip install numpy --quiet")
    import numpy as np

# ─── НАСТРОЙКИ ───────────────────────────────────────────────────────────────
XAI_API_KEY = "xai-HGyczXQYS38BeK49uRcPic9yfmepswFD2RfTZOjLpOTf94Q9Ig6ZftzMtT4JKXCVUO1sOzKWvtkAh4R9"          # ← вставьте ключ xAI с console.x.ai
XAI_URL     = "https://api.x.ai/v1/chat/completions"
MODEL       = "grok-4-1-fast-reasoning"
HOST        = "127.0.0.1"
PORT        = 8976
MAX_TOKENS  = 1200
DB_PATH     = "agent_memory.db"

HEADERS = {
    "Authorization": f"Bearer {XAI_API_KEY}",
    "Content-Type":  "application/json",
    "HTTP-Referer":  "https://mt5-ai-memory.local",
    "X-Title":       "MT5 Memory Agent v1.0",
}


def log(msg: str):
    ts = datetime.now().strftime("%H:%M:%S")
    print(f"[{ts}] {msg}", flush=True)


# ═══════════════════════════════════════════════════════════════════════════════
# SQL-СЛОЙ ПАМЯТИ
# ═══════════════════════════════════════════════════════════════════════════════
class MemoryLayer:
    """
    Постоянная память агента на базе SQLite.

    Таблица decisions  — все решения агента с исходами.
    Таблица agent_stats — агрегированная статистика по инструменту/ТФ.
    """

    def __init__(self, db_path: str = DB_PATH):
        self._db_path = db_path
        self._lock    = threading.Lock()
        self._conn    = sqlite3.connect(db_path, check_same_thread=False)
        self._init_tables()
        log(f"Memory DB: {db_path}")

    def _init_tables(self):
        self._conn.executescript("""
            CREATE TABLE IF NOT EXISTS decisions (
                id          INTEGER PRIMARY KEY AUTOINCREMENT,
                timestamp   TEXT    NOT NULL,
                symbol      TEXT    NOT NULL,
                timeframe   TEXT    NOT NULL,
                signal      TEXT    NOT NULL,
                confidence  REAL    DEFAULT 0.5,
                comment     TEXT    DEFAULT '',
                pnl         REAL,
                outcome     TEXT    DEFAULT 'open',
                rsi         REAL    DEFAULT 50.0,
                ma_align    TEXT    DEFAULT 'mixed',
                volatility  TEXT    DEFAULT 'normal',
                zscore      REAL    DEFAULT 0.0
            );

            CREATE TABLE IF NOT EXISTS agent_stats (
                symbol      TEXT,
                timeframe   TEXT,
                total       INTEGER DEFAULT 0,
                wins        INTEGER DEFAULT 0,
                losses      INTEGER DEFAULT 0,
                avg_pnl     REAL    DEFAULT 0.0,
                best_signal TEXT    DEFAULT 'hold',
                updated_at  TEXT,
                PRIMARY KEY (symbol, timeframe)
            );

            CREATE INDEX IF NOT EXISTS idx_decisions_sym_tf
                ON decisions(symbol, timeframe, outcome);
            CREATE INDEX IF NOT EXISTS idx_decisions_rsi
                ON decisions(rsi);
        """)
        self._conn.commit()

    # ── Запись нового решения ─────────────────────────────────────────────────
    def write_decision(self, symbol: str, timeframe: str,
                       signal: str, confidence: float, comment: str,
                       rsi: float, ma_align: str,
                       volatility: str, zscore: float) -> int:
        """Возвращает decision_id для последующего update_result."""
        with self._lock:
            cur = self._conn.execute(
                """INSERT INTO decisions
                   (timestamp, symbol, timeframe, signal, confidence,
                    comment, rsi, ma_align, volatility, zscore)
                   VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
                (datetime.utcnow().isoformat(), symbol, timeframe,
                 signal, confidence, comment,
                 rsi, ma_align, volatility, zscore)
            )
            self._conn.commit()
            return cur.lastrowid

    # ── Обновление исхода сделки ──────────────────────────────────────────────
    def update_result(self, decision_id: int, pnl: float):
        """Вызывается после получения RESULT от советника."""
        outcome = "win" if pnl > 0 else "loss"
        with self._lock:
            self._conn.execute(
                "UPDATE decisions SET pnl=?, outcome=? WHERE id=?",
                (pnl, outcome, decision_id)
            )
            self._conn.commit()
        # Пересчитываем агрегат в отдельном вызове (без вложенной блокировки)
        self._refresh_stats(decision_id)

    def _refresh_stats(self, decision_id: int):
        with self._lock:
            row = self._conn.execute(
                "SELECT symbol, timeframe FROM decisions WHERE id=?",
                (decision_id,)
            ).fetchone()
        if not row:
            return
        symbol, timeframe = row

        with self._lock:
            stats = self._conn.execute(
                """SELECT COUNT(*),
                          SUM(CASE WHEN outcome='win'  THEN 1 ELSE 0 END),
                          SUM(CASE WHEN outcome='loss' THEN 1 ELSE 0 END),
                          AVG(pnl)
                   FROM decisions
                   WHERE symbol=? AND timeframe=? AND outcome != 'open'""",
                (symbol, timeframe)
            ).fetchone()

            total, wins, losses, avg_pnl = stats
            total  = total  or 0
            wins   = wins   or 0
            losses = losses or 0
            avg_pnl = avg_pnl or 0.0

            # Определяем исторически более точный тип сигнала
            buy_wr = self._conn.execute(
                """SELECT AVG(CASE WHEN outcome='win' THEN 1.0 ELSE 0.0 END)
                   FROM decisions
                   WHERE symbol=? AND timeframe=? AND signal='buy'
                   AND outcome != 'open'""",
                (symbol, timeframe)
            ).fetchone()[0] or 0.0

            sell_wr = self._conn.execute(
                """SELECT AVG(CASE WHEN outcome='win' THEN 1.0 ELSE 0.0 END)
                   FROM decisions
                   WHERE symbol=? AND timeframe=? AND signal='sell'
                   AND outcome != 'open'""",
                (symbol, timeframe)
            ).fetchone()[0] or 0.0

            best_signal = "buy" if buy_wr >= sell_wr else "sell"

            self._conn.execute(
                """INSERT INTO agent_stats
                       (symbol, timeframe, total, wins, losses,
                        avg_pnl, best_signal, updated_at)
                   VALUES (?, ?, ?, ?, ?, ?, ?, ?)
                   ON CONFLICT(symbol, timeframe) DO UPDATE SET
                       total=excluded.total,
                       wins=excluded.wins,
                       losses=excluded.losses,
                       avg_pnl=excluded.avg_pnl,
                       best_signal=excluded.best_signal,
                       updated_at=excluded.updated_at""",
                (symbol, timeframe, total, wins, losses,
                 avg_pnl, best_signal, datetime.utcnow().isoformat())
            )
            self._conn.commit()

    # ── Формирование контекста для промпта ───────────────────────────────────
    def get_context(self, symbol: str, timeframe: str,
                    rsi: float, limit: int = 5) -> str:
        """
        Возвращает текстовый фрагмент для вставки в системный промпт.
        Включает решения при похожем RSI (±8 пунктов) + 3 последних решения.
        """
        with self._lock:
            similar = self._conn.execute(
                """SELECT signal, outcome, pnl, comment, rsi
                   FROM decisions
                   WHERE symbol=? AND timeframe=? AND outcome != 'open'
                   AND ABS(rsi - ?) < 8
                   ORDER BY timestamp DESC LIMIT ?""",
                (symbol, timeframe, rsi, limit)
            ).fetchall()

            recent = self._conn.execute(
                """SELECT signal, outcome, pnl, comment
                   FROM decisions
                   WHERE symbol=? AND timeframe=? AND outcome != 'open'
                   ORDER BY timestamp DESC LIMIT 3""",
                (symbol, timeframe)
            ).fetchall()

        parts = []

        if similar:
            parts.append(f"Similar RSI conditions (RSI ≈ {rsi:.0f}) in the past:")
            for sig, out, pnl, cmt, r in similar:
                pnl_str = f"{pnl:+.2f}" if pnl is not None else "?"
                parts.append(
                    f"  [{sig.upper():4s}] → {out:4s} ({pnl_str} pts) "
                    f"| RSI≈{r:.0f} | {cmt[:60]}"
                )

        if recent:
            parts.append("Last 3 decisions (any conditions):")
            for sig, out, pnl, cmt in recent:
                pnl_str = f"{pnl:+.2f}" if pnl is not None else "?"
                parts.append(
                    f"  [{sig.upper():4s}] → {out:4s} ({pnl_str} pts) | {cmt[:60]}"
                )

        return "\n".join(parts) if parts else "No historical data yet — acting on current market data only."

    # ── Статистика агента ─────────────────────────────────────────────────────
    def get_stats(self, symbol: str, timeframe: str) -> str:
        """Однострочное описание накопленной статистики."""
        with self._lock:
            row = self._conn.execute(
                "SELECT total, wins, losses, avg_pnl, best_signal "
                "FROM agent_stats WHERE symbol=? AND timeframe=?",
                (symbol, timeframe)
            ).fetchone()

        if not row or row[0] == 0:
            return "No statistics yet — this is a fresh start."

        total, wins, losses, avg_pnl, best_signal = row
        win_rate = 100.0 * wins / total if total else 0.0
        return (
            f"Agent stats on {symbol}/{timeframe}: "
            f"{total} closed trades, win-rate {win_rate:.0f}%, "
            f"avg PnL {avg_pnl:+.2f} pts. "
            f"Historically stronger direction: {best_signal.upper()}."
        )

    # ── Полный статус (для команды MEM_STATUS) ────────────────────────────────
    def full_status(self, symbol: str, timeframe: str) -> dict:
        with self._lock:
            row = self._conn.execute(
                "SELECT total, wins, losses, avg_pnl, best_signal, updated_at "
                "FROM agent_stats WHERE symbol=? AND timeframe=?",
                (symbol, timeframe)
            ).fetchone()
            open_cnt = self._conn.execute(
                "SELECT COUNT(*) FROM decisions "
                "WHERE symbol=? AND timeframe=? AND outcome='open'",
                (symbol, timeframe)
            ).fetchone()[0]

        if not row:
            return {"symbol": symbol, "timeframe": timeframe,
                    "total": 0, "open": open_cnt}

        total, wins, losses, avg_pnl, best_signal, upd = row
        win_rate = round(100.0 * wins / total, 1) if total else 0.0
        return {
            "symbol":      symbol,
            "timeframe":   timeframe,
            "total":       total,
            "wins":        wins,
            "losses":      losses,
            "win_rate":    win_rate,
            "avg_pnl":     round(avg_pnl, 4),
            "best_signal": best_signal,
            "open":        open_cnt,
            "updated_at":  upd,
        }


# ═══════════════════════════════════════════════════════════════════════════════
# ТЕХНИЧЕСКИЕ ИНДИКАТОРЫ  (тот же набор, что в evolution-сервере)
# ═══════════════════════════════════════════════════════════════════════════════
def calc_indicators(prices: List[float]) -> dict:
    arr = np.array(prices, dtype=float)

    def sma(a, p):
        return float(np.mean(a[-p:])) if len(a) >= p else float(a[-1])

    def ema(a, p):
        if len(a) < 2:
            return float(a[-1])
        k = 2.0 / (p + 1)
        e = float(a[0])
        for v in a[1:]:
            e = float(v) * k + e * (1 - k)
        return e

    def rsi(a, p=14):
        if len(a) < p + 1:
            return 50.0
        d  = np.diff(a)
        g  = np.where(d > 0, d, 0.0)
        l  = np.where(d < 0, -d, 0.0)
        ag, al = np.mean(g[-p:]), np.mean(l[-p:])
        return 100.0 if al == 0 else round(100 - 100 / (1 + ag / al), 2)

    def atr_simple(a, p=14):
        if len(a) < 2:
            return 0.0
        tr = np.abs(np.diff(a))
        return round(float(np.mean(tr[-p:])), 6)

    def bb(a, p=20):
        if len(a) < p:
            return float(a[-1]), float(a[-1]) - 0.001, float(a[-1]) + 0.001
        s   = a[-p:]
        mid = float(np.mean(s))
        std = float(np.std(s))
        return mid, round(mid - 2 * std, 6), round(mid + 2 * std, 6)

    def stoch_k(a, p=14):
        if len(a) < p:
            return 50.0
        h, l = np.max(a[-p:]), np.min(a[-p:])
        return round(100 * (a[-1] - l) / (h - l), 2) if h != l else 50.0

    current         = float(arr[-1])
    ma5             = sma(arr, 5)
    ma20            = sma(arr, 20)
    ma50            = sma(arr, 50)
    rsi_val         = rsi(arr)
    atr_val         = atr_simple(arr)
    atr20           = atr_simple(arr, 20)
    bb_mid, bb_lo, bb_hi = bb(arr)
    stoch           = stoch_k(arr)

    # z-score и производные
    if len(arr) >= 20:
        mu     = float(np.mean(arr[-20:]))
        sigma  = float(np.std(arr[-20:]))
        zscore = round((current - mu) / sigma, 2) if sigma > 0 else 0.0
    else:
        zscore = 0.0

    # Режим волатильности
    if atr20 > 0:
        ratio = atr_val / atr20
        volatility = "high" if ratio > 1.3 else ("low" if ratio < 0.7 else "normal")
    else:
        volatility = "normal"

    # Выравнивание тренда
    trend = ("up"   if ma5 > ma20 > ma50
             else ("down" if ma5 < ma20 < ma50 else "mixed"))

    # Выравнивание MA для памяти
    ma_align = "above" if current > ma20 else ("below" if current < ma20 else "mixed")

    return {
        "current":    round(current, 5),
        "ma5":        round(ma5, 5),
        "ma20":       round(ma20, 5),
        "ma50":       round(ma50, 5),
        "rsi":        rsi_val,
        "atr":        atr_val,
        "bb_mid":     round(bb_mid, 5),
        "bb_lo":      round(bb_lo, 5),
        "bb_hi":      round(bb_hi, 5),
        "stoch":      stoch,
        "zscore":     zscore,
        "trend":      trend,
        "ma_align":   ma_align,
        "volatility": volatility,
        "above_ma20": current > ma20,
        "above_ma50": current > ma50,
    }


def build_market_context(symbol: str, ind: dict) -> str:
    return (
        f"Symbol: {symbol}\n"
        f"Price: {ind['current']}  "
        f"MA5: {ind['ma5']}  MA20: {ind['ma20']}  MA50: {ind['ma50']}\n"
        f"RSI(14): {ind['rsi']}   Stoch(14): {ind['stoch']}   Z-score: {ind['zscore']}\n"
        f"ATR(14): {ind['atr']}   Volatility: {ind['volatility']}\n"
        f"BB: mid={ind['bb_mid']} lo={ind['bb_lo']} hi={ind['bb_hi']}\n"
        f"Trend alignment: {ind['trend']}  |  "
        f"Above MA20: {ind['above_ma20']}  |  Above MA50: {ind['above_ma50']}\n"
        f"\n"
        f"Based on your trading philosophy and historical memory above, "
        f"analyze this data and give a trading signal.\n"
        f"Output ONLY a single valid JSON object. No markdown, no extra text.\n"
        f"Example: {{\"signal\": \"buy\", "
        f"\"comment\": \"RSI rising from 38, MA aligned up\", \"confidence\": 0.72}}\n"
        f"signal must be exactly: buy  sell  or  hold\n"
        f"Your JSON:"
    )


# ═══════════════════════════════════════════════════════════════════════════════
# ПРОМПТ С ПАМЯТЬЮ
# ═══════════════════════════════════════════════════════════════════════════════
AGENT_BASE_PROMPT = (
    "You are a disciplined quantitative trading agent with persistent memory. "
    "You trade FOREX and CFD instruments using technical analysis. "
    "Your core philosophy: follow momentum in trending conditions, "
    "trade mean-reversion in ranging conditions. "
    "Risk management comes first: avoid entries in high-volatility chaotic regimes."
)


def build_prompt_with_memory(symbol: str, timeframe: str,
                              memory: MemoryLayer, rsi: float) -> str:
    stats   = memory.get_stats(symbol, timeframe)
    context = memory.get_context(symbol, timeframe, rsi)

    return (
        f"{AGENT_BASE_PROMPT}\n\n"
        f"=== YOUR SELF-KNOWLEDGE ===\n"
        f"{stats}\n\n"
        f"=== HISTORICAL CONTEXT ===\n"
        f"{context}\n\n"
        f"=== BEHAVIORAL RULES ===\n"
        f"1. If your win-rate on this instrument is below 40% — be extra conservative.\n"
        f"2. If the last 3 decisions were all losses — reduce confidence by 0.1.\n"
        f"3. If historical context shows repeated failures on one signal type — downgrade it.\n"
        f"4. If volatility is HIGH — prefer hold unless signal is very strong (confidence > 0.75).\n"
        f"5. Never ignore your own track record. Adapt.\n\n"
        f"OUTPUT RULE: Reply ONLY with valid JSON. No markdown.\n"
        f"Example: {{\"signal\": \"buy\", \"comment\": \"reason\", \"confidence\": 0.72}}\n"
        f"BIAS: 'hold' only when genuinely uncertain. Commit to a direction."
    )


# ═══════════════════════════════════════════════════════════════════════════════
# xAI / GROK ВЫЗОВ
# ═══════════════════════════════════════════════════════════════════════════════
def ask_grok(system_prompt: str, user_message: str, retries: int = 2) -> str:
    payload = {
        "model":      MODEL,
        "max_tokens": MAX_TOKENS,
        "temperature": 0.2,   # детерминированность важна для торговых решений
        "messages": [
            {"role": "system", "content": system_prompt},
            {"role": "user",   "content": user_message},
        ],
    }
    for attempt in range(retries + 1):
        try:
            resp = req_lib.post(XAI_URL, headers=HEADERS,
                                json=payload, timeout=60)
            resp.raise_for_status()
            raw = resp.json()["choices"][0]["message"]["content"] or ""
            raw = raw.strip()
            log(f"  RAW → {raw[:200]}")
            return raw
        except Exception as e:
            log(f"Grok API error (attempt {attempt + 1}): {e}")
            if attempt < retries:
                time.sleep(1.5 * (attempt + 1))
    return '{"signal":"hold","comment":"API error","confidence":0.0}'


def parse_signal(raw: str) -> dict:
    raw = raw.strip()
    # Убираем markdown code fences
    if raw.startswith("```"):
        lines = raw.split("\n")
        raw   = "\n".join(l for l in lines if not l.startswith("```")).strip()
    # Убираем thinking-блоки reasoning-моделей
    if "<think>" in raw and "</think>" in raw:
        raw = raw[raw.find("</think>") + 8:].strip()
    # Ищем JSON
    start = raw.find("{")
    end   = raw.rfind("}") + 1
    if start >= 0 and end > start:
        try:
            obj = json.loads(raw[start:end])
            sig = str(obj.get("signal", "hold")).lower().strip()
            if sig not in ("buy", "sell", "hold"):
                sig = "hold"
            return {
                "signal":     sig,
                "comment":    str(obj.get("comment", ""))[:150],
                "confidence": float(obj.get("confidence", 0.5)),
            }
        except (json.JSONDecodeError, ValueError) as e:
            log(f"  JSON parse error: {e}")
    log(f"  WARN: could not parse signal from: {raw[:100]}")
    return {"signal": "hold", "comment": "parse_error", "confidence": 0.0}


# ═══════════════════════════════════════════════════════════════════════════════
# ГЛОБАЛЬНАЯ ПАМЯТЬ  (один экземпляр на весь процесс)
# ═══════════════════════════════════════════════════════════════════════════════
memory = MemoryLayer(DB_PATH)
db_lock = threading.Lock()   # дополнительная блокировка при параллельных символах


def safe_write_decision(**kwargs) -> int:
    """Thread-safe запись решения."""
    with db_lock:
        return memory.write_decision(**kwargs)


def safe_update_result(decision_id: int, pnl: float):
    """Thread-safe обновление исхода."""
    memory.update_result(decision_id, pnl)


# ═══════════════════════════════════════════════════════════════════════════════
# ОСНОВНАЯ ФУНКЦИЯ АНАЛИЗА
# ═══════════════════════════════════════════════════════════════════════════════
def analyze_with_memory(symbol: str, prices: List[float],
                        timeframe: str = "M15") -> dict:
    """
    Полный цикл: индикаторы → контекст из базы → запрос Grok → запись решения.
    Возвращает словарь с signal, comment, confidence, decision_id.
    """
    if not prices:
        return {"signal": "hold", "comment": "no_data",
                "confidence": 0.0, "decision_id": -1}

    ind = calc_indicators(prices)
    rsi = ind["rsi"]

    # Строим промпт с памятью
    system = build_prompt_with_memory(symbol, timeframe, memory, rsi)
    user   = build_market_context(symbol, ind)

    log(f"[{symbol}] Stats: {memory.get_stats(symbol, timeframe)}")
    log(f"[{symbol}] Context loaded (RSI={rsi:.1f})")

    t0  = time.time()
    raw = ask_grok(system, user)
    dt  = round(time.time() - t0, 2)

    result = parse_signal(raw)

    # Сохраняем решение в базу
    decision_id = safe_write_decision(
        symbol     = symbol,
        timeframe  = timeframe,
        signal     = result["signal"],
        confidence = result["confidence"],
        comment    = result["comment"],
        rsi        = rsi,
        ma_align   = ind["ma_align"],
        volatility = ind["volatility"],
        zscore     = ind["zscore"],
    )

    result["decision_id"] = decision_id
    result["latency_s"]   = dt
    result["indicators"]  = ind

    log(
        f"[{symbol}] → {result['signal']} | "
        f"conf={result['confidence']:.2f} | "
        f"decision_id={decision_id} | {dt}s"
    )
    return result


# ═══════════════════════════════════════════════════════════════════════════════
# WebSocket helpers  (идентичны evolution-серверу)
# ═══════════════════════════════════════════════════════════════════════════════
def ws_handshake_response(request: str) -> str:
    key = ""
    for line in request.split("\r\n"):
        if "Sec-WebSocket-Key" in line:
            key = line.split(": ", 1)[1].strip()
            break
    magic  = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
    accept = base64.b64encode(
        hashlib.sha1((key + magic).encode()).digest()
    ).decode()
    return (
        "HTTP/1.1 101 Switching Protocols\r\n"
        "Upgrade: websocket\r\n"
        "Connection: Upgrade\r\n"
        f"Sec-WebSocket-Accept: {accept}\r\n\r\n"
    )


def ws_decode(data: bytes) -> str:
    if len(data) < 2:
        return ""
    masked = bool(data[1] & 0x80)
    plen   = data[1] & 0x7F
    off    = 2 + (2 if plen == 126 else 8 if plen == 127 else 0)
    if masked:
        if len(data) < off + 4:
            return ""
        mask    = data[off: off + 4]
        payload = data[off + 4:]
        return bytearray(
            b ^ mask[i % 4] for i, b in enumerate(payload)
        ).decode("utf-8", errors="replace")
    return data[off:].decode("utf-8", errors="replace")


def ws_encode(message: str) -> bytes:
    payload = message.encode("utf-8")
    n       = len(payload)
    header  = bytearray([0x81])
    if n <= 125:
        header.append(n)
    elif n <= 65535:
        header.append(126)
        header += struct.pack(">H", n)
    else:
        header.append(127)
        header += struct.pack(">Q", n)
    return bytes(header) + payload


# ═══════════════════════════════════════════════════════════════════════════════
# ОБРАБОТЧИК КЛИЕНТА
# ═══════════════════════════════════════════════════════════════════════════════
def handle_client(conn: socket.socket, addr):
    log(f"Connected: {addr}")
    is_ws          = False
    buffer         = b""
    send_lock      = threading.Lock()
    last_status_ts = 0.0

    def safe_send(data: bytes):
        with send_lock:
            try:
                conn.sendall(data)
            except Exception as e:
                log(f"send error: {e}")

    try:
        conn.settimeout(600.0)
        while True:
            try:
                chunk = conn.recv(65536)
            except socket.timeout:
                log("Connection timeout")
                break
            if not chunk:
                break
            buffer += chunk

            # ── WebSocket handshake ──────────────────────────────────────────
            if not is_ws:
                text = buffer.decode("utf-8", errors="replace")
                if "\r\n\r\n" in text:
                    conn.sendall(ws_handshake_response(text).encode())
                    is_ws  = True
                    buffer = b""
                    log("WebSocket handshake OK")
                continue

            if len(buffer) < 2:
                continue
            message = ws_decode(buffer).strip()
            buffer  = b""
            if not message:
                continue
            log(f"Recv: {message[:120]}")

            cmd = message.upper()

            # ── STOP ────────────────────────────────────────────────────────
            if cmd == "STOP":
                break

            # ── MEM_STATUS:SYMBOL ────────────────────────────────────────────
            if cmd.startswith("MEM_STATUS:"):
                now = time.time()
                if now - last_status_ts < 2.0:
                    continue
                last_status_ts = now
                parts  = message[11:].split(":")
                symbol = parts[0].strip()
                tf     = parts[1].strip() if len(parts) > 1 else "M15"
                resp   = memory.full_status(symbol, tf)
                safe_send(ws_encode(json.dumps(resp, ensure_ascii=False)))
                continue

            # ── RESULT:SYM:decision_id:pnl ───────────────────────────────────
            if cmd.startswith("RESULT:"):
                parts = message[7:].split(":")
                if len(parts) >= 3:
                    try:
                        symbol      = parts[0]
                        decision_id = int(parts[1])
                        pnl         = float(parts[2])
                        safe_update_result(decision_id, pnl)
                        log(
                            f"[RESULT] decision_id={decision_id}  "
                            f"pnl={pnl:+.3f}  "
                            f"outcome={'win' if pnl > 0 else 'loss'}"
                        )
                        resp = {
                            "signal":      "result_ack",
                            "decision_id": decision_id,
                            "pnl":         pnl,
                            "outcome":     "win" if pnl > 0 else "loss",
                        }
                    except Exception as e:
                        log(f"RESULT error: {e}")
                        resp = {"signal": "error", "comment": str(e)}
                else:
                    resp = {"signal": "error",
                            "comment": "RESULT format: RESULT:SYM:decision_id:pnl"}
                safe_send(ws_encode(json.dumps(resp, ensure_ascii=False)))
                continue

            # ── ANALYZE:SYM:csv  — синхронно, без потока ────────────────────
            if cmd.startswith("ANALYZE:"):
                parts = message[8:].split(":", 1)
                sym   = parts[0].strip()
                csv   = parts[1].strip() if len(parts) > 1 else ""
                # Таймфрейм можно передать как ANALYZE:SYM:TF:csv
                # но для совместимости по умолчанию M15
                tf = "M15"
                # Проверяем, является ли второй сегмент ТФ
                sub_parts = csv.split(":", 1)
                if len(sub_parts) == 2 and sub_parts[0] in (
                        "M1","M5","M15","M30","H1","H4","D1"):
                    tf  = sub_parts[0]
                    csv = sub_parts[1]
                try:
                    prices = [float(x) for x in csv.split(",") if x.strip()]
                except ValueError:
                    prices = []

                if not prices:
                    safe_send(ws_encode(json.dumps(
                        {"signal": "hold", "comment": "parse_error",
                         "confidence": 0.0, "decision_id": -1},
                        ensure_ascii=False
                    )))
                    continue

                result = analyze_with_memory(sym, prices, tf)
                safe_send(ws_encode(json.dumps(result, ensure_ascii=False)))
                continue

            # ── Unknown ───────────────────────────────────────────────────────
            safe_send(ws_encode(json.dumps(
                {"signal": "error",
                 "comment": f"Unknown command: {message[:60]}"},
                ensure_ascii=False
            )))

    except Exception as e:
        log(f"Handler error: {e}")
    finally:
        conn.close()
        log(f"Disconnected: {addr}")


# ─── Main ─────────────────────────────────────────────────────────────────────
def main():
    bar = "═" * 64
    print(bar)
    print("  LLM Memory Agent — xAI / Grok Edition  v1.0")
    print(f"  Address  : ws://{HOST}:{PORT}")
    print(f"  Model    : {MODEL}  (xAI API)")
    print(f"  Database : {DB_PATH}")
    print()
    print("  Commands:")
    print("    ANALYZE:SYM:csv              — get signal with memory context")
    print("    ANALYZE:SYM:TF:csv           — same, explicit timeframe")
    print("    RESULT:SYM:decision_id:pnl   — report trade result")
    print("    MEM_STATUS:SYM:TF            — agent statistics")
    print("    STOP                         — disconnect")
    print(bar)

    srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    srv.bind((HOST, PORT))
    srv.listen(5)
    log("Server started. Waiting for MT5 connection...")

    try:
        while True:
            conn, addr = srv.accept()
            threading.Thread(
                target=handle_client, args=(conn, addr), daemon=True
            ).start()
    except KeyboardInterrupt:
        log("Server stopped.")
    finally:
        srv.close()


if __name__ == "__main__":
    main()
