import numpy as np
from scipy.optimize import minimize
import matplotlib.pyplot as plt
import itertools
import pandas as pd
import re
import time
import os


## setting the global variables

LOWER_BOUND = 0.0 #the lower bound for the constraint weights
UPPER_BOUND = None #the upper bound for the constraint weights
MAXIterations =100 #the maximum number of EM iterations 
CON_EPSILON = 10e-4 #the convergence criterion
L2_PRIOR = True #to use L2_prior (or equivalently gaussian prior) or not
MU = 0 #setting mus in gaussian prior
SIGMASQUARE = 1e5 #setting sigmasquare in gaussian prior term
LAM = 2
ECM = True #if setting True, will perform another expectation step before re-estimation of UR parameters


class EM():
    def __init__(self):
        '''
        ''' 
    
    def harmony(self, vlist, weights):
        return np.dot(vlist, weights)
    
    def sr_given_ur(self, weights, data):
        ''' This function calculates the predicted probability of a sr, given a ur, under current constraint weights. 
            Which corresponds to Pr(SR|UR;W)

            Parameters
            ----------
                weights: A list of constraing weights
                data: a dictionary which stores the ur, sr, and the violation for each (ur, sr) pair 

            Returns
            -------
                opt: a list of tuples [((UR, SR, MORPH), predicted probability)]
        '''
        opt = [] 
        for form, urs in data.items():
            for curr_ur, srs_viols in urs.items():
                harmonies = np.array([-self.harmony(vlist, weights) for _, vlist in srs_viols.items()])
                curr_form = [((curr_ur, sr, form)) for sr, _ in srs_viols.items()] 
                #using the exp-normalize tric from here: https://timvieira.github.io/blog/post/2014/02/11/exp-normalize-trick/
                #borrowed from Tim Hunter's log-lin class: https://github.com/timhunter/loglin/blob/master/opt.py
                b = harmonies.max()
                shifted_exp_scores = np.exp(harmonies - b)
                normalized = shifted_exp_scores/shifted_exp_scores.sum()
                curr_opt = list(zip(curr_form, normalized))
                opt+=curr_opt
        return opt
    

    def hidden_and_observation(self, ur_sr_prob_pair, thetas):
        ''' This function calculates the conjoined probability P(UR, SR|Morph; Weights, UR probabilities) by multiplying P(UR|Morph;UR probabilities) and P(SR|UR, MORPH; Weights)

            Parameters
            ----------
                ur_sr_prob_pair: a list of tuples [((UR, SR, MORPH), predicted probability)]
                thetas: a dictionary which stores the UR probability {morph:{ur:P(ur|morph)}} 

            Returns
            -------
                conjoined: a dictionary which stores {form:{sr:ur:{the conjoined probability}}}
        '''
        #creating the structure for the output dictionary
        conjoined = {}
        for ((ur, sr, form), prob) in ur_sr_prob_pair:
            if form not in conjoined:
                conjoined[form] = {}
            if sr not in conjoined[form]:
                conjoined[form][sr] = {}
            if ur not in conjoined[form][sr]:
                conjoined[form][sr][ur] = {}

            each_form = re.split('_', form) #split the forms if polymorphemic
            urs = ['/'+i+'/' for i in re.split('/|-', ur) if i != ''] #splitting the urs
            form_ur = list(zip(each_form, urs)) 
            local_thetas = []
            for i in form_ur:
                local_thetas.append(thetas[i[0]][i[1]]) #pulling out the relevant ur probabilities 
            conjoined[form][sr][ur] = prob*np.product(local_thetas) #multiplying the two probabilities together
        return conjoined
    
    def hidden_given_observation(self, conjoined):
        ''' This function calculates the posterior probability based on the conjoined probability. 
            The posterior probability describes the "responsbility" of a ur in deriving a sr

            Parameters
            ----------
                conjoined: a dictionary which stores {form:{sr:ur:{the conjoined probability}}}

            Returns
            -------
                posterior: a dictionary which stores {form:{sr:ur:{posterior probability P(UR|SR)}}}
        '''
        posterior = {}
        for form, realizations in conjoined.items():
            posterior[form] = {}
            for sr, ur_conjoined in realizations.items():
                posterior[form][sr] = {}
                denom = 0
                for ur, suffixed_conjoinedValues in ur_conjoined.items():
                    posterior[form][sr][ur] = {}
                    denom += suffixed_conjoinedValues
                for ur, suffixed_conjoinedValues in ur_conjoined.items():
                    if denom == 0:
                        posterior[form][sr][ur] = 0
                    else:
                        posterior[form][sr][ur] = suffixed_conjoinedValues/denom
        return posterior
    

    def expected_values(self, posterior, obs, data):
        ''' This function splits up each observation and attributes the events to the underlying representations. 
            The output contains the expected counts of each UR for each form. 

            Parameters
            ----------
                 posterior: a dictionary which stores {form:{sr:ur:{posterior probability}}}
                 obs: a dictionary of observations, {form: {sr:observed frequency pair}}
                 data : a dictionary {form:{ur: {sr: violation}}}

            Returns
            ------- 
                eV: a list of tuples ((ur, sr, form), expected counts of (ur, sr) pair)

        '''
        eV = []
        expected_values = {}
        for form, realizations in posterior.items():
            expected_values[form]= {}
            for sr, ur_pos in realizations.items():
                expected_values[form][sr] = {}
                freq = obs[form][sr]
                for ur, pos_value in ur_pos.items():
                    expected_values[form][sr][ur] = freq*pos_value #multiplying the frequency of a (sr,form) pair and the posterior

        for form, urs in data.items():
            for curr_ur, srs_viols in urs.items():
                for sr, vlist in srs_viols.items():
                    eV.append(((curr_ur, sr, form), expected_values[form][sr][curr_ur]))
        return eV


    def E_step(self, weights, data, obs, thetas):

        ''' This function puts the previous intermediate steps altogether. 
            Parameters
            ----------
                weights: A list of constraing weights
                data: a dictionary which stores the ur, sr, and the violation for each (ur, sr) pair 
                obs: a dictionary of observations, {form: {sr:observed frequency pair}}    
                thetas: a dictionary which stores {morph:{ur:P(ur|morph)}} 

            Returns
            ------- 
                a tuple of (pos, ev)
                posterior: a dictionary which stores {form:{sr:ur:{P(UR|SR)}}}
                eV: a list of tuples ((ur, sr, form), expected counts of (ur, sr) pair)
        '''
        prior_G = self.sr_given_ur(weights, data)
        conjoined = self.hidden_and_observation(prior_G, thetas)
        pos = self.hidden_given_observation(conjoined)
        ev = self.expected_values(pos, obs, data)
        return (pos,ev)


    def loglikelihood(self, weights, data, eV):
        ''' This function calculates the log-likelihood of 'filled-in' data. 
            Parameters
            ----------
                 weights: a list of constraint weights
                 data : a dictionary {form:{ur: {sr: violation}}}
                 eV: a list of tuples ((ur, sr, form), expected counts of (ur, sr) pair)

            Returns
            ------- 
                the loglikelighood of the 'filled-in' data
        '''
        gopt = self.sr_given_ur(weights, data)
        if [gopt[i][0] == eV[i][0] for i in range(len(gopt))]: #sanity checking: are the orders of the two lists match up
            le_probs = [y for (x, y) in gopt] #get predicted probability
            log_prob = np.log(np.array(le_probs)) #log of predicted probability
            efreq = np.array([y for (x, y) in eV]) #get expected counts
            return np.dot(log_prob, efreq) 
        else: #if fail to match up, re-order
            opt = [] 
            for (x, y) in eV:
                for (x1, y1) in gopt:
                    if x1==x:
                        opt.append(y)
                        break
            return np.dot(np.log(np.array(opt), np.array([y for y in eV])))


    def penalty(self, weights):
        ''' This function implements L2 prior/Gaussian prior on constraint weights
        '''
        return (LAM/ (2*SIGMASQUARE)) * np.linalg.norm(weights-[MU]*len(weights))**2       

    def constraint_weight(self, initial_weights, data, eVals, CONS, obs):
        ''' This function performs the optimization of constraint weights by using l-bfgs-b with the constraints of non-negative weights. (18) in the main text
            In the current paper, each round of constraint re-weighting starts at 1. 


            Returns
            ------- 
                A new set of constraint weights that maximize the current 'filled-in' data (expected counts).
        '''
        if L2_PRIOR:
            objective = lambda weights: self.penalty(weights) - self.loglikelihood(weights, data, eVals)
        else:
            objective = lambda weights:- self.loglikelihood(weights, data, eVals)
        res = minimize(objective, np.full((len(CONS)), 1), method='L-BFGS-B', bounds=[(LOWER_BOUND, UPPER_BOUND) for x in initial_weights])
        #res = minimize(objective, initial_weights, method='L-BFGS-B', bounds=[(LOWER_BOUND, UPPER_BOUND) for x in initial_weights])
        #print(res.hess_inv.todense())
        return res.x



    def M_step(self, eVals, thetas):
        ''' This function performs re-estimation of UR probabilities given the current 'filled-in' data (expected counts). (19) in the main text 

            Parameters
            ----------
                 eVals: a list of tuples ((ur, sr, form), expected counts of (ur, sr) pair)
                 thetas: a dictionary which stores the UR probability {morph:{ur:P(ur|morph)}} 

            Returns
            ------- 
                thetas: updated UR probability; ogranized as a dictionary {morph:{ur:P(ur|morph)}} 
                storage: the same info ; organized as a list for plotting
        '''
        storage = []
        for morph, urs in thetas.items():
            relevant_form = [y for (x, y) in eVals if morph in x[2].split('_')]
            denom = np.sum(relevant_form)
            for ur, ur_param in urs.items():
                new_addedVal = 0
                relevant_ur = []
                for (x, y) in eVals:
                    splitted_morph = re.split('_', x[2])
                    splitted_ur = [j for j in re.split('/|-', x[0]) if j != '']
                    for i in range(len(splitted_morph)):
                        if morph == splitted_morph[i] and splitted_ur[i] == ur.split('/')[1]:
                            relevant_ur.append(y)
                if denom == 0:
                    new_addedVal = 0
                else:
                    new_addedVal = np.sum(relevant_ur)/denom
                thetas[morph][ur]=new_addedVal
                storage.append([morph, ur,new_addedVal])
        return (thetas, storage)



    #########################################################
    #########################################################
    #########################################################
    ################## The incomplete data ##################
    #########################################################
    #########################################################
    #########################################################


    def incomplete_data_log_likelihood(self, obs, thetas, weights, data):
        ''' This function calculates the log-likelihood of the 'incomplete' data. See (16) in the main text.
        '''
        opt = self.sr_given_ur(weights, data)
        conjoined = self.hidden_and_observation(opt, thetas)
        likelihood = 0
        for form, srs_freq in obs.items():
            for sr, freq in srs_freq.items():
                if freq != 0:
                    likelihood += freq*np.log(sum(conjoined[form][sr].values()))
        return likelihood

    def likelihood(self, obs, thetas, weights, data):
        ''' This function calculates the likelihood of the 'incomplete' data. See (15) in the main text. 
        '''
        opt = self.sr_given_ur(weights, data)
        conjoined = self.hidden_and_observation(opt, thetas)
        likelihood = 1
        for form, srs_freq in obs.items():
            for sr, freq in srs_freq.items():
                if freq != 0:
                    likelihood *= (sum(conjoined[form][sr].values()))**freq
        return likelihood


    ###############################################################
    ###############################################################
    ###############################################################
    ############Functions for plotting and output files############
    ###############################################################
    ###############################################################
    ###############################################################      

    def loglikelihood_plotting(self, loglikelihood, tokens):
        ''' This function plots the loglikelihood for the incomplete data through the learning
        '''
        iterations = [n for n in range(len(loglikelihood))]
        plt.figure()
        plt.plot(iterations, loglikelihood, label = "Loglike")

        plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
        plt.xlabel('Iterations')
        plt.ylabel('loglikelihood')
        plt.tight_layout()

        plt.savefig(tokens+'loglikelihood.png', dpi=400)
        plt.close()   

    def constraint_plotting(self, CONS, cache_weights, tokens):
        ''' This function plots the trajectory of the constraint weight updates. 
        '''
        iterations = [n for n in range(len(cache_weights))]
        plt.figure(0)
        for j in range(len(cache_weights[0])):
            cons_w = [i[j] for i in cache_weights]
            plt.plot(iterations, cons_w, label=CONS[j])
        plt.locator_params(axis="both", integer=True, tight=True)
        plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
        plt.xlabel('Iterations')
        plt.ylabel('constraint weights')
        plt.tight_layout()
        plt.savefig(tokens+'Constraints.png', dpi=400)
        plt.close()


    def parameter_plotting(self, cache_param, tokens):
        ''' This function plots the learning trajectories of the UR parameters.  
        '''
        iterations = [n for n in range(len(cache_param))]
        regrouping = cache_param[0]
        for i in range(len(regrouping)):
            regrouping[i][2] = [regrouping[i][2]]
            for each_run in cache_param[1:]:
                regrouping[i][2].append(each_run[i][2])

        previous_morph = regrouping[0][0]
        j = 1
        for each_morph in regrouping:
            current_morph = each_morph[0]
            if current_morph == previous_morph:
                plt.figure(j)
            else:
                j+=1
                plt.figure(j+1)
                j+=1
            plt.plot(iterations, each_morph[2], label = (each_morph[1].replace('+', '')).replace('.', '')+' '+each_morph[0])
            plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
            plt.xlabel('Iterations')
            plt.ylabel('UR probability')
            plt.tight_layout()
            plt.savefig(tokens+'URProbability_'+current_morph+'.png', dpi=400)
            previous_morph = current_morph    
        plt.close('all')


    def UR_distribution_output(self, thetas, gold, tokens):
        ''' This function outputs the final learned UR probabilities.   
        '''
        curr_form = ''
        frames = [['Morpheme','UR Candidate', 'Linguist-preferred','P(UR)']]
        for i in thetas:
            if i[0] != curr_form:
                curr_form = i[0]
                frames += [['', '', '', '']]
                if i[1].replace('.', '') == gold[i[0]]:
                    frames += [[curr_form, i[1], 1, i[2]]]
                else:
                    frames += [[curr_form, i[1], '', i[2]]]
            else:
                if i[1].replace('.', '') == gold[i[0]]:
                    frames += [['', i[1], 1, i[2]]]
                else:
                    frames+= [['', i[1], '', i[2]]]
        file_name=tokens +'FinalURProbabilities.csv'
        frames = pd.DataFrame(frames[1:], columns=frames[0])
        frames.to_csv(file_name, index = False)
 
    
    
    def spreadsheet_opt(self, weights, ur_sr, viol, thetas, CONS, obs, tokens, data):
        ''' This function outputs the final Tableau with winning URs only (P(UR) > 0.01)   
        '''
        predicted = [i[1] for i in self.sr_given_ur(weights, data)]
        final_data = list(zip(list(zip(ur_sr, viol)), predicted))
        poss = []
        for form, ur_sr in data.items():
            form = form.split('_')
            possible_urs = []
            for i in form:
                possible_urs.append([j for j in thetas[i].keys() if thetas[i][j]>0.01])

            possible_urs = list(itertools.product(*possible_urs))

            new = []
            for j in possible_urs:
                new.append([(i.split('/')[1]) for i in j])
            poss+=[['_'.join(form), '/'+'-'.join(i)+'/'] for i in new]

        urs_probs = {}
        for i in poss:
            forms = i[0].split('_')
            urs = [j for j in re.split('/|-', i[1]) if j != '']
            prob = 1
            for pair in list(zip(forms, urs)):
                prob *= thetas[pair[0]]['/'+pair[1]+'/']
            urs_probs[tuple(i)] = prob

        final_data = [i for i in final_data if [i[0][0][2], i[0][0][0]] in poss]

        re_organize = []
        re_organize.append(['Word','Observed Surface Frequency','UR','P(UR)', 'SR']+CONS+['P(SR|UR)', 'P(UR, SR)'])
        re_organize.append(['','','' ,'','']+weights.tolist()+['', ''])
        curr_form = ''
        for i in final_data:
            if i[0][0][2] != curr_form:
                curr_form = i[0][0][2]
                re_organize.append(['','','' ,'','']+['']*len(weights)+['', ''])
            form = i[0][0][2]
            freq = obs[i[0][0][2]][i[0][0][1]]
            ur = i[0][0][0]
            sr = i[0][0][1]
            ur_prob = urs_probs[(form, ur)]
            re_organize.append([form, freq, ur,ur_prob, sr]+i[0][1]+[i[1], i[1]*ur_prob])

        re_organize = pd.DataFrame(re_organize[1:], columns=re_organize[0])
        file_name=tokens +'Tableau_WinningURsOnly.csv'
        re_organize.to_csv(file_name, index=False)


    def file_output(self, CONS, cWeights, cache_weights, cache_param, tokens, ur_sr, viols, thetas, obs, data, loglike, gold):
        ''' This function puts the plotting and file output steps altogether. 
        '''
        if not os.path.exists(tokens):
            os.makedirs(tokens)
        self.constraint_plotting(CONS, cache_weights, tokens)
        self.parameter_plotting(cache_param, tokens)
        self.spreadsheet_opt(cWeights, ur_sr, viols, thetas, CONS, obs, tokens, data)
        self.loglikelihood_plotting(loglike, tokens)
        self.UR_distribution_output(cache_param[-1], gold, tokens)
        with open(tokens+'history.txt', 'w') as f:
            f.write("%s\n The constraint weights for each iteration:")
            for item in cache_weights:
                f.write("%s\n" % item)
            f.write("%s\n Learnt underlying representation for each iteration:")
            for item in cache_param:
                f.write("%s\n" % item)
            f.write("%s\n The log likelihood of the observed data:")
            for item in loglike:
                f.write("%s\n" % item)
        
        

    def update(self, feed, tokens, weights_initialization, gold):
        ''' This function connects up each procedure up and does the full search until converges, or until reaches the maximum number of iterations. 
        Parameters
        ----------
        file_path : a path to a csv file encoding a full tableau
        tokens: the identifier for this simulation, which will be used to generate output files
        weights_initialization: initial constraint weights; you can enter a full list, or a single value
        gold: a dictionary of gold URs for each observed morpheme.  

        '''

        #initializations
        start2 = time.time()
        j = 0
        cache_param = []
        cache_weights = []
        (CONS, viols, ur_sr, obs, (thetas,storage), data) = feed
        #initialize the constraint 
        if type(weights_initialization) != list:
            cWeights = np.full((len(CONS)), weights_initialization)
        else:
            cWeights = weights_initialization

        cache_weights.append(cWeights)
        cache_param.append(storage)
        loglike = []
        loglike.append(self.incomplete_data_log_likelihood(obs, thetas, cWeights, data))

        diff = abs(loglike[-1]) 

        #start to update 
        while(j < MAXIterations):
            j +=1
            if (diff < CON_EPSILON):
                #stop updating if the increase in loglikelihood is small enough
                self.file_output(CONS, cWeights, cache_weights, cache_param, tokens, ur_sr, viols, thetas, obs, data, loglike, gold)
                end2 = time.time()
                print('Time elapsed for the learning:'+str(end2 - start2)+'s')
                return (thetas, cWeights)
            else:
                (pos,ev) = self.E_step(cWeights, data, obs, thetas)
                start = time.time()
                cWeights = self.constraint_weight(cWeights, data, ev, CONS, obs)
                end = time.time()
                print('Time elapsed for the constraint weighting:'+str(end - start)+'s')
                if ECM:
                    (pos_2, ev_Wt) = self.E_step(cWeights, data, obs, thetas)
                    (thetas, storage) = self.M_step(ev_Wt, thetas)
                else:
                    (thetas, storage) = self.M_step(ev, thetas)
                #print(thetas)
                #print(cWeights)
                loglike.append(self.incomplete_data_log_likelihood(obs, thetas, cWeights, data))
                diff = abs(loglike[-1]-loglike[-2])
                cache_param = cache_param + [storage]
                cache_weights.append(cWeights)

        self.file_output(CONS, cWeights, cache_weights, cache_param, tokens, ur_sr, viols, thetas, obs, data, loglike, gold)
        end2 = time.time()
        print('Time elapsed for the learning:'+str(end2 - start2)+'s \n')
        return (thetas, cWeights)
    
    
    def wug_test(self, new_paradigm, weights):
        (CONS, viols, ur_sr, obs, (thetas,storage), data) = new_paradigm
        j = 0
        #initialize the constraint 
        if type(weights) != list:
            cWeights = np.full((len(CONS)), weights)
        else:
            cWeights = weights

        loglike = []
        loglike.append(self.incomplete_data_log_likelihood(obs, thetas, cWeights, data))

        diff = abs(loglike[-1]) 

        #start to update 
        while(j < MAXIterations):
            j +=1
            if (diff < CON_EPSILON):
                loglike.append(self.incomplete_data_log_likelihood(obs, thetas, cWeights, data))
                return thetas
            else:
                (pos,ev) = self.E_step(cWeights, data, obs, thetas)
                (thetas, storage) = self.M_step(ev, thetas)
                #print(thetas)
                #print(cWeights)
                loglike.append(self.incomplete_data_log_likelihood(obs, thetas, cWeights, data))
                diff = abs(loglike[-1]-loglike[-2])

        return thetas
