import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
STATE_DIM = 1    # 1D state dimension
ACTION_DIM = 1   # 1D action dimension
HIDDEN_DIM = 128
LR_ACTOR = 1e-4
LR_CRITIC = 1e-3
GAMMA = 0.99
TAU = 1e-3       # Soft update parameter
BUFFER_SIZE = 100000
BATCH_SIZE = 128

# Replay Buffer
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        
        # Convert to tensors (handling both numpy arrays and tensors)
        states = torch.FloatTensor(
            np.array([s.detach().cpu().numpy() if torch.is_tensor(s) else s for s in states])
        ).to(device)
        
        actions = torch.FloatTensor(
            np.array([a.detach().cpu().numpy() if torch.is_tensor(a) else a for a in actions])
        ).to(device)
        
        rewards = torch.FloatTensor(
            np.array([r.detach().cpu().numpy() if torch.is_tensor(r) else r for r in rewards])
        ).to(device)
        
        next_states = torch.FloatTensor(
            np.array([ns.detach().cpu().numpy() if torch.is_tensor(ns) else ns for ns in next_states])
        ).to(device)
        
        dones = torch.FloatTensor(
            np.array([d.detach().cpu().numpy() if torch.is_tensor(d) else d for d in dones])
        ).to(device)
        
        return states, actions, rewards, next_states, dones
    
    def __len__(self):
        return len(self.buffer)

# Critic Network (Q-function)
class Critic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, 1)
        self.relu = nn.ReLU()
        
    def forward(self, state, action):
        x = torch.cat([state, action], dim=1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        q_value = self.fc3(x)
        return q_value

# Actor Network (Policy)
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, action_dim)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
        
    def forward(self, state):
        x = self.relu(self.fc1(state))
        x = self.relu(self.fc2(x))
        action = self.tanh(self.fc3(x))  # Output between -1 and 1
        return action

# DDPG Agent
class DDPGAgent:
    def __init__(self, state_dim, action_dim):
        self.actor = Actor(state_dim, action_dim, HIDDEN_DIM).to(device)
        self.actor_target = Actor(state_dim, action_dim, HIDDEN_DIM).to(device)
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=LR_ACTOR)
        
        self.critic = Critic(state_dim, action_dim, HIDDEN_DIM).to(device)
        self.critic_target = Critic(state_dim, action_dim, HIDDEN_DIM).to(device)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=LR_CRITIC)
        
        self.replay_buffer = ReplayBuffer(BUFFER_SIZE)
        self.action_dim = action_dim
    
    def select_action(self, state, noise_scale=0.1):
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        action = self.actor(state).cpu().data.numpy().flatten()
        # Add exploration noise
        action += noise_scale * np.random.randn(self.action_dim)
        return np.clip(action, -1, 1)
    
    def update(self):
        if len(self.replay_buffer) < BATCH_SIZE:
            return

        # Sample batch
        states, actions, rewards, next_states, dones = self.replay_buffer.sample(BATCH_SIZE)
        
        # --- Critic Loss ---
        next_actions = self.actor_target(next_states)
        # Ensure next_actions is [batch_size, action_dim]
        if next_actions.dim() > 2:
            next_actions = next_actions.squeeze(1)

        next_states = next_states.view(next_states.shape[0], -1)  # Flatten to [B, C*F]
        target_q = self.critic_target(next_states, next_actions)
        # Target Q must be [batch_size, 1]
        # Ensure all are [batch_size, 1]
        rewards = rewards.unsqueeze(1) if rewards.dim() == 1 else rewards
        rewards = rewards[:, 0:1]  # torch.Size([128, 1])
        dones = dones.unsqueeze(1) if dones.dim() == 1 else dones
        target_q = target_q.unsqueeze(1) if target_q.dim() == 1 else target_q


        # Now compute target Q-values
        target_q = rewards + (1 - dones) * GAMMA * target_q
        
        states = states.view(states.shape[0], -1)  # Flatten to [B, C*F]
        current_q = self.critic(states, actions)
        
        # Explicit shape check (optional)
        assert current_q.shape == target_q.shape, \
            f"Shape mismatch: {current_q.shape} vs {target_q.shape}"
        
        critic_loss = nn.MSELoss()(current_q, target_q.detach())
        
        # --- Rest of the update remains the same ---
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()
        
        # Actor loss
        actor_loss = -self.critic(states, self.actor(states)).mean()
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()
        
        # Soft update targets
        for target, param in zip(self.actor_target.parameters(), self.actor.parameters()):
            target.data.copy_(TAU * param.data + (1 - TAU) * target.data)
        for target, param in zip(self.critic_target.parameters(), self.critic.parameters()):
            target.data.copy_(TAU * param.data + (1 - TAU) * target.data)
    
    def save(self, filename):
        torch.save({
            'actor': self.actor.state_dict(),
            'critic': self.critic.state_dict(),
            'actor_target': self.actor_target.state_dict(),
            'critic_target': self.critic_target.state_dict(),
        }, filename)
    
    def load(self, filename):
        checkpoint = torch.load(filename)
        self.actor.load_state_dict(checkpoint['actor'])
        self.critic.load_state_dict(checkpoint['critic'])
        self.actor_target.load_state_dict(checkpoint['actor_target'])
        self.critic_target.load_state_dict(checkpoint['critic_target'])

# Training loop with random data
if __name__ == "__main__":
    agent = DDPGAgent(STATE_DIM, ACTION_DIM)
    episodes = 10
    max_steps = 200
    
    for episode in range(episodes):
        state = np.random.randn(STATE_DIM)  # Random 1D state
        episode_reward = 0
                    
        for step in range(max_steps):
            # Select action
            action = agent.select_action(state)
                        
            # Simulate environment (random transitions)
            next_state = np.random.randn(STATE_DIM)
            reward = np.random.randn(1)  # 1D reward (will be reshaped to [1,1] in buffer)
            done = 0 if step < max_steps-1 else 1
                        
            # Store transition
            agent.replay_buffer.push(state, action, reward, next_state, done)
                        
            # Update agent
            agent.update()
                        
            state = next_state
            episode_reward += reward[0]
                        
            if done:
                break
                    
        print(f"Episode: {episode+1}, Reward: {episode_reward:.2f}")
