"""学習データの生成処理。
"""
import glob
import os
import random
import json
from typing import List, NoReturn
import numpy as np
from nn.feature import generate_input_planes, generate_target_data, generate_value_data
from learning_param import BATCH_SIZE, DATA_SET_SIZE


def _save_data(save_file_path: str, input_data: np.ndarray, policy_data: np.ndarray,\
    value_data: np.ndarray, log_counter: int) -> NoReturn:
    """学習データをnpzファイルとして出力する。

    Args:
        save_file_path (str): 保存するファイルパス。
        input_data (np.ndarray): 入力データ。
        policy_data (np.ndarray): Policyのデータ。
        value_data (np.ndarray): Valueのデータ
        log_counter (int): データセットにある棋譜データの個数。
    """
    save_data = {
        "input": np.array(input_data[0:DATA_SET_SIZE]),
        "policy": np.array(policy_data[0:DATA_SET_SIZE]),
        "value": np.array(value_data[0:DATA_SET_SIZE], dtype=np.int32),
        "log_count": np.array(log_counter)
    }
    np.savez_compressed(save_file_path, **save_data)

def generate_supervised_learning_data(program_dir: str, log_dir: str):
    input_data = []
    policy_data = []
    value_data = []

    log_counter = 1
    data_counter = 0

    for one_log in os.listdir(log_dir):
        if os.path.isdir(os.path.join(log_dir, one_log)):
            with open(os.path.join(log_dir, one_log, "game.dcl2")) as dclfile:
                dcl2_data = dclfile.readlines()
            dcl_json_data = json.loads(dcl2_data[-2])
            for log_path in sorted(glob.glob(os.path.join(log_dir, one_log, "*.json"))):
                with open(log_path, 'r') as file:
                    data = json.load(file)
                    if data['log']['end'] <= 9 :
                        input_data.append(generate_input_planes(data['log']['simulator_storage']['stones'], data['log']['shot']))
                        policy_data.append(generate_target_data(data['log']['selected_move']))
                        value_data.append(generate_value_data(dcl_json_data, data['log']['end'], data['log']['shot']))
                

                if len(value_data) >= DATA_SET_SIZE:
                    print(f"sl_data{data_counter}")
                    _save_data(os.path.join(program_dir, "data", f"sl_data_{data_counter}"), input_data, policy_data, value_data, log_counter)
                    input_data = input_data[DATA_SET_SIZE:]
                    policy_data = policy_data[DATA_SET_SIZE:]
                    value_data = value_data[DATA_SET_SIZE:]
                    log_counter = 1
                    data_counter += 1
                
                log_counter += 1

    # 端数の出力
    n_batches = len(value_data) // BATCH_SIZE
    if n_batches > 0:
        _save_data(os.path.join(program_dir, "data", f"sl_data_{data_counter}"), \
            input_data[0:n_batches*BATCH_SIZE], policy_data[0:n_batches*BATCH_SIZE], \
            value_data[0:n_batches*BATCH_SIZE], log_counter)


        
