import time
from copy import deepcopy
import numpy as np
import torch
import os
import math
import signal

from mcts.node import Node
from mcts.config import N_SCORE
from mcts.utils import dist_v_to_exp_v, score_to_idx
from nn.feature import generate_input_planes
from nn.utility import get_torch_device, load_network
from nn.network.dual_net import DualNet
from other.other import generate_move_from_policy, index_to_shot
from board.constant import BOARD_SIZE
from dc3client.models import StoneRotation

current_dir = os.path.dirname(__file__)
model_path = os.path.join(current_dir, "..\\model\\sl-model.bin")

class TimeoutException(Exception):
    pass

def timeout_handler(signum, frame):
    raise TimeoutException("Timeout: step function exceeded 20 seconds")

class Player:
    def __init__(self, ucb_const, pw_const, init_temperature, out_temperature, num_init, num_sample, l, max_depth, stone_simulator):
        self.root_node = Node(0.0, 0.0, 0.0, 0.0)
        self.root_env = None
        self.env_dict = None
        self.num_node = None
        self.num_playout = None

        self.ucb_const = ucb_const
        self.pw_const = pw_const 
        self.num_init = num_init
        self.device = get_torch_device(use_gpu=False)
        self.network = load_network(model_path, use_gpu=False)
        self.network.to(self.device)

        self.init_temperature = init_temperature
        self.out_temperature = out_temperature
        self.num_sample = num_sample 
        self.l = l  # control the area of sample space
        self.max_depth = max_depth

        self.stone_simulator = stone_simulator

        self.kari = 0

    def reset(self, root_env, num_playout, num_init):
        self.root_node = Node(0.0, 0.0, 0.0, 0.0)
        self.root_env = root_env
        self.env_dict = {"": deepcopy(self.root_env)}
        self.num_node = 0
        self.num_playout = num_playout
        self.num_init = num_init

    def think(self, root_env, num_playout, num_init):
        self.reset(root_env, num_playout, num_init)

        # 1. ルート盤面からネットワーク予測を実施
        root_prediction_p, _ = self.predict(self.root_env)
        root_prediction_p = root_prediction_p.view(-1)
        self.prepare_init_actions(self.root_node, root_prediction_p)
        
        # 2. 初期アクション情報からルート直下の候補子ノードを作成
        candidates = []
        for init_info in self.root_node.m_init_infos:
            init_action_id, init_action_prob = init_info
            move = index_to_shot(init_action_id)
            child = self.root_node.add_node(move[0], move[1], move[2], init_action_prob)
            candidates.append(child)
        # ルートの子リストとして候補を登録
        self.root_node.m_children = candidates
        print(len(self.root_node.m_children))

        # 3. Sequential Halving により探索予算内で各候補をシミュレーション
        self.sequential_halving_search(num_playout)
        print("kari: ", self.kari)
        return self.sample_best_action_sequential_halving()

    def sequential_halving_search(self, total_budget):
        """
        ルート直下の候補に対して、与えられたシミュレーション予算内で
        Sequential Halving を実施し、下位候補を段階的に削除する。
        """
        candidates = self.root_node.m_children
        num_candidates = len(candidates)
        if num_candidates == 0:
            print("num_candidates == 0")
            return

        #sequential halvingを何段階行うかの定数
        rounds = math.floor(math.log2(num_candidates)) + 1
        #rounds = 1
        remaining_budget = total_budget

        print("Starting Sequential Halving with {} candidates and {} total simulations.".format(num_candidates, total_budget))

        for round in range(rounds):
            n_candidates = len(candidates)
            if n_candidates == 0:
                print("n_candidates == 0")
                break
            # 各候補に割り当てるシミュレーション回数（最低1回）
            simulations_per_candidate = max(1, remaining_budget // (n_candidates * (rounds - round)))
            print("Round {}: {} candidates, {} simulations per candidate.".format(round+1, n_candidates, simulations_per_candidate))
            for candidate in candidates:
                for _ in range(simulations_per_candidate):
                    candidate_env = self.get_env(candidate)
                    if candidate_env is None:
                        continue
                    # 各候補からのシミュレーションは deepcopy した盤面で行う
                    sim_env = deepcopy(candidate_env)
                    value = self.play_simulation(1, sim_env, candidate)#ここが怪しい
                    # play_simulation 内で candidate の統計 (m_v, m_visits) は更新される
            used = n_candidates * simulations_per_candidate
            remaining_budget = max(0, remaining_budget - used)
            # 候補を平均評価値 (m_v/m_visits) でソートし、上位半分を残す
            def avg(candidate):
                return candidate.m_v / candidate.m_visits if candidate.m_visits > 0 else float('-inf')
            candidates.sort(key=avg, reverse=True)
            keep = math.ceil(n_candidates / 2)#ceil:切り上げ
            candidates = candidates[:keep]#ここで下位を切り捨て
            print("After round {}: {} candidates remain.".format(round+1, len(candidates)))
            self.root_node.m_children = candidates
            if remaining_budget == 0:
                break

    def sample_best_action_sequential_halving(self):
        """
        最終的に残った候補の中から、平均評価値が最も高い候補を最善手として返す。
        """
        best_candidate = None
        best_avg = float('-inf')
        for candidate in self.root_node.m_children:
            if candidate.m_visits > 0:
                avg_value = candidate.m_v / candidate.m_visits
            else:
                avg_value = float('-inf')
            if avg_value > best_avg:
                best_avg = avg_value
                best_candidate = candidate
        print("Best candidate move:", best_candidate.m_move, "Avg value:", best_avg, "visit: ", best_candidate.m_visits)
        return best_candidate.m_move, best_avg

    def predict(self, cur_env):
        if self.root_env.game_state["end"] != cur_env.game_state["end"]:
            leaf_dist_v = [0] * N_SCORE
            leaf_dist_v[score_to_idx(cur_env.game_state["score"][self.root_env.game_state["end"]])] = 1
            leaf_black_eval = dist_v_to_exp_v(leaf_dist_v)
            return None, leaf_black_eval
        else:
            input_data = generate_input_planes(cur_env.game_state["stones"], cur_env.game_state["num_shot"])
            input_plane = torch.tensor(input_data.reshape(1, 29, BOARD_SIZE, BOARD_SIZE))
            prediction_p, prediction_v = self.network.forward_with_softmax2(input_plane)
            prediction_p = prediction_p.reshape(BOARD_SIZE * BOARD_SIZE * 2)
            max_index = np.argmax(prediction_p)
            x, y, rotation = index_to_shot(max_index)
            next_env = cur_env.step(x, y, rotation, self.stone_simulator)
            _, leaf_black_eval = self.predict(next_env)
            #leaf_black_eval = -1 * leaf_black_eval
            '''
            if cur_env.game_state["WhiteToMove"]:
                prediction_v = torch.flip(prediction_v, [1])
            leaf_black_eval = dist_v_to_exp_v(prediction_v.numpy())'''
        return prediction_p, leaf_black_eval

    def prepare_init_actions(self, node, prediction_p):
        factor = 1.0
        prev_init_action_prob = 0.0
        for _ in range(self.num_init):
            init_action_id = np.random.choice(np.arange(len(prediction_p)), p=self.apply_temperature(prediction_p, self.init_temperature))
            init_action_prob = prediction_p[init_action_id]
            factor = factor * (1.0 - prev_init_action_prob)
            prev_init_action_prob = init_action_prob
            node.m_init_infos.append((init_action_id, (init_action_prob * factor)))
            prediction_p[init_action_id] = 0.0
            prediction_p = prediction_p + 1.0E-6
            prediction_p = prediction_p / prediction_p.sum()

    def apply_temperature(self, distribution, temperature):
        if temperature < 0.1:
            probabilities = np.zeros(distribution.shape[0])
            probabilities[np.argmax(distribution)] = 1.
        else:
            log_probabilities = np.log(distribution)
            log_probabilities = log_probabilities * (1 / temperature)
            log_probabilities = log_probabilities - log_probabilities.max()
            probabilities = np.exp(log_probabilities)
            probabilities = probabilities / probabilities.sum()
            probabilities = probabilities.numpy()
        return probabilities

    def play_simulation(self, cur_depth, cur_env, cur_node):
        is_expanded = False
        if (self.root_env.game_state["end"] != cur_env.game_state["end"]):
            prediction_p, leaf_black_eval = self.predict(cur_env)
            cur_node.kr_update(leaf_black_eval)
            return leaf_black_eval
        else:
            prediction_p, leaf_black_eval = self.predict(cur_env)
            cur_node.kr_update(leaf_black_eval)
            return leaf_black_eval
        '''
        num_child = len(cur_node.m_children)
        num_total_child_visits = cur_node.m_visits
        if num_child < self.num_init:
            init_action_id, init_action_prob = cur_node.m_init_infos[num_child]
            init_action = index_to_shot(init_action_id)
            expanded_node = cur_node.add_node(init_action[0], init_action[1], init_action[2], init_action_prob)
            self.num_node += 1
            next_env = self.get_env_change(cur_env, expanded_node)
            prediction_p, leaf_black_eval = self.predict(next_env)
            if prediction_p is not None:
                prediction_p = prediction_p.view(-1)
                self.prepare_init_actions(expanded_node, prediction_p)
            expanded_node.kr_update(leaf_black_eval)
            is_expanded = True
        else:
            selected_node = cur_node.ucb_select(cur_env.game_state["WhiteToMove"], self.ucb_const)
            if num_total_child_visits < self.pw_const * (num_child ** 2):
                selected_env = self.get_env_change(cur_env, selected_node)
                leaf_black_eval, is_expanded = self.play_simulation(cur_depth + 1, selected_env, selected_node)
            if not is_expanded:
                sample_move = selected_node.sample_move(self.num_sample, self.l)
                expanded_node = cur_node.add_node(sample_move[0], sample_move[1], sample_move[2], 0)
                self.num_node += 1
                next_env = self.get_env_change(cur_env, expanded_node)
                prediction_p, leaf_black_eval = self.predict(next_env)
                if prediction_p is not None:
                    prediction_p = prediction_p.view(-1)
                    self.prepare_init_actions(expanded_node, prediction_p)
                expanded_node.kr_update(leaf_black_eval)
                is_expanded = True
        cur_node.kr_update(leaf_black_eval)
        return leaf_black_eval, is_expanded'''

    def get_env(self, cur_node):
        cur_node_key = self.node_key(cur_node)
        #if cur_node_key not in self.env_dict.keys():
        prev_env = self.env_dict[self.node_key(cur_node.m_parent)]
        cur_env = deepcopy(prev_env)
        try:
            cur_env = cur_env.step(cur_node.m_move[0], cur_node.m_move[1], cur_node.m_move[2], self.stone_simulator)
        except TimeoutException as e:
            print(e)
            return None
        self.env_dict[cur_node_key] = cur_env
        return self.env_dict[cur_node_key]

    def get_env_15(self, cur_node):
        prev_env = self.env_dict[self.node_key(cur_node.m_parent)]
        cur_env = deepcopy(prev_env)
        cur_env = cur_env.step(cur_node.m_move[0], cur_node.m_move[1], cur_node.m_move[2], self.stone_simulator)
        return cur_env

    def get_env_change(self, cur_env, cur_node):
        if cur_env.game_state["num_shot"] == 15:
            return self.get_env_15(cur_node)
        else:
            print("else")
            return self.get_env(cur_node)

    def node_key(self, cur_node):
        key = ""
        while cur_node.m_parent is not None:
            key = f"({cur_node.m_move[0]:.4f} {cur_node.m_move[1]:.4f} {cur_node.m_move[2]})->{key}"
            cur_node = cur_node.m_parent
        return key

    # 旧実装（参考用）
    def old_sample_best_action(self):
        kns = np.array([c.m_visits for c in self.root_node.m_children])
        kns = kns / sum(kns)
        best_id = np.random.choice(np.arange(len(kns)), p=self.apply_temperature(kns, self.out_temperature))
        value = self.root_node.m_children[best_id].m_v / self.root_node.m_children[best_id].m_visits
        print("v: ", self.root_node.m_children[best_id].m_v, "visit: ", self.root_node.m_children[best_id].m_visits, "value: ", value)
        total_visits = sum(child.m_visits for child in self.root_node.m_children)
        print("Total visits of all child nodes:", total_visits)
        return self.root_node.m_children[best_id].m_move, value

    def print_winrate(self):
        min_val = -10
        num = 0
        for child in self.root_node.m_children:
            if min_val < (child.m_v / child.m_visits):
                min_val = (child.m_v / child.m_visits)
                print((child.m_v / child.m_visits), "aaa: ", child.m_visits)
                minchild = child
                minnum = num
            num += 1
        print(min_val)
        print(minchild)
        print("winrate: ", self.root_node.m_children[minnum].m_move)

    def sample_best_action(self):
        min_val = -10
        num = 0
        for child in self.root_node.m_children:
            if (min_val < (child.m_v / child.m_visits)):
                min_val = child.m_v / child.m_visits
                minchild = child
                minnum = num
            num += 1
        print("v: ", minchild.m_v, "visit: ", minchild.m_visits, "value: ", min_val)
        return self.root_node.m_children[minnum].m_move, min_val
