import re
import pandas as pd
import itertools
from interpolation import Interpolation
from lang import Lang

    
class GEN():
    def __init__(self, lang = Lang(), urgen = '', lang_name = ''):
        self.lang = lang
        self.urgen = urgen
        self.lang_name = lang_name
        
    def _alternation_conservative(self):
        '''
        the conservative version of the alternation list is the list without '.' as a key
        this is to prevent GEN from inserting all possible segments detected to alternate with a null element. 
        '''
        lang = self.lang
        alternation = lang.get_alternating_segments()
        if '.' in alternation.keys():
            alternation.pop('.') 
        return alternation
        
    def _gen_segment(self, alternation_seg): 
        ''' This function takes in a list of segments at the same slot within the surface forms and add the phonetically intermediate segments to the list Parameters
        ----------
        alternation_seg: a list of surface segments
        feature: a feature system (we assume fully specified binary-valued features); should be organized as feature_list 
    

        Returns
        -------
        result: an updated list of segments with phonetically intermediate segments enriched
        '''
        lang = self.lang
        comb = itertools.combinations(alternation_seg, 2)
        for two_segs in comb:
            if '.' not in two_segs:
                intermediate = Interpolation.phonetically_intermediate(two_segs[0], two_segs[1], lang.feat_dict)
                for seg in intermediate:
                    if seg not in alternation_seg:
                        alternation_seg.append(seg)           
        return alternation_seg
    
    def alternation_conservative_interpolative(self):
        lang = self.lang
        alternation_conservative = self._alternation_conservative()
        for seg, alter in alternation_conservative.items():
            for i in alter:
                if i != '.':
                    #find the phonetically intermediate segments between the target segment and its alternants
                    addedSegs = Interpolation.phonetically_intermediate(seg, i, lang.feat_dict) 
                    alternation_conservative[seg] += addedSegs  #add in these segments 
                    if self.urgen == 'KK-E': 
                        ##if KK-E, those phonetically intermdiate segments can be possible UR segments. 
                        for newSeg in addedSegs: 
                            #if they are not currently considered as possible UR candidates; add them as keys in the alternation list
                            if newSeg not in alternation_conservative.keys(): 
                                alternation_conservative[newSeg] = [seg]+[j for j in addedSegs if j != newSeg]
                            else:
                                alternation_conservative[newSeg] += [seg]+[j for j in addedSegs if j != newSeg]

        for seg, alter in alternation_conservative.items():
            alternation_conservative[seg]= list(set(alter)) #remove reduplicates
        return alternation_conservative
    
    
    def _cobble(self, srs):
        ''' This function generates cobbled forms based on an allomorph set.
        Parameters
        ----------
        srs: a list of strings; an allomorph set for a morpheme


        Returns
        -------
        cobbled_forms: an updated list of forms, containing the cobbled forms

        '''
        #At each slot, gather all possible surface realizations
        segs = []
        for i in range(len(srs[0])):
            new_list = list(set([sr[i] for sr in srs]))
            segs.append(new_list)

        #Form all logical combinations and create the strings
        combins = list(itertools.product(*segs))    
        cobbled_form = [' '.join(i) for i in combins]
        return cobbled_form
    
    def _feature_cobble(self, srs, feature):
        ''' This function generates featurally composite forms based on an allomorph set.
        Parameters
        ----------
        srs: a list of strings; an allomorph set for a morpheme
        feature: a feature system (we assume fully specified binary-valued features); should be organized as feat_dict 


        Returns
        -------
        cobbled_forms: an updated list of forms, containing the featurally cobbled forms

        '''
    
        #At each slot, gather all possible realizations and perform feature cobbling
        # (when binary-valued, they are the same as finding the phonetically intermediate segments)
        segs = []
        for i in range(len(srs[0])):
            new_list = list(set([sr[i] for sr in srs]))

            if len(new_list)>1: #namely, at each slot, there are multiple segments possible
                segs.append(self._gen_segment(new_list)) 
            else:
                segs.append(new_list)

        #Form all logical combinations and create the strings
        combins = list(itertools.product(*segs))    
        cobbled_form = [' '.join(i) for i in combins]
        return cobbled_form
    

    def _free_substitution(self, srs, alternations):
        ''' This function generates free-substituted forms according to an alternation list.
        Parameters
        ----------
        srs: a list of strings; an allomorph set for a morpheme
        alternations: a dictionary of OBSERVED segmental alternations, 
                      keys being possible UR segments; values as all possible surface alternants for a segment 


        Returns
        -------
        new_forms: an updated list of forms, containing free-substituted ones 

        '''
        new_srs = []
        for sr in srs:
            new_srs.append(sr.split(' '))
            segs = []
            for i in range(len(new_srs[0])):
                new_list = []
                for srs in new_srs:
                    if srs[i] not in new_list:
                        new_list.append(srs[i])
                segs.append(new_list)

              #at each slot, if we observe a segment, it's possible to be derived through any observed alternants of this segment. Thus, append the alternants into the UR list. 
            new_form = []
            for each_slot in segs:
                new_slot = each_slot[:]
                for each_poss in new_slot:
                    if each_poss in alternations: 
                        each_slot += [i for i in alternations[each_poss] if i not in new_slot] #put in any segments that are observed to alternate with the current target segment for each slot
                new_form.append(each_slot)

            combins = list(itertools.product(*new_form))    
            new_form = list(set([' '.join(i) for i in combins])) #create the strings
        return new_form
    
    
    def create_urs(self):
        lang = self.lang
        allomorph = lang.get_allomorph()
    
        ##based on the URGEN selection, create a dictionary which eventually stores (UR, SR) pairs for each morpheme
        if self.urgen == 'KK-C':
            return allomorph
        else:
            ur_sr = {}
            for morph, surface in allomorph.items():
                surface = [i.split(' ') for i in surface]
                if self.urgen == 'KK-D':
                    ur = self._cobble(surface)
                elif self.urgen == 'KK-E':
                    ur = self._feature_cobble(surface, lang.feat_dict)
                elif self.urgen == 'KK-Z':
                    alternation_conservative = self._alternation_conservative()
                    surface = [' '.join(i) for i in surface]
                    ur = self._free_substitution(surface, alternation_conservative)
                ur_sr[morph] = ur
            return ur_sr
    
    def  _get_insertion_context(self, allomorph):
        ''' This function takes in a list of allomorphs and records the local context of possible insertio; 
            distributions are trigram-based (left, right).
        '''
        cxt = {}
        new_srs = [(sr.split(' ')) for sr in allomorph]

        for i in new_srs: #for each allomorph
            for k in range(len(i)): #for each position
                if i[k] == '.': #if we find a null segment in a form
                    for j in new_srs: #for all allomorphs
                        if j[k]!='.': #if the segment at this position is not null
                            if k == 0 and len(i)>1: #if '.' is found at the beginning of the word, use "#" to represent the boundary
                                if (('#', j[k+1]) in cxt) and (j[k] not in cxt[('#', j[k+1])]): #if this context is already recorded and the current segment is not recorded for this context
                                    cxt[('#', j[k+1])] += [j[k]] #add in the list for this context
                                elif (('#', j[k+1]) not in cxt): #if this context is not already recorded
                                    cxt[('#', j[k+1])] = [j[k]] #create a key and add in its list
                            elif k == len(i)-1 and len(i)>1:  #if '.' is found at the end of the word, use "#" to represent the boundary
                                if ((j[k-1], '#') in cxt) and (j[k] not in cxt[(j[k-1], '#')]):
                                    cxt[(j[k-1], '#')] += [j[k]]
                                elif ((j[k-1], '#') not in cxt):
                                    cxt[(j[k-1], '#')] = [j[k]]           
                            elif len(i)>1: #if none of the previous conditions apply and the length is longer than 1
                                if ((j[k-1], j[k+1]) in cxt) and (j[k] not in cxt[(j[k-1], j[k+1])]):
                                    cxt[(j[k-1], j[k+1])] += [j[k]]
                                elif ((j[k-1], j[k+1]) not in cxt):
                                    cxt[(j[k-1], j[k+1])] = [j[k]]
                            elif len(i) == 1: #if the length of the current allomorph is one, that means it's adjacent to both boundaries.
                                if (('#', '#') in cxt) and (j[k] not in cxt[('#', '#')]): 
                                    cxt[('#', '#')] += [j[k]] 
                                elif (('#', '#') not in cxt):
                                    cxt[('#', '#')] = [j[k]] 
        return cxt
    
    def _get_context_all(self):
        cxt = {}
        allomorph = self.lang.get_allomorph()
        for form, allomorphs in allomorph.items():
            cxt.update(self._get_insertion_context(allomorphs))
        return cxt
    
    def _create_srs(self, ur):
        alternation_conservative_interpolative = self.alternation_conservative_interpolative()
        if '.' not in ur:
            return self._free_substitution([ur], alternation_conservative_interpolative)
        else: 
            srs = self._free_substitution([ur], alternation_conservative_interpolative) #first, perform regular alternation substitution GEN. 
            ur = ur.split(' ') #splitting UR
            srs = [i.split(' ') for i in srs]
            allo_list= [i.split(' ') for i in allomorph[morph]] #split observed allomorphs into list of segments

            for idx in range(len(ur)): #look at each index
                if ur[idx] =='.':  #if we observe a '.'
                    idx_seg = [k[idx] for k in allo_list if k!=ur] #find out all observed segments at that slot of '.' for all allomrphs
                    new_srs = [m[:idx]+[k]+m[idx+1:] for m in srs for k in idx_seg] #add those observed segments
                    srs+=new_srs #add in this sr
            return (list(set([' '.join(i) for i in srs]))) #remove reduplicates
    
    def ur_sr_gen(self):
        ur_sr = self.create_urs()
        ur_sr_complete = {}
        cxt =self._get_context_all()
        for morph, urs in ur_sr.items():
            ur_sr_complete[morph] = {}
            for ur in urs:
                srs = self._create_srs(ur)
                srs = [i.split(' ') for i in srs]
                #split urs, srs and enrich boundaries
                ur = ur.split(' ')
                new_ur = ['#']+ur+['#']  
                new_srs = [['#']+i+['#'] for i in srs]

                add = 0
                for i in range(len(new_ur)-1): #at each position
                    if any([(j[i], j[i+1]) in cxt.keys() for j in new_srs]): #if there's one generated sr that satisfy the epenthesis context
                        ur.insert(i+add, '.') #first, enrich the UR with a placeholder 
                        for k in srs: #create a new copy of the sr
                            new = ["#"]+k+['#']
                            kc = list(k)

                            if (new[i+add], new[add+i+1]) in cxt.keys(): #examine whether to epenthesize 
                                k.insert(i+add, '.') #for this copy, need to enrich a place holder to keep the length consistent with the urs
                                srs += [kc[:i+add]+[seg]+kc[i+add:] for seg in cxt[(new[i+add], new[add+i+1])]] #add in more segments
                            else: #if no, still need to enrich a place holder to keep the length consitent with the urs
                                if len(k)<len(ur):
                                    k.insert(i+add, '.')
                        add+=1 #increment the position 
                ur_sr_complete[morph][' '.join(ur)]= list(set([' '.join(sr) for sr in srs])) #remove reduplicates
        return ur_sr_complete 
    
    def save_ur_sr_pairs(self):
        final_data = [['Stem', 'Word + MS', 'Observed Surface Frequency', 'UR','SR']]
        obs = self.lang.obs
        obs = [[re.split(' - | -', i[0]), 
                   [j for j in re.split(' ', i[1].replace('stem', '')) if j!='']]
                  for i in obs]
        ur_sr_complete = self.ur_sr_gen()
        for curr_obs, curr_form in obs: #for each observation
            urs = [(ur_sr_complete[j].keys()) for j in curr_form] #get all relevant URs for each position
            srs = [(ur_sr_complete[j].values()) for j in curr_form] #get all revelant SRs for each position (note that the order of the values will always correspond to the order of keys, so we can ignore the correspondence between ur and lists of srs.). 
            urs =list(itertools.product(*urs)) #itertools will also preserve the correspondence
            srs = list(itertools.product(*srs)) 
            ur_sr_pair = list(zip(urs, srs)) #zip each ur, (lists of sr) together
            for ur, sr in ur_sr_pair: 
                sr = list(itertools.product(*sr)) #combining SRs if polymorphemic
                for i in sr:
                    [stem] =[j for j in curr_form if j[0].isupper()] #check out what is the stem; stems are indicated by taking upper case
                    if [(((k.replace(' . ', '')).replace('. ', '')).replace(' .', '')).replace('.', '') for k in i] == [(((k.replace(' . ', '')).replace('. ', '')).replace(' .', '')).replace('.', '') for k in curr_obs]: #check whether the current sr is observed
                        row = [stem, '_'.join(curr_form), 1, [k.split(' ') for k in ur], [k.split(' ') for k in i]]
                    else:
                        row = [stem, '_'.join(curr_form), 0, [k.split(' ') for k in ur], [k.split(' ') for k in i]]
                    final_data.append(row)
        
        
        
        for i in final_data[1:]:
            i[3] = [''.join(j) for j in i[3]]
            i[4] = [''.join(j) for j in i[4]]
            i[3] = '/'+'-'.join(i[3])+'/'
            i[4] = '['+'-'.join(i[4])+']'
            
        
        result = pd.DataFrame(final_data, columns=final_data[0])   
        file_name = self.lang_name + '_'+self.urgen+'_FullInputTableau_NeedCons.csv'
        result.to_csv(file_name, index=False)




