#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
langchain_memory_server.py  — LangChain Memory Agent Server
=============================================================
Статья: "От сигнала к сделке через цепочку агентов:
         LangChain-архитектура с персистентной SQL-памятью поверх MQL5"
Автор:   Evgeniy Koshtenko, 2025

Архитектура:
  WebSocket-сервер (порт 8977)
  + SQLite-память (решения + исходы между сессиями)
  + LangChain цепочка: Signal Agent → News Agent → Risk Agent

╔══════════════════════════════════════════════════════════════╗
║   LANGCHAIN MEMORY AGENT  v1.0                               ║
║   Агент помнит решения и учится на своих ошибках.            ║
║                                                              ║
║  Команды МТ5:                                                ║
║   ANALYZE:SYMBOL:TF:close_csv   → сигнал + decision_id      ║
║   RESULT:SYMBOL:decision_id:pnl → обновить исход             ║
║   MEM_STATUS:SYMBOL:TF          → статистика агента          ║
║   STOP                          → отключиться                ║
╚══════════════════════════════════════════════════════════════╝

Установка:
    pip install langchain langchain-xai langchain-core python-dotenv numpy

Запуск:
    python langchain_memory_server.py

.env:
    XAI_API_KEY=xai-xxxxxxxxxxxxxxxx
"""

import os
import json
import socket
import struct
import base64
import hashlib
import threading
import sqlite3
import time
from datetime import datetime, timezone
from typing import List

from dotenv import load_dotenv
load_dotenv()

try:
    import numpy as np
except ImportError:
    os.system(f"pip install numpy --quiet")
    import numpy as np

from langchain_xai import ChatXAI
from langchain_core.prompts import ChatPromptTemplate

# ─── НАСТРОЙКИ ────────────────────────────────────────────────────────────────
HOST      = "127.0.0.1"
PORT      = 8977
DB_PATH   = "langchain_memory.db"
MODEL     = "grok-3"

# Провайдер LLM — меняй здесь при необходимости
llm = ChatXAI(model=MODEL, temperature=0.1)

# ─── ЛОГИРОВАНИЕ ──────────────────────────────────────────────────────────────
import logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[
        logging.FileHandler("langchain_memory_agent.log", encoding="utf-8"),
        logging.StreamHandler()
    ]
)
log = logging.getLogger(__name__)


# ═══════════════════════════════════════════════════════════════════════════════
# SQL-СЛОЙ ПАМЯТИ  (идентичен llm_server_memory_mql5.py)
# ═══════════════════════════════════════════════════════════════════════════════
class MemoryLayer:
    """Постоянная память агента на базе SQLite."""

    def __init__(self, db_path: str = DB_PATH):
        self._db_path = db_path
        self._lock    = threading.Lock()
        self._conn    = sqlite3.connect(db_path, check_same_thread=False)
        self._init_tables()
        log.info(f"Memory DB: {db_path}")

    def _init_tables(self):
        self._conn.executescript("""
            CREATE TABLE IF NOT EXISTS decisions (
                id          INTEGER PRIMARY KEY AUTOINCREMENT,
                timestamp   TEXT    NOT NULL,
                symbol      TEXT    NOT NULL,
                timeframe   TEXT    NOT NULL,
                signal      TEXT    NOT NULL,
                confidence  REAL    DEFAULT 0.5,
                comment     TEXT    DEFAULT '',
                pnl         REAL,
                outcome     TEXT    DEFAULT 'open',
                rsi         REAL    DEFAULT 50.0,
                ma_align    TEXT    DEFAULT 'mixed',
                volatility  TEXT    DEFAULT 'normal',
                zscore      REAL    DEFAULT 0.0,
                news_risk   TEXT    DEFAULT 'LOW',
                chain_log   TEXT    DEFAULT ''
            );

            CREATE TABLE IF NOT EXISTS agent_stats (
                symbol      TEXT,
                timeframe   TEXT,
                total       INTEGER DEFAULT 0,
                wins        INTEGER DEFAULT 0,
                losses      INTEGER DEFAULT 0,
                avg_pnl     REAL    DEFAULT 0.0,
                best_signal TEXT    DEFAULT 'hold',
                updated_at  TEXT,
                PRIMARY KEY (symbol, timeframe)
            );

            CREATE INDEX IF NOT EXISTS idx_decisions_sym_tf
                ON decisions(symbol, timeframe, outcome);
            CREATE INDEX IF NOT EXISTS idx_decisions_rsi
                ON decisions(rsi);
        """)
        self._conn.commit()

    def write_decision(self, symbol: str, timeframe: str,
                       signal: str, confidence: float, comment: str,
                       rsi: float, ma_align: str, volatility: str,
                       zscore: float, news_risk: str = "LOW",
                       chain_log: str = "") -> int:
        with self._lock:
            cur = self._conn.execute(
                """INSERT INTO decisions
                   (timestamp, symbol, timeframe, signal, confidence,
                    comment, rsi, ma_align, volatility, zscore,
                    news_risk, chain_log)
                   VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
                (datetime.now(timezone.utc).isoformat(), symbol, timeframe,
                 signal, confidence, comment,
                 rsi, ma_align, volatility, zscore,
                 news_risk, chain_log)
            )
            self._conn.commit()
            return cur.lastrowid

    def update_result(self, decision_id: int, pnl: float):
        outcome = "win" if pnl > 0 else "loss"
        with self._lock:
            self._conn.execute(
                "UPDATE decisions SET pnl=?, outcome=? WHERE id=?",
                (pnl, outcome, decision_id)
            )
            self._conn.commit()
        self._refresh_stats(decision_id)

    def _refresh_stats(self, decision_id: int):
        with self._lock:
            row = self._conn.execute(
                "SELECT symbol, timeframe FROM decisions WHERE id=?",
                (decision_id,)
            ).fetchone()
        if not row:
            return
        symbol, timeframe = row

        with self._lock:
            stats = self._conn.execute(
                """SELECT COUNT(*),
                          SUM(CASE WHEN outcome='win'  THEN 1 ELSE 0 END),
                          SUM(CASE WHEN outcome='loss' THEN 1 ELSE 0 END),
                          AVG(pnl)
                   FROM decisions
                   WHERE symbol=? AND timeframe=? AND outcome != 'open'""",
                (symbol, timeframe)
            ).fetchone()

            total, wins, losses, avg_pnl = stats
            total   = total   or 0
            wins    = wins    or 0
            losses  = losses  or 0
            avg_pnl = avg_pnl or 0.0

            buy_wr = self._conn.execute(
                """SELECT AVG(CASE WHEN outcome='win' THEN 1.0 ELSE 0.0 END)
                   FROM decisions
                   WHERE symbol=? AND timeframe=? AND signal='buy'
                   AND outcome != 'open'""",
                (symbol, timeframe)
            ).fetchone()[0] or 0.0

            sell_wr = self._conn.execute(
                """SELECT AVG(CASE WHEN outcome='win' THEN 1.0 ELSE 0.0 END)
                   FROM decisions
                   WHERE symbol=? AND timeframe=? AND signal='sell'
                   AND outcome != 'open'""",
                (symbol, timeframe)
            ).fetchone()[0] or 0.0

            best_signal = "buy" if buy_wr >= sell_wr else "sell"

            self._conn.execute(
                """INSERT INTO agent_stats
                       (symbol, timeframe, total, wins, losses,
                        avg_pnl, best_signal, updated_at)
                   VALUES (?, ?, ?, ?, ?, ?, ?, ?)
                   ON CONFLICT(symbol, timeframe) DO UPDATE SET
                       total=excluded.total, wins=excluded.wins,
                       losses=excluded.losses, avg_pnl=excluded.avg_pnl,
                       best_signal=excluded.best_signal,
                       updated_at=excluded.updated_at""",
                (symbol, timeframe, total, wins, losses,
                 avg_pnl, best_signal, datetime.now(timezone.utc).isoformat())
            )
            self._conn.commit()

    def get_context(self, symbol: str, timeframe: str,
                    rsi: float, limit: int = 5) -> str:
        with self._lock:
            similar = self._conn.execute(
                """SELECT signal, outcome, pnl, comment, rsi
                   FROM decisions
                   WHERE symbol=? AND timeframe=? AND outcome != 'open'
                   AND ABS(rsi - ?) < 8
                   ORDER BY timestamp DESC LIMIT ?""",
                (symbol, timeframe, rsi, limit)
            ).fetchall()

            recent = self._conn.execute(
                """SELECT signal, outcome, pnl, comment, news_risk
                   FROM decisions
                   WHERE symbol=? AND timeframe=? AND outcome != 'open'
                   ORDER BY timestamp DESC LIMIT 3""",
                (symbol, timeframe)
            ).fetchall()

        parts = []
        if similar:
            parts.append(f"Similar RSI conditions (RSI ≈ {rsi:.0f}) in the past:")
            for sig, out, pnl, cmt, r in similar:
                pnl_str = f"{pnl:+.2f}" if pnl is not None else "?"
                parts.append(
                    f"  [{sig.upper():4s}] → {out:4s} ({pnl_str} pts) "
                    f"| RSI≈{r:.0f} | {cmt[:60]}"
                )

        if recent:
            parts.append("Last 3 decisions (any conditions):")
            for sig, out, pnl, cmt, nr in recent:
                pnl_str = f"{pnl:+.2f}" if pnl is not None else "?"
                parts.append(
                    f"  [{sig.upper():4s}] → {out:4s} ({pnl_str} pts) "
                    f"| news_risk={nr} | {cmt[:60]}"
                )

        return "\n".join(parts) if parts else \
            "No historical data yet — acting on current market data only."

    def get_stats(self, symbol: str, timeframe: str) -> str:
        with self._lock:
            row = self._conn.execute(
                "SELECT total, wins, losses, avg_pnl, best_signal "
                "FROM agent_stats WHERE symbol=? AND timeframe=?",
                (symbol, timeframe)
            ).fetchone()

        if not row or row[0] == 0:
            return "No statistics yet — this is a fresh start."

        total, wins, losses, avg_pnl, best_signal = row
        win_rate = 100.0 * wins / total if total else 0.0
        return (
            f"Agent stats on {symbol}/{timeframe}: "
            f"{total} closed trades, win-rate {win_rate:.0f}%, "
            f"avg PnL {avg_pnl:+.2f} pts. "
            f"Historically stronger direction: {best_signal.upper()}."
        )

    def full_status(self, symbol: str, timeframe: str) -> dict:
        with self._lock:
            row = self._conn.execute(
                "SELECT total, wins, losses, avg_pnl, best_signal, updated_at "
                "FROM agent_stats WHERE symbol=? AND timeframe=?",
                (symbol, timeframe)
            ).fetchone()
            open_cnt = self._conn.execute(
                "SELECT COUNT(*) FROM decisions "
                "WHERE symbol=? AND timeframe=? AND outcome='open'",
                (symbol, timeframe)
            ).fetchone()[0]

        if not row:
            return {"symbol": symbol, "timeframe": timeframe,
                    "total": 0, "open": open_cnt}

        total, wins, losses, avg_pnl, best_signal, upd = row
        win_rate = round(100.0 * wins / total, 1) if total else 0.0
        return {
            "symbol":      symbol,
            "timeframe":   timeframe,
            "total":       total,
            "wins":        wins,
            "losses":      losses,
            "win_rate":    win_rate,
            "avg_pnl":     round(avg_pnl, 4),
            "best_signal": best_signal,
            "open":        open_cnt,
            "updated_at":  upd,
        }


# ═══════════════════════════════════════════════════════════════════════════════
# ТЕХНИЧЕСКИЕ ИНДИКАТОРЫ
# ═══════════════════════════════════════════════════════════════════════════════
def calc_indicators(prices: List[float]) -> dict:
    arr = np.array(prices, dtype=float)

    def sma(a, p):
        return float(np.mean(a[-p:])) if len(a) >= p else float(a[-1])

    def ema(a, p):
        if len(a) < 2:
            return float(a[-1])
        k = 2.0 / (p + 1)
        e = float(a[0])
        for v in a[1:]:
            e = float(v) * k + e * (1 - k)
        return e

    def rsi_fn(a, p=14):
        if len(a) < p + 1:
            return 50.0
        d  = np.diff(a)
        g  = np.where(d > 0, d, 0.0)
        l  = np.where(d < 0, -d, 0.0)
        ag, al = np.mean(g[-p:]), np.mean(l[-p:])
        return 100.0 if al == 0 else round(100 - 100 / (1 + ag / al), 2)

    def atr_simple(a, p=14):
        if len(a) < 2:
            return 0.0
        tr = np.abs(np.diff(a))
        return round(float(np.mean(tr[-p:])), 6)

    def bb(a, p=20):
        if len(a) < p:
            return float(a[-1]), float(a[-1]) - 0.001, float(a[-1]) + 0.001
        s   = a[-p:]
        mid = float(np.mean(s))
        std = float(np.std(s))
        return mid, round(mid - 2 * std, 6), round(mid + 2 * std, 6)

    def stoch_k(a, p=14):
        if len(a) < p:
            return 50.0
        h, l = np.max(a[-p:]), np.min(a[-p:])
        return round(100 * (a[-1] - l) / (h - l), 2) if h != l else 50.0

    current           = float(arr[-1])
    ma5               = sma(arr, 5)
    ma20              = sma(arr, 20)
    ma50              = sma(arr, 50)
    ema20             = ema(arr, 20)
    ema50             = ema(arr, 50)
    rsi_val           = rsi_fn(arr)
    atr_val           = atr_simple(arr)
    atr20             = atr_simple(arr, 20)
    bb_mid, bb_lo, bb_hi = bb(arr)
    stoch             = stoch_k(arr)

    if len(arr) >= 20:
        mu     = float(np.mean(arr[-20:]))
        sigma  = float(np.std(arr[-20:]))
        zscore = round((current - mu) / sigma, 2) if sigma > 0 else 0.0
    else:
        zscore = 0.0

    if atr20 > 0:
        ratio = atr_val / atr20
        volatility = "high" if ratio > 1.3 else ("low" if ratio < 0.7 else "normal")
    else:
        volatility = "normal"

    trend    = ("up"   if ma5 > ma20 > ma50
                else ("down" if ma5 < ma20 < ma50 else "mixed"))
    ma_align = "above" if current > ma20 else ("below" if current < ma20 else "mixed")

    return {
        "current":    round(current, 5),
        "ma5":        round(ma5, 5),
        "ma20":       round(ma20, 5),
        "ma50":       round(ma50, 5),
        "ema20":      round(ema20, 5),
        "ema50":      round(ema50, 5),
        "rsi":        rsi_val,
        "atr":        atr_val,
        "bb_mid":     round(bb_mid, 5),
        "bb_lo":      round(bb_lo, 5),
        "bb_hi":      round(bb_hi, 5),
        "stoch":      stoch,
        "zscore":     zscore,
        "trend":      trend,
        "ma_align":   ma_align,
        "volatility": volatility,
        "above_ma20": current > ma20,
        "above_ma50": current > ma50,
    }


# ═══════════════════════════════════════════════════════════════════════════════
# LANGCHAIN ПРОМПТЫ
# ═══════════════════════════════════════════════════════════════════════════════

SIGNAL_PROMPT = ChatPromptTemplate.from_messages([
    ("system", """You are a disciplined quantitative trading agent with persistent memory.
You trade FOREX and CFD instruments using technical analysis.

=== YOUR SELF-KNOWLEDGE ===
{agent_stats}

=== HISTORICAL CONTEXT (from your SQLite memory) ===
{memory_context}

=== BEHAVIORAL RULES ===
1. If your win-rate on this instrument is below 40% — be extra conservative.
2. If the last 3 decisions were all losses — reduce confidence by 0.1.
3. If historical context shows repeated failures on one signal type — downgrade it.
4. If volatility is HIGH — prefer hold unless confidence > 0.75.
5. Never ignore your own track record. Adapt.

Calculate EMA(20), EMA(50), RSI(14) from the provided close prices.
Determine trend direction and entry quality.
Return ONLY valid JSON (no markdown, no explanation):
{{
  "signal":     "buy" | "sell" | "hold",
  "confidence": <0.0–1.0>,
  "ema20":      <number>,
  "ema50":      <number>,
  "rsi14":      <number>,
  "trend":      "up" | "down" | "flat",
  "comment":    "<one sentence in Russian>"
}}"""),
    ("human", "{market_data}")
])

NEWS_PROMPT = ChatPromptTemplate.from_messages([
    ("system", """You are a news filter for an algorithmic trading system.
Evaluate whether there is a high-volatility market risk right now or in the next 30 minutes:
  — key macro data releases (NFP, CPI, GDP, PMI, central bank decisions)
  — geopolitical events that could move the market sharply
  — major session open with anomalous gap

Return ONLY valid JSON (no markdown, no explanation):
{{
  "news_risk": "HIGH" | "LOW",
  "reason":    "<one sentence in Russian>"
}}"""),
    ("human", "Instrument: {symbol}\nUTC time: {utc_time}")
])

RISK_PROMPT = ChatPromptTemplate.from_messages([
    ("system", """You are a risk manager for an algorithmic trading system.

Hard rules (never break):
  — Risk per trade: max 1% of balance
  — Max open positions: 3
  — Min stop-loss: 15 pips
  — TP/SL ratio: at least 1.5
  — Max volume: 1.0 lot
  — If spread > 30 points: reject

Calculate volume from 1% risk:
  volume = (balance * 0.01) / (sl_pips * 10)
  (pip_value = 10 USD per lot for majors)

=== AGENT MEMORY HINT ===
{agent_stats}

Return ONLY valid JSON (no markdown, no explanation):
{{
  "approved":  true | false,
  "volume":    <0.01–1.0>,
  "sl_pips":   <integer>,
  "tp_pips":   <integer>,
  "reason":    "<one sentence in Russian>"
}}"""),
    ("human", "Signal:\n{signal_data}\n\nAccount:\n{account_data}")
])

signal_chain = SIGNAL_PROMPT | llm
news_chain   = NEWS_PROMPT   | llm
risk_chain   = RISK_PROMPT   | llm


# ═══════════════════════════════════════════════════════════════════════════════
# ПАРСИНГ LLM-ОТВЕТОВ
# ═══════════════════════════════════════════════════════════════════════════════
def parse_llm(raw: str) -> dict:
    text = raw.strip()
    # Убираем markdown
    if text.startswith("```"):
        lines = text.splitlines()
        text = "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:])
    # Убираем thinking-блоки reasoning-моделей
    if "<think>" in text and "</think>" in text:
        text = text[text.find("</think>") + 8:].strip()
    # Ищем JSON
    start = text.find("{")
    end   = text.rfind("}") + 1
    if start >= 0 and end > start:
        return json.loads(text[start:end])
    raise ValueError(f"No JSON found in: {text[:100]}")


# ═══════════════════════════════════════════════════════════════════════════════
# ГЛОБАЛЬНАЯ ПАМЯТЬ
# ═══════════════════════════════════════════════════════════════════════════════
memory  = MemoryLayer(DB_PATH)
db_lock = threading.Lock()


def safe_write_decision(**kwargs) -> int:
    with db_lock:
        return memory.write_decision(**kwargs)


def safe_update_result(decision_id: int, pnl: float):
    memory.update_result(decision_id, pnl)


# ═══════════════════════════════════════════════════════════════════════════════
# ОСНОВНАЯ ФУНКЦИЯ АНАЛИЗА — LangChain цепочка + память
# ═══════════════════════════════════════════════════════════════════════════════
def analyze_with_langchain_memory(symbol: str, prices: List[float],
                                  timeframe: str = "M15",
                                  account: dict = None) -> dict:
    """
    Полный цикл:
      1. Считаем индикаторы
      2. Загружаем контекст из SQLite-памяти
      3. Signal Agent (с памятью в промпте)
      4. News Agent
      5. Risk Agent (с памятью в промпте)
      6. Записываем решение в базу
      7. Возвращаем {signal, comment, confidence, decision_id, ...}
    """
    if not prices:
        return {"signal": "hold", "comment": "no_data",
                "confidence": 0.0, "decision_id": -1}

    if account is None:
        account = {"balance": 10000, "equity": 10000,
                   "free_margin": 10000, "open_positions": 0}

    # 1. Индикаторы
    ind      = calc_indicators(prices)
    rsi_val  = ind["rsi"]
    utc_now  = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M UTC")

    # 2. Контекст из памяти
    agent_stats    = memory.get_stats(symbol, timeframe)
    memory_context = memory.get_context(symbol, timeframe, rsi_val)

    market_data = json.dumps({
        "symbol":    symbol,
        "timeframe": timeframe,
        "close_last_20": prices[-20:],
        "current":   ind["current"],
        "ma5":       ind["ma5"],
        "ma20":      ind["ma20"],
        "ma50":      ind["ma50"],
        "ema20":     ind["ema20"],
        "ema50":     ind["ema50"],
        "rsi":       rsi_val,
        "stoch":     ind["stoch"],
        "atr":       ind["atr"],
        "bb_lo":     ind["bb_lo"],
        "bb_hi":     ind["bb_hi"],
        "zscore":    ind["zscore"],
        "trend":     ind["trend"],
        "volatility":ind["volatility"],
    }, ensure_ascii=False)

    chain_log = []

    # 3. Signal Agent
    try:
        t0      = time.time()
        sig_raw = signal_chain.invoke({
            "market_data":    market_data,
            "agent_stats":    agent_stats,
            "memory_context": memory_context,
        })
        signal  = parse_llm(sig_raw.content)
        chain_log.append(f"signal={signal.get('signal')} conf={signal.get('confidence',0):.2f} [{round(time.time()-t0,1)}s]")
        log.info(f"[{symbol}] Signal: {signal.get('signal')} conf={signal.get('confidence',0):.2f}")
    except Exception as e:
        log.error(f"Signal agent error: {e}")
        return {"signal": "hold", "comment": "signal_agent_error",
                "confidence": 0.0, "decision_id": -1}

    sig_name = str(signal.get("signal", "hold")).lower()
    conf     = float(signal.get("confidence", 0.0))

    # Фильтр по уверенности
    if conf < 0.0001 or sig_name == "hold":
        decision_id = safe_write_decision(
            symbol=symbol, timeframe=timeframe,
            signal="hold", confidence=conf,
            comment=signal.get("comment", "low_confidence"),
            rsi=rsi_val, ma_align=ind["ma_align"],
            volatility=ind["volatility"], zscore=ind["zscore"],
            news_risk="LOW", chain_log=" | ".join(chain_log)
        )
        return {"signal": "hold",
                "comment": f"low_confidence:{conf:.2f}",
                "confidence": conf,
                "decision_id": decision_id}

    # 4. News Agent
    news_risk = "LOW"
    news_reason = ""
    try:
        t0       = time.time()
        news_raw = news_chain.invoke({"symbol": symbol, "utc_time": utc_now})
        news     = parse_llm(news_raw.content)
        news_risk   = news.get("news_risk", "LOW")
        news_reason = news.get("reason", "")
        chain_log.append(f"news={news_risk} [{round(time.time()-t0,1)}s]")
        log.info(f"[{symbol}] News: {news_risk} — {news_reason}")
    except Exception as e:
        log.warning(f"News agent error: {e} — skipping filter")
        news = {"news_risk": "LOW", "reason": "agent_unavailable"}

    if news_risk == "HIGH":
        decision_id = safe_write_decision(
            symbol=symbol, timeframe=timeframe,
            signal="hold", confidence=conf,
            comment=f"news_block: {news_reason}",
            rsi=rsi_val, ma_align=ind["ma_align"],
            volatility=ind["volatility"], zscore=ind["zscore"],
            news_risk="HIGH", chain_log=" | ".join(chain_log)
        )
        return {"signal": "hold",
                "comment": f"news_risk_HIGH: {news_reason}",
                "confidence": conf,
                "decision_id": decision_id}

    # 5. Risk Agent
    try:
        t0       = time.time()
        risk_raw = risk_chain.invoke({
            "signal_data":  json.dumps(signal, ensure_ascii=False),
            "account_data": json.dumps(account, ensure_ascii=False),
            "agent_stats":  agent_stats,
        })
        risk = parse_llm(risk_raw.content)
        chain_log.append(f"risk={'OK' if risk.get('approved') else 'REJECT'} vol={risk.get('volume',0)} [{round(time.time()-t0,1)}s]")
        log.info(f"[{symbol}] Risk: approved={risk.get('approved')} vol={risk.get('volume',0)}")
    except Exception as e:
        log.error(f"Risk agent error: {e}")
        return {"signal": "hold", "comment": "risk_agent_error",
                "confidence": conf, "decision_id": -1}

    if not risk.get("approved", False):
        decision_id = safe_write_decision(
            symbol=symbol, timeframe=timeframe,
            signal="hold", confidence=conf,
            comment=f"risk_reject: {risk.get('reason','')}",
            rsi=rsi_val, ma_align=ind["ma_align"],
            volatility=ind["volatility"], zscore=ind["zscore"],
            news_risk=news_risk, chain_log=" | ".join(chain_log)
        )
        return {"signal": "hold",
                "comment": f"risk_rejected: {risk.get('reason','')}",
                "confidence": conf,
                "decision_id": decision_id}

    # 6. Финальный сигнал — пишем в базу
    final_signal = sig_name  # buy | sell
    decision_id  = safe_write_decision(
        symbol=symbol, timeframe=timeframe,
        signal=final_signal,
        confidence=conf,
        comment=signal.get("comment", ""),
        rsi=rsi_val, ma_align=ind["ma_align"],
        volatility=ind["volatility"], zscore=ind["zscore"],
        news_risk=news_risk,
        chain_log=" | ".join(chain_log)
    )

    result = {
        "signal":      final_signal,
        "comment":     signal.get("comment", ""),
        "confidence":  conf,
        "decision_id": decision_id,
        "volume":      round(float(risk.get("volume", 0.01)), 2),
        "sl_pips":     int(risk.get("sl_pips", 20)),
        "tp_pips":     int(risk.get("tp_pips", 40)),
        "news_risk":   news_risk,
        "indicators":  ind,
        "chain_log":   " | ".join(chain_log),
    }

    log.info(
        f"[{symbol}] → {final_signal.upper()} | "
        f"conf={conf:.2f} | decision_id={decision_id} | "
        f"vol={result['volume']} sl={result['sl_pips']} tp={result['tp_pips']}"
    )
    return result


# ═══════════════════════════════════════════════════════════════════════════════
# WebSocket helpers
# ═══════════════════════════════════════════════════════════════════════════════
def ws_handshake_response(request: str) -> str:
    key = ""
    for line in request.split("\r\n"):
        if "Sec-WebSocket-Key" in line:
            key = line.split(": ", 1)[1].strip()
            break
    magic  = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
    accept = base64.b64encode(
        hashlib.sha1((key + magic).encode()).digest()
    ).decode()
    return (
        "HTTP/1.1 101 Switching Protocols\r\n"
        "Upgrade: websocket\r\n"
        "Connection: Upgrade\r\n"
        f"Sec-WebSocket-Accept: {accept}\r\n\r\n"
    )


def ws_decode(data: bytes):
    """Декодирует один WS-фрейм. Возвращает (message, consumed). consumed=0 — фрейм неполный."""
    if len(data) < 2:
        return "", 0
    masked = bool(data[1] & 0x80)
    plen   = data[1] & 0x7F
    if plen == 126:
        if len(data) < 4:
            return "", 0
        payload_len = struct.unpack(">H", data[2:4])[0]
        off = 4
    elif plen == 127:
        if len(data) < 10:
            return "", 0
        payload_len = struct.unpack(">Q", data[2:10])[0]
        off = 10
    else:
        payload_len = plen
        off = 2
    if masked:
        off += 4
    if len(data) < off + payload_len:
        return "", 0
    if masked:
        mask    = data[off - 4: off]
        payload = data[off: off + payload_len]
        msg = bytearray(b ^ mask[i % 4] for i, b in enumerate(payload)).decode("utf-8", errors="replace")
    else:
        msg = data[off: off + payload_len].decode("utf-8", errors="replace")
    return msg, off + payload_len


def ws_encode(message: str) -> bytes:
    payload = message.encode("utf-8")
    n       = len(payload)
    header  = bytearray([0x81])
    if n <= 125:
        header.append(n)
    elif n <= 65535:
        header.append(126)
        header += struct.pack(">H", n)
    else:
        header.append(127)
        header += struct.pack(">Q", n)
    return bytes(header) + payload


# ═══════════════════════════════════════════════════════════════════════════════
# ОБРАБОТЧИК КЛИЕНТА
# ═══════════════════════════════════════════════════════════════════════════════
def handle_client(conn: socket.socket, addr):
    log.info(f"Connected: {addr}")
    is_ws          = False
    buffer         = b""
    send_lock      = threading.Lock()
    last_status_ts = 0.0

    def safe_send(data: bytes):
        with send_lock:
            try:
                conn.sendall(data)
            except Exception as e:
                log.error(f"send error: {e}")

    try:
        conn.settimeout(600.0)
        while True:
            try:
                chunk = conn.recv(65536)
            except socket.timeout:
                log.info("Connection timeout")
                break
            if not chunk:
                break
            buffer += chunk

            # WebSocket handshake
            if not is_ws:
                text = buffer.decode("utf-8", errors="replace")
                if "\r\n\r\n" in text:
                    conn.sendall(ws_handshake_response(text).encode())
                    is_ws  = True
                    buffer = b""
                    log.info("WebSocket handshake OK")
                continue

            # Обрабатываем все полные фреймы из буфера за один recv
            while len(buffer) >= 2:
                message, consumed = ws_decode(buffer)
                if consumed == 0:
                    break  # фрейм ещё не полный
                buffer  = buffer[consumed:]
                message = message.strip()
                if not message:
                    continue
                log.info(f"Recv: {message[:120]}")

                cmd = message.upper()

                # STOP
                if cmd == "STOP":
                    return  # выходим из handle_client

                # MEM_STATUS:SYMBOL:TF
                if cmd.startswith("MEM_STATUS:"):
                    now = time.time()
                    if now - last_status_ts < 2.0:
                        log.info("MEM_STATUS throttled, skipping")
                        continue
                    last_status_ts = now
                    parts  = message[11:].split(":")
                    symbol = parts[0].strip()
                    tf     = parts[1].strip() if len(parts) > 1 else "M15"
                    try:
                        resp = memory.full_status(symbol, tf)
                        data = json.dumps(resp, ensure_ascii=False)
                        log.info(f"MEM_STATUS sending {len(data)} bytes")
                        safe_send(ws_encode(data))
                        log.info("MEM_STATUS sent OK")
                    except Exception as e:
                        log.error(f"MEM_STATUS error: {e}")
                    continue

                # RESULT:SYM:decision_id:pnl
                if cmd.startswith("RESULT:"):
                    parts = message[7:].split(":")
                    if len(parts) >= 3:
                        try:
                            symbol      = parts[0]
                            decision_id = int(parts[1])
                            pnl         = float(parts[2])
                            safe_update_result(decision_id, pnl)
                            log.info(
                                f"[RESULT] decision_id={decision_id} "
                                f"pnl={pnl:+.3f} "
                                f"outcome={'win' if pnl > 0 else 'loss'}"
                            )
                            resp = {
                                "signal":      "result_ack",
                                "decision_id": decision_id,
                                "pnl":         pnl,
                                "outcome":     "win" if pnl > 0 else "loss",
                            }
                        except Exception as e:
                            log.error(f"RESULT error: {e}")
                            resp = {"signal": "error", "comment": str(e)}
                    else:
                        resp = {"signal": "error",
                                "comment": "Format: RESULT:SYM:decision_id:pnl"}
                    safe_send(ws_encode(json.dumps(resp, ensure_ascii=False)))
                    continue

                # ANALYZE:SYM:TF:csv[|account_json]
                if cmd.startswith("ANALYZE:"):
                    parts = message[8:].split(":", 2)
                    if len(parts) < 2:
                        safe_send(ws_encode(json.dumps(
                            {"signal": "hold", "comment": "parse_error",
                             "confidence": 0.0, "decision_id": -1}
                        )))
                        continue

                    sym = parts[0].strip()
                    tf  = "M15"

                    # parts после split(":", 2): [SYM, TF, csv] или [SYM, csv]
                    if len(parts) == 3:
                        # ANALYZE:SYM:TF:csv  — нормальный формат
                        tf           = parts[1].strip()
                        csv_and_rest = parts[2]
                    else:
                        # ANALYZE:SYM:csv  — без явного TF
                        csv_and_rest = parts[1]

                    account = None
                    if "|" in csv_and_rest:
                        csv_part, acc_part = csv_and_rest.split("|", 1)
                        try:
                            account = json.loads(acc_part)
                        except Exception:
                            pass
                    else:
                        csv_part = csv_and_rest

                    try:
                        prices = [float(x) for x in csv_part.split(",") if x.strip()]
                    except ValueError:
                        prices = []

                    if not prices:
                        safe_send(ws_encode(json.dumps(
                            {"signal": "hold", "comment": "parse_error",
                             "confidence": 0.0, "decision_id": -1}
                        )))
                        continue

                    result = analyze_with_langchain_memory(sym, prices, tf, account)
                    safe_send(ws_encode(json.dumps(result, ensure_ascii=False)))
                    continue

                # Unknown
                safe_send(ws_encode(json.dumps(
                    {"signal": "error",
                     "comment": f"Unknown command: {message[:60]}"},
                    ensure_ascii=False
                )))

    except Exception as e:
        log.error(f"Handler error: {e}")
    finally:
        conn.close()
        log.info(f"Disconnected: {addr}")


# ─── Main ──────────────────────────────────────────────────────────────────────
def main():
    bar = "═" * 64
    print(bar)
    print("  LangChain Memory Agent Server  v1.0")
    print(f"  Address  : ws://{HOST}:{PORT}")
    print(f"  Model    : {MODEL}  (LangChain / xAI)")
    print(f"  Database : {DB_PATH}")
    print()
    print("  Commands:")
    print("    ANALYZE:SYM:TF:csv                  — get signal (chain + memory)")
    print("    ANALYZE:SYM:TF:csv|{account_json}   — same + account data")
    print("    RESULT:SYM:decision_id:pnl           — report trade result")
    print("    MEM_STATUS:SYM:TF                    — agent statistics")
    print("    STOP                                 — disconnect")
    print(bar)

    srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    srv.bind((HOST, PORT))
    srv.listen(5)
    log.info("LangChain Memory Agent started. Waiting for MT5...")

    try:
        while True:
            conn, addr = srv.accept()
            threading.Thread(
                target=handle_client, args=(conn, addr), daemon=True
            ).start()
    except KeyboardInterrupt:
        log.info("Server stopped.")
    finally:
        srv.close()


if __name__ == "__main__":
    main()
