import socket
import torch
import numpy as np
import copy
import json
import time
import os
import torch.nn as nn
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
script_dir = os.path.dirname(os.path.abspath(__file__))
print(script_dir)
import sys
sys.path.append(script_dir)
#import simulator
import net
y_center=38.405
x_center=0
rot_r=0
cut=6
cut2=4
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
class client:
    def __init__(self):
        self.inshot_table=np.load(script_dir+'/shot_table.npy')
        self.outshot_table=np.load(script_dir+'/outshot_table.npy')
        self.model=net.Net(dim=256,head_num=8).to(device=device)
        checkpoint = torch.load(script_dir+'/basemodel_9.pth')
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.eval()
        self.src_mask = nn.Transformer.generate_square_subsequent_mask(16)[:16,1:]
        print("input sim port")
        sim_port=int(input())
        self.s=socket.socket(socket.AF_INET,socket.SOCK_STREAM)
        self.s.connect((socket.gethostname(),10000+sim_port))
        self.x_min=22
        self.y_min=41#before41
        self.stone_rad=0.145
        self.y_center=38.405
        self.x_center=0
        wt=[[0.959,0.939,0.919,0.771,0.609,0.340,0.162,0.034,0.015,0.010,0.005],
        [0.989,0.969,0.946,0.794,0.557,0.279,0.122,0.021,0.014,0.009,0.004],
        [1.000,0.999,0.962,0.881,0.677,0.260,0.042,0.011,0.011,0.001,0.000],
        [1.000,1.000,1.000,1.000,1.000,0.220,0.000,0.000,0.000,0.000,0.000]]
        self.wtable_1st=np.array(wt)
        self.wtable_2nd=1-self.wtable_1st[:,::-1]

        self.shot_table=np.zeros((41,59,2,5,11))
        for sx in range(41):
            for ty in range(59):
                for rot in range(2):
                    j=-2
                    k=-2
                    for i in range(43):
                        self.shot_table[sx,ty,rot,j+2,k+5]=self.inshot_table[sx,ty,rot,i]
                        if i==4 or i==15 or i==26:
                            j+=1
                            k=-5
                        elif i==37:
                            j+=1
                            k=-2
                        else:
                            k+=1
        


    def calc_score(self,sheet,num_stone):
        score=0
        No1=sheet[0][5]
        for i in range(16):
            if sheet[i][2]>1.829+self.stone_rad or sheet[i][5]==0:
                break
            if No1==sheet[i][5]:
                score+=1
            else:
                break
        score*=No1
        return score
    
    def calc_score2(self,sheet,end,player):
        score=0
        for i in range(15):
            if player==1 and sheet[i][5]==-1:
                sheet[i][5]=1
                break
            if player==-1 and sheet[i][5]==1:
                sheet[i][5]=-1
                break
        No1=sheet[0][5]
        sum=0
        for i in range(15):
            if sheet[i][2]>1.829+self.stone_rad or sheet[i][5]==0:
                break
            if No1==sheet[i][5]:
                score+=1
            else:
                break
        score*=No1
        for i in range(15):
            if sheet[i][2]>1.829+self.stone_rad or sheet[i][5]==0:
                break
            if sheet[i,5]==player:
                sum-=1
            else:
                sum+=1
        #print(score,sum)
        return (score+sum)/2

    def create_sheet_data(self,sheet,shot_result,player,full=0):
        sheet_after_shot=np.zeros((15,7))
        for i in range(14):
            sheet_after_shot[i,0]=shot_result["stone"+str(i+1)][0]
            sheet_after_shot[i,1]=shot_result["stone"+str(i+1)][1]
            sheet_after_shot[i,2]=np.sqrt((shot_result["stone"+str(i+1)][0]-x_center)**2+(shot_result["stone"+str(i+1)][1]-y_center)**2)
            if sheet[i,5]==0:
                sheet_after_shot[i,2]=150
            sheet_after_shot[i,5]=sheet[i,5]
        sheet_after_shot[14,0]=shot_result["stone16"][0]
        sheet_after_shot[14,1]=shot_result["stone16"][1]
        sheet_after_shot[14,2]=np.sqrt((shot_result["stone16"][0]-x_center)**2+(shot_result["stone16"][1]-y_center)**2)
        sheet_after_shot[14,5]=player
        if full==1:
            for i in range(15):
                if sheet_after_shot[i,0]<-2.375 or sheet_after_shot[i,0]>2.375 or sheet_after_shot[i,1]>40.234 or sheet_after_shot[i,1]<32.004:
                    sheet_after_shot[i,:]=0
                    sheet_after_shot[i,2]=150
                elif sheet_after_shot[i,2]>1.829+0.145 and sheet_after_shot[i,1]>38.405:
                    sheet_after_shot[i,:]=0
                else:
                    sheet_after_shot[i,3]=1
            sheet_after_shot=sheet_after_shot[np.argsort(sheet_after_shot[:,2])]
            for i in range(15):
                if sheet_after_shot[i,2]<1.829 and sheet_after_shot[i,5]!=0:
                        sheet_after_shot[i,4]=1
                elif sheet_after_shot[i,2]>=100:
                        sheet_after_shot[i,2]=0
            flag=False
            player=sheet_after_shot[0,5]
            for i in range(1,15):
                if sheet_after_shot[i,5]==0:
                    break
                if player!=sheet_after_shot[i,5]:
                    flag=True
                if flag:
                    sheet_after_shot[i,6]=1
        else:
            sheet_after_shot=sheet_after_shot[np.argsort(sheet_after_shot[:,2])]
        return sheet_after_shot
    
    def create_sheet_data2(self,sheet,shot_result,player):
        sheet_after_shot=np.zeros((15,5))
        for i in range(14):
            sheet_after_shot[i,0]=shot_result["stone"+str(i+1)][0]
            sheet_after_shot[i,1]=shot_result["stone"+str(i+1)][1]
            sheet_after_shot[i,2]=np.sqrt((sheet_after_shot[i,0]-x_center)**2+(sheet_after_shot[i,1]-y_center)**2)
            if sheet[i,5]==0:
                sheet_after_shot[i,2]=150
            sheet_after_shot[i,3]=sheet[i,5]*player
            if sheet_after_shot[i,2]<1.829+0.145:
                sheet_after_shot[i,4]=1
        sheet_after_shot[14,0]=shot_result["stone16"][0]
        sheet_after_shot[14,1]=shot_result["stone16"][1]
        sheet_after_shot[14,2]=np.sqrt((sheet_after_shot[14,0]-x_center)**2+(sheet_after_shot[14,1]-y_center)**2)
        sheet_after_shot[14,3]=player
        if sheet_after_shot[14,2]<1.829+0.145:
            sheet_after_shot[14,4]=1
        sheet_after_shot=sheet_after_shot[np.argsort(sheet_after_shot[:,2])]
        for i in range(15):
            if sheet_after_shot[i,0]<-2.375 or sheet_after_shot[i,0]>2.375 or sheet_after_shot[i,1]>40.234+0.145 or sheet_after_shot[i,1]<32.004 or sheet_after_shot[i,2]>100:
                sheet_after_shot[i,:]=0
            elif sheet_after_shot[i,2]>1.829+0.145 and sheet_after_shot[i,1]>38.405:
                sheet_after_shot[i,:]=0
            else:
                sheet_after_shot[i,1]=sheet_after_shot[i,1]-38.405
        return sheet_after_shot

    
    def calc_winrate(self,rest_end,score_dist,score,player):
        winrate=0
        score_dist_aftershot=int(np.clip(score_dist+3+score,0,10))
        if rest_end==0 and score_dist_aftershot!=5:
            if score_dist_aftershot>5:
                winrate=1 
            else:
                winrate=0
        else:
            if score>0:
                winrate=self.wtable_1st[3-rest_end,10-score_dist_aftershot]
            elif score==0:
                if player==1:
                    winrate=self.wtable_2nd[3-rest_end,10-score_dist_aftershot]
                else:
                    winrate=self.wtable_1st[3-rest_end,10-score_dist_aftershot]
                
            else:
                winrate=self.wtable_2nd[3-rest_end,10-score_dist_aftershot]
        return winrate
    
    def calc_score_rulebase_new(self,sheet_after_shot,player,end):
        alpha=1.83+self.stone_rad
        beta=1.22
        gamma=1.83+self.stone_rad
        a=0.915
        b=1.83
        w1=1
        w2=1
        if player==-1:
            mu=0.1
        else:
            mu=0.2
        if player==-1:
            e=self.calc_score2(sheet_after_shot.copy(),end,player)
            #print(e)
            if e==-1 and sheet_after_shot[0][5]==1:
                e=0.1
        else:
            e=self.calc_score(sheet_after_shot,15)
            sum=0
            for i in range(15):
                if sheet_after_shot[i][2]>1.829+self.stone_rad or sheet_after_shot[i][5]==0:
                    break
                if sheet_after_shot[i,5]==player:
                    sum+=1
                else:
                    sum-=1
            e=(e+sum)/2
        num_enemy_stone=0
        for i in range(14):
            if sheet_after_shot[i][5]==0:
                break
            ns=0
            for j in range(i):
                if sheet_after_shot[j][5]==-1:
                    ns+=1
                else:
                    break
            ks=1/(1+ns)
            hx=0.0

            if player==-1:
                if abs(sheet_after_shot[i][0]-self.x_center)<alpha:
                    hx=1-abs(sheet_after_shot[i][0]-self.x_center)/alpha

            else:
                if sheet_after_shot[i][2]<alpha:
                    hx=abs(sheet_after_shot[i][0]-self.x_center)/alpha
                

            hy=0.0
            y=-(self.y_center-sheet_after_shot[i][1])+4.88
            if 4.88-beta < y and y<4.88:
                hy=(y-4.88+beta)*(y-4.88+beta)/(beta*beta)
            elif 4.88<y and y<4.88+gamma:
                hy=1-(y-4.88)*(y-4.88)/(gamma*gamma)
            
            js=w1*ks+w2*hx*hy
            #print(hy)
            ds=1
            Tmax=0.01
            if sheet_after_shot[i][5]==-1 and sheet_after_shot[i][2]<1.829+0.145:
                num_enemy_stone+=1
            
            for j in range(15):
                if sheet_after_shot[j][5]==0:
                    break
                
                if i!=j:
                    if sheet_after_shot[i][5]==sheet_after_shot[j][5]:
                        d=np.linalg.norm(sheet_after_shot[i,:2]-sheet_after_shot[j,:2])
                        D2=a+b*abs(sheet_after_shot[i][1]-sheet_after_shot[j][1])/d
                        dist=1
                        if d<D2:
                            dist=0.5+0.5*d*d/(D2*D2)
                        if dist<ds:
                            ds=dist
                        if abs(sheet_after_shot[j][1]-sheet_after_shot[i][1])<0.145 and abs(sheet_after_shot[j][0]-sheet_after_shot[i][0])>0.145*4 and sheet_after_shot[i,2]<1.829 and sheet_after_shot[j,2]<1.829:
                            Tmax=0.1
                        """if sheet_after_shot[j][1]-sheet_after_shot[i][1]>8*self.stone_rad and sheet_after_shot[j][1]-sheet_after_shot[i][1]<1.829:
                            T=0
                            if abs(sheet_after_shot[i][0]-sheet_after_shot[j][0])<1*self.stone_rad:
                                T=1
                            elif abs(sheet_after_shot[i][0]-sheet_after_shot[j][0])<alpha:
                                T=1/(abs(sheet_after_shot[i][0]-sheet_after_shot[j][0])+1)
                            if Tmax<T:
                                Tmax=T"""
            N=js/ds
            e+=mu*N*sheet_after_shot[i][5]*Tmax
        #print(e)
        #print(player,num_enemy_stone)
        if player==1:
            e-=num_enemy_stone*0.5
        return e
    
    def calc_score_rulebase(self,sheet_after_shot,player,end):
        alpha=1.83+self.stone_rad
        beta=1.22
        gamma=1.83+self.stone_rad
        a=0.915
        b=1.83
        w1=1
        w2=1
        if player==-1:
            mu=0.1
        else:
            mu=0.2
        if player==-1:
            e=self.calc_score2(sheet_after_shot.copy(),end,player)
            #print(e)
            if e==-1 and sheet_after_shot[0][5]==1:
                e=0.1
        else:
            e=self.calc_score(sheet_after_shot,15)
        e*=1
        num_enemy_stone=0
        for i in range(14):
            if sheet_after_shot[i][5]==0:
                break
            ns=0
            for j in range(i):
                if sheet_after_shot[j][5]==-1:
                    ns+=1
                else:
                    break
            ks=1/(1+ns)
            hx=0.0
            if player==-1:
                if abs(sheet_after_shot[i][0]-self.x_center)<alpha:
                    hx=1-abs(sheet_after_shot[i][0]-self.x_center)/alpha

            else:
                if sheet_after_shot[i][2]<alpha:
                    hx=abs(sheet_after_shot[i][0]-self.x_center)/alpha
                

            hy=0.0
            y=-(self.y_center-sheet_after_shot[i][1])+4.88
            if 4.88-beta < y and y<4.88:
                hy=(y-4.88+beta)*(y-4.88+beta)/(beta*beta)
            elif 4.88<y and y<4.88+gamma:
                hy=1-(y-4.88)*(y-4.88)/(gamma*gamma)
            
            js=w1*ks+w2*hx*hy
            #print(hy)
            ds=1
            Tmax=0.01
            if sheet_after_shot[i][5]==-1 and sheet_after_shot[i][2]<1.829:
                num_enemy_stone+=1
            
            for j in range(15):
                if sheet_after_shot[j][5]==0:
                    break
                
                if i!=j:
                    if sheet_after_shot[i][5]==sheet_after_shot[j][5]:
                        d=np.linalg.norm(sheet_after_shot[i,:2]-sheet_after_shot[j,:2])
                        D2=a+b*abs(sheet_after_shot[i][1]-sheet_after_shot[j][1])/d
                        dist=1
                        if d<D2:
                            dist=0.5+0.5*d*d/(D2*D2)
                        if dist<ds:
                            ds=dist
                        if sheet_after_shot[j][1]-sheet_after_shot[i][1]>8*self.stone_rad and sheet_after_shot[j][1]-sheet_after_shot[i][1]<1.829:
                            T=0
                            if abs(sheet_after_shot[i][0]-sheet_after_shot[j][0])<1*self.stone_rad:
                                T=1
                            elif abs(sheet_after_shot[i][0]-sheet_after_shot[j][0])<alpha:
                                T=1/(abs(sheet_after_shot[i][0]-sheet_after_shot[j][0])+1)
                            if Tmax<T:
                                Tmax=T
                
            N=js/ds
            e+=mu*N*sheet_after_shot[i][5]*Tmax
        #print(e)
        #print(player,num_enemy_stone)
        if player==1:
            e-=num_enemy_stone*0.5
        return e
    
    def calc_score_modelbase(self,sheet_after_shot,player,remain_end,score_dist,shot_num):
        d1=np.zeros((1,15,5))
        d2=np.zeros((1,24))
        d3=np.zeros((1,16))
        score_dist_one_hot=np.zeros(5)
        score_dist_one_hot[score_dist]=1
        remain_end_one_hot=np.zeros(4)
        remain_end_one_hot[remain_end]=1
        shot_num_onehot=np.zeros(15)
        shot_num_onehot[14]=1
        d2[0]=np.concatenate([score_dist_one_hot,remain_end_one_hot,shot_num_onehot])
        stone_num=0
        for i in range(15):
            if sheet_after_shot[i,3]!=0:
                stone_num+=1
            else:
                break
        d1[0]=sheet_after_shot
        #print(stone_num)
        #print(d1)
        d3[0]=np.append(self.src_mask[stone_num]!=0,False)
        d1=torch.from_numpy(d1).float().to(device)
        d2=torch.from_numpy(d2).float().to(device)
        d3=torch.from_numpy(d3).bool().to(device)
        predict=self.model(d1,d2,d3).cpu().detach().numpy().copy()[0]
        #print(sheet_after_shot[:5],np.argmax(predict))
        e=0
        for score in range(11):
            e-=self.calc_winrate(remain_end,score_dist,score-5,1)*predict[score]
        return e
    
    def rule_base(self,player,shot_num,stones,field):
        opt=np.zeros(3)
        field_win_rate=np.zeros((self.x_min*2+1,69,2))
        outshot_winrate=np.zeros((79,2))
        field_win_rate[:,:,:]=-10
        outshot_winrate[:,:]=-10
        emax=-10
        for rot in range(2):
            for dx in range(-self.x_min+cut2,self.x_min+1-cut2):
                #for dy in range(-16,17,2):
                #for dy in range(-36,20,2):
                for dy in range(-30,20,2):
                    target_x=dx*self.stone_rad
                    target_y=dy*self.stone_rad+self.y_center
                    shot_result=self.shot(stones,target_x,target_y,rot)
                    sheet_after_shot=self.create_sheet_data(field,shot_result,1)
                    
                    if shot_num < 6:
                        stone_cant_move=0
                        for i in range(15):
                            if field[i][5]==0:
                                break
                            if field[i][5]==-1 and field[i][1]<y_center and field[i][2]>1.829+self.stone_rad:
                                stone_cant_move+=1
                        for i in range(15):
                            if sheet_after_shot[i][5]==0:
                                break
                            if sheet_after_shot[i][5]==-1 and sheet_after_shot[i][1]<y_center and sheet_after_shot[i][2]>1.829+self.stone_rad:
                                stone_cant_move-=1
                        if stone_cant_move>0:
                            field_win_rate[dx+self.x_min,dy+self.y_min,rot]=-1
                        else:
                            field_win_rate[dx+self.x_min,dy+self.y_min,rot]=self.calc_score_rulebase_new(sheet_after_shot,player,shot_num)
                    else:
                        field_win_rate[dx+self.x_min,dy+self.y_min,rot]=self.calc_score_rulebase_new(sheet_after_shot,player,shot_num)
                    if dy==-36:
                        field_win_rate[dx+self.x_min,dy+self.y_min-2:dy+self.y_min+2,rot]=field_win_rate[dx+self.x_min,dy+self.y_min,rot]
                    else:
                        field_win_rate[dx+self.x_min,dy+self.y_min+1,rot]=field_win_rate[dx+self.x_min,dy+self.y_min,rot]
                    #print(dx,dy,e)
            for dx in range(79):
                if rot==rot_r:
                    target_x=x_center+self.stone_rad*dx*0.25-0.69875-0.25*self.stone_rad*4
                else:
                    target_x=x_center+self.stone_rad*dx*0.25-1.875-0.25*self.stone_rad*4
                target_y=-(38.405-6.039+4.88)
                shot_result=self.shot(stones,target_x,target_y,rot)
                sheet_after_shot=self.create_sheet_data(field,shot_result,1)
                if shot_num < 6:
                    stone_cant_move=0
                    for i in range(15):
                        if field[i][5]==0:
                            break
                        if field[i][5]==-1 and field[i][1]<y_center and field[i][2]>1.829+self.stone_rad:
                            stone_cant_move+=1
                    for i in range(15):
                        if sheet_after_shot[i][5]==0:
                            break
                        if sheet_after_shot[i][5]==-1 and sheet_after_shot[i][1]<y_center and sheet_after_shot[i][2]>1.829+self.stone_rad:
                            stone_cant_move-=1
                    if stone_cant_move>0:
                        outshot_winrate[dx,rot]=-1
                    else:
                        outshot_winrate[dx,rot]=self.calc_score_rulebase(sheet_after_shot,player,shot_num)
                else:
                    outshot_winrate[dx,rot]=self.calc_score_rulebase(sheet_after_shot,player,shot_num)

        best_win_rate=-10
        field_win_rate2=np.zeros(((self.x_min-2)*2+1,59,2))
        for rot in range(2):
                for sx in range((self.x_min-2)*2+1):
                    #for ty in range(18,51):
                    for ty in range(0,59):
                        win_rate=np.vdot(field_win_rate[sx:sx+5,ty:ty+11,rot],self.shot_table[sx,ty,rot])
                        field_win_rate2[sx,ty,rot]=win_rate
                        if best_win_rate<win_rate:
                            best_win_rate=win_rate
                            opt[0]=sx
                            opt[1]=ty
                            opt[2]=rot
                            #print(best_win_rate)
                for sx in range(71):
                    win_rate=np.inner(outshot_winrate[sx:sx+9,rot],self.outshot_table[sx,rot])
                    if best_win_rate<win_rate:
                        best_win_rate=win_rate
                        opt[0]=sx
                        opt[1]=-1
                        opt[2]=rot
        print(best_win_rate,opt)
        return opt
    def model_base(self,player,shot_num,stones,field,score_dist,end):
        opt=np.zeros(3)
        field_win_rate=np.zeros((self.x_min*2+1,69,2))-1
        field_win_rate2=np.zeros((self.x_min*2+1,69,2))-1
        outshot_winrate=np.zeros((79,2))-1
        emax=-10
        for rot in range(2):
            for dx in range(-self.x_min+cut,self.x_min+1-cut):
                #for dy in range(-16,17,2):
                for dy in range(-10,20,2):
                    target_x=dx*self.stone_rad
                    target_y=dy*self.stone_rad+self.y_center
                    shot_result=self.shot(stones,target_x,target_y,rot)
                    sheet_after_shot=self.create_sheet_data2(field,shot_result,player)
                    field_win_rate[dx+self.x_min,dy+self.y_min,rot]=self.calc_score_modelbase(sheet_after_shot,player,end,score_dist,shot_num)
                    if dy==-36:
                        field_win_rate[dx+self.x_min,dy+self.y_min-2:dy+self.y_min+2,rot]=field_win_rate[dx+self.x_min,dy+self.y_min,rot]
                    else:
                        field_win_rate[dx+self.x_min,dy+self.y_min+1,rot]=field_win_rate[dx+self.x_min,dy+self.y_min,rot]
                    #print(dx,dy,e)
            for dx in range(79):
                if rot==rot_r:
                    target_x=x_center+self.stone_rad*dx*0.25-0.69875-0.25*self.stone_rad*4
                else:
                    target_x=x_center+self.stone_rad*dx*0.25-1.875-0.25*self.stone_rad*4
                target_y=-(38.405-6.039+4.88)
                shot_result=self.shot(stones,target_x,target_y,rot)
                sheet_after_shot=self.create_sheet_data2(field,shot_result,player)
                outshot_winrate[dx,rot]=self.calc_score_modelbase(sheet_after_shot,player,end,score_dist,shot_num)

        best_win_rate=-10
        #field_win_rate2=np.zeros(((self.x_min-2)*2+1,59,2))
        """df=pd.DataFrame(field_win_rate2[:,30:,0])
        #print(df)
        sns.heatmap(df, annot=True,cbar=False)
        plt.show()"""
        for rot in range(2):
                for sx in range((self.x_min-2)*2+1):
                    #for ty in range(18,51):
                    for ty in range(0,59):
                        win_rate=np.vdot(field_win_rate[sx:sx+5,ty:ty+11,rot],self.shot_table[sx,ty,rot])
                        field_win_rate2[sx,ty,rot]=win_rate
                        if best_win_rate<win_rate:
                            best_win_rate=win_rate
                            opt[0]=sx
                            opt[1]=ty
                            opt[2]=rot
                            #print(best_win_rate)
                for sx in range(71):
                    win_rate=np.inner(outshot_winrate[sx:sx+9,rot],self.outshot_table[sx,rot])
                    if best_win_rate<win_rate:
                        best_win_rate=win_rate
                        opt[0]=sx
                        opt[1]=-1
                        opt[2]=rot
        print(best_win_rate,opt)
        return opt

    

    def shot(self,stones,target_x,target_y,rot):
        if rot==0:
            stones["shot"]=[target_x,target_y,1]
        else:
            stones["shot"]=[target_x,target_y,-1]
        json_shot=(json.dumps(stones)+'\n').encode('utf-8')
        self.s.send(bytes(json_shot))
        response = self.s.recv(1024)
        shot_result=json.loads(response.decode('utf-8'))
        return shot_result


    def Expectimax(self,field,shot_num,end,score_dist,remain_time):
        if shot_num==1:
            opt=np.zeros(3)
            opt[0]=20
            opt[1]=11
            opt[2]=0
            return 0,opt
        his=np.zeros(11)
        #print(end)
        if shot_num%2==1:
            shot_player=-1
        else:
            shot_player=1
        flag=0
        enemy_stone_num=0
        enemy_stone_num2=0
        my_stone_num=0
        score=0
        opt=np.zeros(3)
        stones={}
        stones["result"]=[0,0]
        
        for i in range(15):
            stones["stone"+str(i+1)]=[field[i][0],field[i][1],field[i][5]]
            if field[i][5]==-1:
                enemy_stone_num+=1
                if field[i][2]<1.829+self.stone_rad:
                    enemy_stone_num2+=1
                flag=1
            elif field[i][5]==1:
                my_stone_num+=1
            else:# field[i][5]==0:
                break
        num_stone=i
        if remain_time<10:
            print("notime")
            opt[0]=20
            opt[1]=30
            opt[2]=0
            return 0,opt
        if remain_time<25*(end)+2.5*(15-shot_num):
            print("notime mode")
            opt=self.rule_base(shot_player,shot_num,stones,field)
            return 0,opt
        if shot_num<13 or shot_num==14:
            opt=self.rule_base(shot_player,shot_num,stones,field)
            return 0,opt
        elif shot_num==15 or shot_num==13:
            opt=self.model_base(shot_player,shot_num,stones,field,score_dist,end)
            return 0,opt
        else:
            field_win_rate=np.zeros((self.x_min*2+1,69,2))
            outshot_winrate=np.zeros((79,2))
            sheet_score_out=np.zeros((79,2))
            sheet_score=np.zeros((self.x_min*2+1,69,2))
            num_stone=i
            for rot in range(2):
                for dx in range(-self.x_min,self.x_min+1):
                    for dy in range(-16,25,2):
                        target_x=dx*self.stone_rad
                        target_y=dy*self.stone_rad+self.y_center
                        shot_result=self.shot(stones,target_x,target_y,rot)
                        sheet_after_shot=self.create_sheet_data(field,shot_result,1)
                        score=self.calc_score(sheet_after_shot,num_stone)
                        sheet_score[dx+self.x_min,dy+self.y_min,rot]=score
                        field_win_rate[dx+self.x_min,dy+self.y_min,rot]=self.calc_winrate(end,score_dist,score,shot_player)
                        if dy==-16:
                            field_win_rate[dx+self.x_min,dy+self.y_min-2:dy+self.y_min+2,rot]=field_win_rate[dx+self.x_min,dy+self.y_min,rot]
                        else:
                            field_win_rate[dx+self.x_min,dy+self.y_min+1,rot]=field_win_rate[dx+self.x_min,dy+self.y_min,rot]
                            sheet_score[dx+self.x_min,dy+self.y_min+1,rot]=sheet_score[dx+self.x_min,dy+self.y_min,rot]
                
                for dx in range(79):
                    if rot==rot_r:
                        target_x=x_center+self.stone_rad*dx*0.25-0.69875-self.stone_rad
                    else:
                        target_x=x_center+self.stone_rad*dx*0.25-1.875-self.stone_rad
                    target_y=-37.246
                    shot_result=self.shot(stones,target_x,target_y,rot)
                    sheet_after_shot=self.create_sheet_data(field,shot_result,1)
                    score=self.calc_score(sheet_after_shot,num_stone)
                    sheet_score_out[dx,rot]=score
                    outshot_winrate[dx,rot]=self.calc_winrate(end,score_dist,score,shot_player)
            best_win_rate=0
            best_score=-10

            for rot in range(2):
                for sx in range((self.x_min-2)*2+1):
                    for ty in range(18,59):
                        win_rate=0
                        score=sheet_score[sx+2,ty+5,rot]
                        """j=-2
                        k=-2
                        for i in range(43):
                            win_rate+=field_win_rate[sx+j+2,ty+k+5,rot]*self.inshot_table[sx,ty,rot,i]
                            if i==4 or i==15 or i==26:
                                j+=1
                                k=-5
                            elif i==37:
                                j+=1
                                k=-2
                            else:
                                k+=1"""
                        win_rate=np.vdot(field_win_rate[sx:sx+5,ty:ty+11,rot],self.shot_table[sx,ty,rot])
                        if best_win_rate<win_rate or (best_win_rate==win_rate and score>best_score):
                            best_win_rate=win_rate
                            best_score=score
                            opt[0]=sx
                            opt[1]=ty
                            opt[2]=rot

                for sx in range(71):
                    win_rate=0
                    score=sheet_score_out[sx+4,rot]
                    win_rate=np.inner(outshot_winrate[sx:sx+9,rot],self.outshot_table[sx,rot])
                    if best_win_rate<win_rate or (best_win_rate==win_rate and score>best_score):
                        best_win_rate=win_rate
                        best_score=score
                        opt[0]=sx
                        opt[1]=-1
                        opt[2]=rot
            return best_win_rate,opt


shottable=np.load(script_dir+'/draw_shot_velo.npy')
outshottable=np.load(script_dir+'/take_shot_velo.npy')
client1=client()
print("INPUT SERVER port")
server_port=int(input())
s=socket.socket(socket.AF_INET,socket.SOCK_STREAM)
s.settimeout(600)
#host_name=socket.gethostbyname("digitalcurling.ddns.net")
#print(host_name)
s.connect(("180.35.242.250",server_port))
#s.connect((socket.gethostname(),15000+server_port))

response = s.recv(1024)
print(response)

dc={
        "cmd": "dc_ok",
        "name": "Jiritsukun-Jr_ver3"
    }
json_dc=(json.dumps(dc)+'\n').encode('utf-8')
#print(json_dc)
time.sleep(0.1)
s.send(json_dc)

response = s.recv(4096)
print(response)
response=json.loads(response.decode('utf-8'))
team=response["team"]
print(team)
if team=='team1':
    enemy_team='team0'
else:
    enemy_team='team1'
ready={
    "cmd": "ready_ok",
    "player_order": [
        3,  
        1,  
        2, 
        0 
    ]
}
json_ready=(json.dumps(ready)+'\n').encode('utf-8')
time.sleep(0.5)
s.send(json_ready)
response = s.recv(4096)
print('ready')
print(response)
num_end=9
#response = s.recv(4096)
#print(response)
score_dist=2
num=0
while(True):
    time.sleep(0.1)
    message = ""
    message_buffer = ""
    while True:
        if message_buffer != "":
            message_recv = message_buffer
            message_buffer = ""
        else:
            message_recv: str = s.recv(4096*2).decode("utf-8")  # type: ignore
        message += message_recv
        if "\n" in message_recv:
            # split message by "\n"
            message_split = message.split("\n")
            # set first message to message
            message = message_split[0]
            # set second and later message to message_buffer
            message_buffer = "\n".join(message_split[1:])
            break

    st = json.dumps(message)
    json_data = json.loads(st)
    response = json.loads(json_data)
    num+=1
    end_count=0
    if True:
        #print(response["next_team"])
        if response["cmd"]=="game_over":
            print("gameover")
            break
        if response["next_team"]==team:
            end=int(response["state"]["end"])
            shotnum=int(response["state"]["shot"])
            print('shot_number',shotnum)
            stones=response["state"]["stones"][team]
            stones2=response["state"]["stones"][enemy_team]
            my_score=response["state"]["scores"][team]
            enemy_score=response["state"]["scores"][enemy_team]
            print(response["state"]["thinking_time_remaining"])
            remain_time=response["state"]["thinking_time_remaining"][team]
            if shotnum==0 or shotnum==1:
                end_count+=1
                score_dist=2
                if end>num_end+1:
                    my_extrascore=response["state"]["extra_end_scores"][team]
                    enemy_extrascore=response["state"]["extra_end_scores"][enemy_team]
                    #my_extrascore=0
                    #enemy_extrascore=0
                    for i in range(num_end):
                        score_dist+=my_score[i]-enemy_score[i]
                    if score_dist<0:
                        score_dist=0
                    elif score_dist>4:
                        score_dist=4
                    for i in range(end-num_end):
                        score_dist+=my_extrascore[i]-enemy_extrascore[i]
                    print('score',score_dist-2)
                    if score_dist<0:
                        score_dist=0
                    elif score_dist>4:
                        score_dist=4
                else:
                    for i in range(end):
                        score_dist+=my_score[i]-enemy_score[i]
                    print('score',score_dist-2)
                    if score_dist<0:
                        score_dist=0
                    elif score_dist>4:
                        score_dist=4
                print('end:',end)
                

            field=np.zeros((16,7))
            i=0
            for stone in stones:
                if stone != None:
                    field[i,0]=stone["position"]["x"]
                    field[i,1]=stone["position"]["y"]
                    field[i,2]=np.sqrt((field[i,0])**2+(field[i,1]-y_center)**2)
                    field[i,3]=1
                    field[i,5]=1
                else:
                    field[i,2]=150
                i+=1
            for stone in stones2:
                if stone != None:
                    field[i,0]=stone["position"]["x"]
                    field[i,1]=stone["position"]["y"]
                    field[i,2]=np.sqrt((field[i,0])**2+(field[i,1]-y_center)**2)
                    field[i,3]=1
                    field[i,5]=-1
                else:
                    field[i,2]=150
                i+=1
            field=field[np.argsort(field[:,2])]
            for i in range(16):
                if field[i,2]<1.829 and field[i,5]!=0:
                        field[i,4]=1
                elif field[i,2]==150:
                        field[i,2]=0
            flag=False
            player=field[0,5]
            for i in range(1,16):
                if field[i,5]==0:
                    break
                if player!=field[i,5]:
                    flag=True
                if flag:
                    field[i,6]=1
            #end=0
            t1=time.time()
            #print(field)
            if num_end-end<0:
                end=0
            elif num_end-end>3:
                end=3
            else:
                end=num_end-end
            #print(field)
            print()
            value,opt=client1.Expectimax(field=field[0:15],shot_num=shotnum+1,end=end,score_dist=score_dist,remain_time=remain_time)
            print(time.time()-t1)
            
            move={
            "cmd": "move",
            "move": {
                "type": "shot",
                "velocity": { "x": 0.1, "y": 2.0 },
                "rotation": "ccw"
            }
            }
            x=int(opt[0])
            y=int(opt[1])
            print(x,y,value)
            if y<0:
                move["move"]["velocity"]["x"]=outshottable[x,int(opt[2]),0]
                move["move"]["velocity"]["y"]=outshottable[x,int(opt[2]),1]
            else:
                move["move"]["velocity"]["x"]=shottable[x,y,int(opt[2]),0]
                move["move"]["velocity"]["y"]=shottable[x,y,int(opt[2]),1]
            if int(opt[2])==0:
                move["move"]["rotation"]="ccw"
            else:
                move["move"]["rotation"]="cw"
            json_move=(json.dumps(move)+'\n').encode('utf-8')
            time.sleep(0.1)
            s.send(json_move)


