package org.simantics.scl.compiler.parser.regexp.automata;

import gnu.trove.impl.Constants;
import gnu.trove.list.array.TByteArrayList;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.map.hash.TIntIntHashMap;
import gnu.trove.procedure.TIntIntProcedure;
import gnu.trove.procedure.TIntProcedure;
import gnu.trove.set.hash.TIntHashSet;

import java.util.ArrayList;

import org.simantics.scl.compiler.parser.regexp.RAtom;
import org.simantics.scl.compiler.parser.regexp.Regexp;

public class DFA implements Automata {

    private ArrayList<TIntIntHashMap> transitions = new ArrayList<TIntIntHashMap>(); 
    private TByteArrayList accepts = new TByteArrayList();
    private int initialState;
    
    public int newState() {
        int stateId = transitions.size();
        transitions.add(new TIntIntHashMap(
                Constants.DEFAULT_CAPACITY,
                Constants.DEFAULT_LOAD_FACTOR,
                0, -1));
        accepts.add((byte)0);
        return stateId;
    }
    
    public int size() {
        return transitions.size();
    }
    
    public DFA copy() {
        DFA copy = new DFA();
        for(TIntIntHashMap t : transitions)
            copy.transitions.add(new TIntIntHashMap(t));
        copy.accepts = new TByteArrayList(accepts);
        copy.initialState = initialState;
        return copy;
    }

    public void addTransition(int sourceId, int symbol, int targetId) {
        transitions.get(sourceId).put(symbol, targetId);
    }
    
    public int getTransition(int sourceId, int symbol) {
        return transitions.get(sourceId).get(symbol);
    }
    
    public void forEachTransition(int source, final TIntIntProcedure proc) {        
        transitions.get(source).forEachEntry(proc);
    }
    
    public int[] nextStates(int id) {
        return transitions.get(id).keys();
    }
    
    public void setAccepts(int id, boolean accepts) {
        this.accepts.set(id, accepts ? (byte)1 : (byte)0);
    }
    
    public boolean getAccepts(int id) {
        return accepts.get(id)==1;
    }

    public void setInitialState(int initialState) {
        this.initialState = initialState;
    }
    
    public int getInitialState() {
        return initialState;
    }

    public DFA minimize() {
        // Compute relevant input characters for minimization
        final TIntIntHashMap symbolMap = new TIntIntHashMap();
        final int[] symbolArray;
        {
            final TIntArrayList l = new TIntArrayList();
            TIntProcedure proc = new TIntProcedure() {
                @Override
                public boolean execute(int value) {
                    if(!symbolMap.containsKey(value)) {
                        symbolMap.put(value, l.size());
                        l.add(value);
                    }
                    return true;
                }
            };
            for(TIntIntHashMap tMap : transitions)
                tMap.forEachKey(proc);
            symbolArray = l.toArray();
        }
        int symbolCount = symbolMap.size();        
        int stateCount = transitions.size();
        
        // Inverse automata
        final TIntArrayList[][] inverse = new TIntArrayList[stateCount+1][];
        for(int i=0;i<inverse.length;++i)
            inverse[i] = new TIntArrayList[symbolCount];
        for(int sourceId=0;sourceId<stateCount;++sourceId) {
            TIntIntHashMap tMap = transitions.get(sourceId);
            for(int j=0;j<symbolCount;++j) {
                int targetId = tMap.get(symbolArray[j]);
                if(targetId == -1)
                    targetId = stateCount;
                TIntArrayList l = inverse[targetId][j];
                if(l == null) {
                    l = new TIntArrayList();
                    inverse[targetId][j] = l;   
                }
                l.add(sourceId);
            }
        }
        
        /*for(int i=0;i<inverse.length;++i)
            for(int j=0;j<symbolCount;++j)
                System.out.println(i + " " + j + " -> " + inverse[i][j]);
        */
        
        // 
        int[] ids = new int[stateCount+1];
        final int[] memPartion = new int[stateCount+1];        
        TIntArrayList partionBegin = new TIntArrayList();
        TIntArrayList partionEnd = new TIntArrayList();
        TIntArrayList stack = new TIntArrayList();
        TIntArrayList scheduled = new TIntArrayList();
        
        // Initial partition
        {
            int min = 0;
            int max = stateCount;
            ids[min++] = stateCount;
            memPartion[stateCount] = 0;
            for(int i=0;i<stateCount;++i) {
                if(accepts.get(i)==1) {
                    ids[max--] = i;
                    memPartion[i] = 1;
                }
                else {
                    ids[min++] = i;
                    memPartion[i] = 0;
                }
            }
            partionBegin.add(0);
            partionBegin.add(min);
            partionEnd.add(min);
            partionEnd.add(ids.length);
            scheduled.add(0);
            scheduled.add(0);
            if(min < ids.length/2) {
                stack.add(0);
                scheduled.set(0, 1);
            }
            else {
                stack.add(1);
                scheduled.set(1, 1);
            }
        }
        
        // Refinement
        while(!stack.isEmpty()) {
            int partionId = stack.removeAt(stack.size()-1);
            scheduled.set(partionId, 0);
            int begin = partionBegin.get(partionId);
            int end = partionEnd.get(partionId);
            
            for(int j=0;j<symbolCount;++j) {
                TIntHashSet invStates = new TIntHashSet();
                for(int i=begin;i<end;++i) {
                    int s = ids[i];
                    TIntArrayList inv = inverse[s][j];
                    if(inv != null)
                        invStates.addAll(inv);
                }
                int[] invStatesArray = invStates.toArray();
                
                TIntHashSet partions = new TIntHashSet();
                for(int s : invStatesArray)
                    partions.add(memPartion[s]);
                int[] partionsArray = partions.toArray();
                
                for(int p : partionsArray) {
                    int pBegin = partionBegin.get(p);
                    int pEnd = partionEnd.get(p);
                    boolean splits = false;
                    for(int k=pBegin;k<pEnd;++k)
                        if(!invStates.contains(ids[k])) {
                            splits = true;
                            break;
                        }
                    if(splits) {
                        int p2 = partionBegin.size();
                        int p2End = pEnd;
                        for(int k=pBegin;k<pEnd;++k) {
                            int s = ids[k];
                            if(invStates.contains(s)) {
                                memPartion[s] = p2;
                                --pEnd;
                                ids[k] = ids[pEnd];
                                ids[pEnd] = s;
                                --k;
                            }
                        }
                        partionEnd.set(p, pEnd);
                        partionBegin.add(pEnd);
                        partionEnd.add(p2End);
                        
                        if(scheduled.get(p) == 1) {
                            scheduled.add(1);
                            stack.add(p2);
                        }
                        else {
                            if(pEnd - pBegin <= p2End - pEnd) {
                                stack.add(p);
                                scheduled.add(0);
                                scheduled.set(p, 1);
                            }
                            else {
                                stack.add(p2);
                                scheduled.add(1);
                            }
                        }
                    }
                }
            }
        }
        
        // Print partitions
        /*System.out.println("Partition count: " + partionBegin.size());
        for(int i=0;i<memPartion.length;++i)
            System.out.println("    " + i + " in " + memPartion[i]);
        */
        
        // Put reachable states into new automaton
        final DFA aut = new DFA();
        int failurePartion = memPartion[stateCount];
        final int[] sArray = new int[partionBegin.size()];
        sArray[failurePartion] = -1;
        for(int i=0;i<partionBegin.size();++i)
            if(i != failurePartion)
                sArray[i] = aut.newState();
        for(int i=0;i<partionBegin.size();++i)
            if(i != failurePartion) {
                final int sourceId = sArray[i];
                int bId = ids[partionBegin.get(i)];
                forEachTransition(bId, new TIntIntProcedure() {
                    @Override
                    public boolean execute(int a, int b) {
                        aut.addTransition(sourceId, a, sArray[memPartion[b]]);
                        return true;
                    }
                });
                aut.setAccepts(sourceId, getAccepts(bId));
            }
        aut.setInitialState(sArray[memPartion[initialState]]);
        return aut;
    }
    
    public Regexp toRegexp(int from, byte[] to) {
        int stateCount = size();
        final int[] order = new int[stateCount];
        for(int i=0;i<stateCount;++i)
            order[i] = i;
        order[0] = from;
        order[from] = 0;
        
        final Regexp[][] a = new Regexp[stateCount][stateCount];
        final Regexp[] b = new Regexp[stateCount];
        for(int i=0;i<stateCount;++i) {
            b[i] = to[order[i]]==1 ? Regexp.ONE : Regexp.ZERO;
            for(int j=0;j<stateCount;++j)
                a[i][j] = Regexp.ZERO;
            final Regexp[] row = a[i];
            forEachTransition(order[i], new TIntIntProcedure() {
                @Override
                public boolean execute(int symbol, int targetId) {
                    targetId = order[targetId];
                    row[targetId] = Regexp.or(row[targetId], new RAtom(symbol));
                    return true;
                }
            });
        }
        
        for(int n=stateCount-1;n>=0;--n) {
            Regexp ss = Regexp.star(a[n][n]);
            b[n] = Regexp.seq(ss, b[n]);
            for(int j=0;j<stateCount;++j)
                a[n][j] = Regexp.seq(ss, a[n][j]);
            for(int i=0;i<n;++i) {
                b[i] = Regexp.or(b[i], Regexp.seq(a[i][n], b[n]));
                for(int j=0;j<n;++j)
                    a[i][j] = Regexp.or(a[i][j], Regexp.seq(a[i][n], a[n][j]));
            }
        }
        
        return b[0];
    }

    public Regexp toRegexp() {
        return toRegexp(initialState, accepts.toArray()).simplify();
    }
    
    public Regexp toRegexpTo(int position) {
        int stateCount = size();
        byte[] targetArray = new byte[stateCount];
        targetArray[position] = 1;
        return toRegexp(initialState, targetArray).simplify();
    }
    
    public Regexp toRegexpFrom(int position) {
        return toRegexp(position, accepts.toArray()).simplify();
    }

    public Regexp toPositionalRegexp(int position) {
        DFA aut = copy();
        int s = transitions.size();
        aut.transitions.add(aut.transitions.get(position));
        aut.transitions.set(position, new TIntIntHashMap());
        aut.addTransition(position, 0x80000000, s);
        aut.accepts.add(accepts.get(position));
        aut.accepts.set(position, (byte)0);
        return aut.toRegexp();
    }
}
