#Dual Networkの実装

from typing import Tuple
from torch import nn
import torch

from my_program_test.board.constant import BOARD_SIZE
from my_program_test.nn.network.res_block import ResidualBlock
from my_program_test.nn.network.head.policy_head import PolicyHead
from my_program_test.nn.network.head.value_head import ValueHead

class DualNet(nn.Module):
    def __init__(self, device: torch.device, board_size: int=BOARD_SIZE):
        #Dual Networkの実装クラス

        super().__init__()
        filters = 32  
        blocks = 9    

        self.device = device

        self.conv_layer = nn.Conv2d(in_channels=29, out_channels=filters, \
            kernel_size=3, padding=1, bias=False)
        self.bn_layer = nn.BatchNorm2d(num_features=filters)
        self.relu = nn.ReLU()
        self.blocks = make_common_blocks(blocks, filters)
        self.policy_head = PolicyHead(board_size, filters)
        self.value_head = ValueHead(board_size, filters)

        self.softmax = nn.Softmax(dim=1)
        self.softmax2 = nn.Softmax(dim=2)

    def forward(self, input_plane: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        #前向き伝播処理を実行する。
        blocks_out = self.blocks(self.relu(self.bn_layer(self.conv_layer(input_plane))))

        return self.policy_head(blocks_out), self.value_head(blocks_out)

    def forward_for_sl(self, input_plane: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        #前向き伝搬処理を実行する。教師有り学習で利用する。
        policy, value = self.forward(input_plane)
        return self.softmax(policy), value


    def forward_with_softmax(self, input_plane: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        #前向き伝搬処理を実行する。
        policy, value = self.forward(input_plane)
        policy = policy.view(256, 2048)
        return self.softmax(policy), self.softmax(value)

    def forward_with_softmax2(self, input_plane: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        #前向き伝搬処理を実行する。
        policy, value = self.forward(input_plane)
        return self.softmax(policy), self.softmax(value)

    def inference(self, input_plane: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        #前向き伝搬処理を実行する。探索用に使うメソッドのため、デバイス間データ転送も内部処理する。
        policy, value = self.forward(input_plane.to(self.device))
        return self.softmax(policy).cpu(), self.softmax(value).cpu()


    def inference_with_policy_logits(self, input_plane: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        #前向き伝搬処理を実行する。Gumbel AlphaZero用の探索に使うメソッドのため、
        #デバイス間データ転送も内部処理する。
        policy, value = self.forward(input_plane.to(self.device))
        return policy.cpu(), self.softmax(value).cpu()

    '''def forward_original(self, input_plane: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        #自分で作った前向き伝播処理
        policy, value = self.forward(input_plane) '''


def make_common_blocks(num_blocks: int, num_filters: int) -> torch.nn.Sequential:
    #DualNetで用いる残差ブロックの塊を構成して返す。
    
    blocks = [ResidualBlock(num_filters) for _ in range(num_blocks)]
    return nn.Sequential(*blocks)