import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
import matplotlib.pyplot as plt
import os
import time
import math

os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

class Enc(torch.nn.Module):
    def __init__(self, dim, head_num, dropout = 0.1):
        super(Enc,self).__init__()
        self.TEL=nn.TransformerEncoderLayer(dim,head_num,dropout=dropout,batch_first=True)
        self.transformerEnc=nn.TransformerEncoder(self.TEL,4)
    def forward(self,x,mask):
        x=self.transformerEnc(x,src_key_padding_mask=mask)
        return x
    
class Dec(torch.nn.Module):
    def __init__(self, dim, head_num, dropout = 0.1):
        super(Dec,self).__init__()
        self.TEL=nn.TransformerEncoderLayer(dim,head_num,dropout=dropout,batch_first=True)
        self.transformerEnc=nn.TransformerEncoder(self.TEL,4)

    def forward(self,x,mask):
        x=self.transformerEnc(x,src_key_padding_mask=mask)
        return x

class Dec2(torch.nn.Module):
    def __init__(self, dim, head_num, dropout = 0.1):
        super(Dec2,self).__init__()
        self.TEL=nn.TransformerDecoderLayer(dim,head_num,dropout=dropout,batch_first=True)
        self.transformerEnc=nn.TransformerDecoder(self.TEL,4)

    def forward(self,x,y,mask):
        x=self.transformerEnc(x,y,tgt_key_padding_mask=mask)
        return x
    

class Net_transformer(torch.nn.Module):
    def __init__(self, dim, head_num, dropout = 0.1):
        super(Net_transformer,self).__init__()
        self.output_dim=256
        self.dim = dim
        self.vec=nn.Sequential(
            nn.Linear(5,dim),
        )
        self.sm=nn.Softmax(dim=1)
        self.TEL=nn.TransformerEncoderLayer(dim,head_num,dropout=dropout,batch_first=True)
        self.transformerEnc=nn.TransformerEncoder(self.TEL,4)
        self.fc1=nn.Sequential(
            nn.Linear(dim,self.output_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
        )
        self.fc_game=nn.Sequential(
            nn.Linear(9,self.output_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
        )
        self.fc2 = nn.Sequential(
            nn.Linear(self.output_dim*16,512),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(512,32),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(32,11),
        )
        self.dec=Dec(dim,head_num,dropout)
    def forward(self,input,game_data,mask2):
        x=self.vec(input)
        game_data=self.fc_game(game_data)
        game_data=game_data.view(-1,1,self.output_dim)
        x=torch.cat([x,game_data],dim=1)
        x=self.dec(x,mask=mask2)
        x=x.view(-1,16*self.output_dim)
        x=self.fc2(x)
        x=self.sm(x)
        return x
    
class Net(torch.nn.Module):
    def __init__(self, dim, head_num, dropout = 0.1):
        super(Net,self).__init__()
        self.output_dim=256
        self.dim = dim
        """self.vec=nn.Sequential(
            nn.Linear(5,dim),
        )"""
        #self.dropout = nn.Dropout(dropout)
        #self.EncoderBlocks = nn.ModuleList([EncoderBlock(dim, head_num) for _ in range(6)])
        self.sm=nn.Softmax(dim=1)
        self.fc1=nn.Sequential(
            nn.Linear(5,self.output_dim),
        )
        self.fc_game=nn.Sequential(
            nn.Linear(24,self.output_dim),
        )
        self.fc2 = nn.Sequential(
            nn.Linear(self.output_dim*16,512),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(512,32),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(32,11),
        )
        self.dec=Dec(dim,head_num,dropout)
        #self.pmodel=pre_Net(dim=dim,head_num=head_num).to(device=device)
    def forward(self,input,game_data,mask2):
        #print(self.pe[:x.size(0)].shape)
        x=self.fc1(input)
        #x=self.transformerEnc(x,src_key_padding_mask=mask)
        game_data=self.fc_game(game_data)
        game_data=game_data.view(-1,1,self.output_dim)
        x=torch.cat([x,game_data],dim=1)
        #print(game_data.shape)
        x=self.dec(x,mask=mask2)
        x=x.view(-1,16*self.output_dim)
        x=self.fc2(x)
        x=self.sm(x)
        return x