package org.simantics.scl.compiler.parser.generator.table;

import java.util.ArrayList;
import java.util.Arrays;

import org.simantics.scl.compiler.parser.generator.grammar.AnaGrammar;
import org.simantics.scl.compiler.parser.generator.grammar.Prod;

import gnu.trove.list.array.TIntArrayList;
import gnu.trove.list.array.TLongArrayList;
import gnu.trove.map.hash.TIntIntHashMap;
import gnu.trove.map.hash.TIntObjectHashMap;
import gnu.trove.map.hash.TLongIntHashMap;
import gnu.trove.map.hash.TObjectIntHashMap;
import gnu.trove.procedure.TIntIntProcedure;
import gnu.trove.procedure.TIntObjectProcedure;
import gnu.trove.procedure.TObjectIntProcedure;
import gnu.trove.set.hash.THashSet;
import gnu.trove.set.hash.TIntHashSet;
import gnu.trove.set.hash.TLongHashSet;

public class ParseTableBuilder {
    public final static int MAX_STACK_ID = 10;
    
    private static final int STATE_MASK = 0x0fff;
    private static final int REDUCE_MASK = 0x8000;
    private static final int POP_MASK = 0x4000;
    private static final int PUSH_MASK = 0x2000;
    public static final int ERROR_ACTION = 0xffff;
    private static final int ACCEPT_ACTION = 0xfffe;

    final AnaGrammar grammar;
    private TObjectIntHashMap<ItemSet> states = new TObjectIntHashMap<ItemSet>();
    private ArrayList<ItemSet> itemSets = new ArrayList<ItemSet>();
    private ArrayList<TIntIntHashMap> transitions = new ArrayList<TIntIntHashMap>();
    private ArrayList<TIntIntHashMap> stackOps = new ArrayList<TIntIntHashMap>();
    private TIntArrayList backTransSymbols = new TIntArrayList();
    private ArrayList<TIntArrayList> backLinks = new ArrayList<TIntArrayList>();
    int[] initialStates;
    TIntHashSet finalStates = new TIntHashSet(); 

    private ParseTableBuilder(AnaGrammar grammar) {
        this.grammar = grammar;
    }
        
    private static boolean isNonterminal(int symbol) {
        return symbol < 0;
    }
    
    private void close(ArrayList<Item> items) {
        THashSet<Item> itemSet = new THashSet<Item>(items);
        for(int i=0;i<items.size();++i) {
            Item item = items.get(i);
            for(int nextSymbol : item.nextSymbols(grammar))
                if(isNonterminal(nextSymbol)) {
                    nextSymbol = ~nextSymbol;
                    int pEnd = grammar.nonterminalPos[nextSymbol+1];
                    for(int pId=grammar.nonterminalPos[nextSymbol];pId<pEnd;++pId) {
                        Item newItem = new Item(pId, grammar.prods.get(pId).rhs.getInitialState(), -1);
                        if(itemSet.add(newItem))
                            items.add(newItem);
                    }
                }                
        }
    }
    
    private int getState(int backTransSymbol, ArrayList<Item> items) {
        // Create state
        close(items);
        final ItemSet itemSet = new ItemSet(items);
        if(states.contains(itemSet))
            return states.get(itemSet);
        final int newState = states.size();
        states.put(itemSet, newState);
        itemSets.add(itemSet);
        backTransSymbols.add(backTransSymbol);
        backLinks.add(new TIntArrayList(2));
        
        // Create transitions
        TIntObjectHashMap<ArrayList<Item>> transitionMap = new TIntObjectHashMap<ArrayList<Item>>();
        //close(items);
        for(Item item : items) {
            for(int s : item.nextSymbols(grammar)) {
                ArrayList<Item> l = transitionMap.get(s);
                if(l == null) {
                    l = new ArrayList<Item>();
                    transitionMap.put(s, l);
                }
                l.add(new Item(item.production, item.nextPosition(grammar, s), item.stackPos));
            }
        }
        
        final TIntIntHashMap trans = new TIntIntHashMap();
        final TIntIntHashMap stackOpMap = new TIntIntHashMap();
        transitions.add(trans);
        stackOps.add(stackOpMap);
        if(transitionMap.remove(grammar.terminalNames.length-1)!=null) {
            finalStates.add(newState);
        }
        transitionMap.forEachEntry(new TIntObjectProcedure<ArrayList<Item>>() {
            @Override
            public boolean execute(int a, ArrayList<Item> b) {
                boolean stackShift = false;
                int minStackPos = Integer.MAX_VALUE;
                for(Item item : b) {
                    if(item.stackPos == -1)
                        stackShift = true;
                    else
                        minStackPos = Math.min(minStackPos, item.stackPos);
                }
                int stackOp = 0;
                if(minStackPos > 0 && minStackPos != Integer.MAX_VALUE) {
                    stackOp |= POP_MASK;
                    //System.out.println("minStackPos = " + minStackPos);
                    for(Item item : b)
                        if(item.stackPos >= 0)
                            --item.stackPos;
                }
                boolean stackOverflow = false;
                if(stackShift) {
                    stackOp |= PUSH_MASK;
                    for(Item item : b) {
                        ++item.stackPos;
                        if(item.stackPos > MAX_STACK_ID)
                            stackOverflow = true;
                    }
                }
                stackOpMap.put(a, stackOp);
                System.out.println(newState + " " + grammar.getName(a) + " " + stackOp);
                
                if(stackOverflow) {
                    System.err.println("Stack overflow when following " + grammar.getName(a) + " at");
                    System.err.println(itemSet.toString(grammar));
                }
                else {
                    int state = getState(a, b);
                    trans.put(a, state);
                    backLinks.get(state).add(newState);
                }
                return true;
            }
            
        });
        return newState;
    }

    TLongArrayList sMap = new TLongArrayList();
    TLongIntHashMap sMapInv = new TLongIntHashMap();
    TIntHashSet[] follow;
        
    private static int getState(long s) {
        return (int)(s >> 32);
    }
    
    private static int getSymbol(long s) {
        return (int)s;
    }
    
    private static long getS(int state, int symbol) {
        return (((long)state) << 32) | (long)symbol;
    }
    
    private void computeFollow() {
        for(int i=0;i<itemSets.size();++i) {
            final int source = i;
            transitions.get(i).forEachEntry(new TIntIntProcedure() {
                @Override
                public boolean execute(int symbol, int target) {
                    if(symbol < 0) {
                        long s = getS(source, ~symbol);
                        int id = sMap.size();
                        sMap.add(s);
                        sMapInv.put(s, id);
                    }
                    return true;
                }
            });
        }

        // initfollow
        follow = new TIntHashSet[sMap.size()];
        final TIntHashSet[] gread = new TIntHashSet[follow.length];
        final TIntHashSet[] gla = new TIntHashSet[sMap.size()];
        for(int i=0;i<follow.length;++i) {
            gread[i] = new TIntHashSet();
            gla[i] = new TIntHashSet();
        }
        for(int i=0;i<follow.length;++i) {
            final int id = i;
            long s = sMap.get(i);
            int source = getState(s);
            int symbol = getSymbol(s);
            final int target = transitions.get(source).get(~symbol);
            final TIntHashSet drSet = new TIntHashSet();
            transitions.get(target).forEachEntry(new TIntIntProcedure() {
                @Override
                public boolean execute(int symbol2, int target2) {
                    if(symbol2 >= 0)
                        drSet.add(symbol2);
                    else if(grammar.nullable[~symbol2])
                        gread[sMapInv.get(getS(target, ~symbol2))].add(id);
                    return true;
                }
            });
            if(finalStates.contains(target))
                drSet.add(grammar.terminalNames.length-1);
            follow[i] = drSet;
            
            ItemSet set = itemSets.get(target);
            for(Item targetItem : set.items) {
                Prod prod = grammar.prods.get(targetItem.production);
                if(grammar.almostAccepts(prod.rhs, targetItem.position)) {
                    for(Item sourceItem : itemSets.get(source).items) {
                        if(sourceItem.production == targetItem.production &&
                                prod.rhs.getTransition(sourceItem.position, ~symbol) == targetItem.position) {
                            TLongHashSet visited = new TLongHashSet(); 
                            traceBack(gla, id, visited, source, sourceItem);
                        }
                    }
                    
                }
            }
        }
        //System.out.println("follow: " + Arrays.toString(follow));
        //System.out.println("gread: " + Arrays.toString(gread));
        //System.out.println("gla: " + Arrays.toString(gla));
        AnaGrammar.gclose(follow, gread);
        AnaGrammar.gclose(follow, gla);
        
        /*System.out.println("Gla:");
        for(int i=0;i<gla.length;++i) {
            int iState = getState(sMap.get(i));
            int iSymbol = getSymbol(sMap.get(i));
            for(int j : gla[i].toArray()) {
                int jState = getState(sMap.get(j));
                int jSymbol = getSymbol(sMap.get(j));
                System.out.println("-- from --");
                System.out.println(itemSets.get(iState).toString(grammar));
                System.out.println("    symbol: " + grammar.nonterminalNames[iSymbol]);
                System.out.println("-- to --");
                System.out.println(itemSets.get(jState).toString(grammar));
                System.out.println("    symbol: " + grammar.nonterminalNames[jSymbol]);
            }
        }*/
        
        /*for(int i=0;i<follow.length;++i) {
            long s = sMap.get(i);
            int source = getState(s);
            int symbol = getSymbol(s);
            System.out.println("------------------------------");
            System.out.println(itemSets.get(source).toString(grammar));
            System.out.println("Symbol: " + grammar.nonterminalNames[symbol]);
            System.out.print("Follow:");
            for(int sym : follow[i].toArray())
                System.out.print(" " + grammar.terminalNames[sym]);
            System.out.println();
        }*/
    }
    
    private void traceBack(TIntHashSet[] gla, int initialId, TLongHashSet visited, int state, Item item) {
        if(visited.add( (((long)state)<<32) | (long)item.position )) {
            Prod prod = grammar.prods.get(item.production);
            if(item.stackPos == -1) {
                int id = sMapInv.get(getS(state, prod.lhs));
                gla[id].add(initialId);
            }
            
            int backTransSymbol = backTransSymbols.get(state);
            for(int prevState : backLinks.get(state).toArray())
                for(Item prevItem : itemSets.get(prevState).items)
                    if(prevItem.production == item.production &&
                            prod.rhs.getTransition(prevItem.position, backTransSymbol) == item.position)
                        traceBack(gla, initialId, visited, prevState, prevItem);
        }
    }
    
    private void lookback( TLongHashSet visited, TIntHashSet la, int production, int state, int position) {
        if(visited.add( (((long)state)<<32) | (long)position )) {
            int backTransSymbol = backTransSymbols.get(state);
            Prod prod = grammar.prods.get(production);
            boolean mightBeInitial = prod.rhs.getTransition(prod.rhs.getInitialState(), backTransSymbol) == position;
            for(int prevState : backLinks.get(state).toArray()) {
                for(Item item : itemSets.get(prevState).items) {
                    if(item.production == production &&
                            prod.rhs.getTransition(item.position, backTransSymbol) == position)
                        lookback(visited, la, production, prevState, item.position);
                    if(mightBeInitial && grammar.prods.get(item.production).rhs.getTransition(item.position, ~prod.lhs) >= 0) {
                        int id = sMapInv.get(getS(prevState, prod.lhs));
                        la.addAll(follow[id]);
                    }
                }
            }
        }
    }

    private void createReduceActions() {
        computeFollow();
        for(int i=0;i<itemSets.size();++i) {
            TIntIntHashMap trans = transitions.get(i);
            if(finalStates.contains(i))
                trans.put(grammar.terminalNames.length-1, ACCEPT_ACTION);
            
            TIntObjectHashMap<TIntHashSet> laMap = new TIntObjectHashMap<TIntHashSet>();
            TIntIntHashMap stackPosMap = new TIntIntHashMap(); 
            for(Item item : itemSets.get(i).items) {
                Prod prod = grammar.prods.get(item.production);
                if(prod.rhs.getAccepts(item.position)) {
                    TIntHashSet la = laMap.get(item.production);
                    if(la == null) {
                        la = new TIntHashSet();
                        laMap.put(item.production, la);
                    }
                    
                    TLongHashSet visited = new TLongHashSet();
                    lookback(visited, la, item.production, i, item.position);
                    
                    if(stackPosMap.containsKey(item.production)) {
                        stackPosMap.put(item.production, Math.max(item.stackPos, stackPosMap.get(item.production))); // TODO arbitrary choice
                    }
                    else
                        stackPosMap.put(item.production, item.stackPos);
                }
            }
            
            // Create transitions
            for(int production : laMap.keys()) {
                int stackPos = 0; //stackPosMap.get(production);
                TIntHashSet la = laMap.get(production);
                for(int symbol : la.toArray()) {
                    if(trans.contains(symbol)) {
                        int oldAction = trans.get(symbol);
                        if(oldAction >= 0) {
                            Prod prod = grammar.prods.get(production);
                            if(prod.annotations.containsKey(symbol)) {
                                byte v = prod.annotations.get(symbol);
                                if(v == 1)
                                    trans.put(symbol, REDUCE_MASK | production | (stackPos << 13));
                            }
                            else {
                                System.err.println("Shift/reduce conflict when encountering " + grammar.terminalNames[symbol] + " in context");
                                System.err.println(itemSets.get(i).toString(grammar));
                            }
                        }
                        else {
                            System.err.println("Reduce/reduce conflict when encountering " + grammar.terminalNames[symbol] + " in context");
                            System.err.println(itemSets.get(i).toString(grammar));
                        }
                    }
                    else
                        trans.put(symbol, REDUCE_MASK | production | (stackPos << 13));
                }
            }
            
            // Check stacking conflicts
            /*trans.forEachEntry(new TIntIntProcedure() {
                @Override
                public boolean execute(int a, int b) {
                    if(b >= 0) {
                        boolean kernelState = false;
                        boolean nonkernelState = false;
                        for(Item item : itemSets.get(b).items) {
                            Prod prod = grammar.prods.get(item.production);
                            if(item.position == prod.rhs.getTransition(prod.rhs.getInitialState(), a))
                                nonkernelState = true;
                            else if(item.position != prod.rhs.getInitialState())
                                kernelState = true;
                        }
                        
                        
                        if(kernelState && nonkernelState) {
                            System.err.println("Stacking conflict when following " + grammar.getName(a) + " to");
                            System.err.println(itemSets.get(b).toString(grammar));
                        }
                    }
                    return true;
                }
            });*/
        }
    }
    
    public static ParseTable build(AnaGrammar grammar) {
        ParseTableBuilder builder = new ParseTableBuilder(grammar);
        
        builder.initialStates = new int[grammar.initialNonterminals.length];
        for(int i=0;i<grammar.initialNonterminals.length;++i) {
            ArrayList<Item> seed = new ArrayList<Item>();
            int prodId = grammar.prods.size()-i-1;
            seed.add(new Item(prodId, 
                    grammar.prods.get(prodId).rhs.getInitialState(), 0));
            builder.initialStates[i] = builder.getState(REDUCE_MASK, seed);
        }
        
        builder.createReduceActions();
        
        System.out.println("States: " + builder.itemSets.size());
        
        //builder.visualize();
        
        builder.printParseTable();
        return builder.getParseTable();
    }

    private ParseTable getParseTable() {
        int[] productionInfo = new int[grammar.prods.size()];
        for(int i=0;i<productionInfo.length;++i) {
            Prod prod = grammar.prods.get(i);
            productionInfo[i] = prod.lhs;
        }
        
        int[][] actionTable = new int[transitions.size()][];
        int[][] gotoTable = new int[transitions.size()][];
        for(int i=0;i<transitions.size();++i) {
            final int[] actions = new int[grammar.terminalNames.length]; 
            Arrays.fill(actions, ERROR_ACTION);
            actionTable[i] = actions;
            final int[] gotos = new int[grammar.nonterminalNames.length];
            Arrays.fill(gotos, ERROR_ACTION);
            gotoTable[i] = gotos;
            final TIntIntHashMap stackOpMap = stackOps.get(i); 
            transitions.get(i).forEachEntry(new TIntIntProcedure() {
                @Override
                public boolean execute(int a, int b) {
                    int action = b | stackOpMap.get(a);
                    if(a >= 0)
                        actions[a] = action;
                    else
                        gotos[~a] = action;
                    return true;
                }
            });
        }
        
        String[] stateDescriptions = new String[itemSets.size()];
        for(int i=0;i<stateDescriptions.length;++i) {
            final StringBuilder b = new StringBuilder();
            b.append(itemSets.get(i).toString(grammar));
            transitions.get(i).forEachEntry(new TIntIntProcedure() {
                @Override
                public boolean execute(int symbol, int action) {
                    if(symbol >= 0) {
                        b.append("\n    ").append(grammar.terminalNames[symbol]).append(" ->");
                        if((action & REDUCE_MASK) == 0) {
                            if((action & POP_MASK) != 0)
                                b.append(" POP");
                            if((action & PUSH_MASK) != 0)
                                b.append(" PUSH");
                            b.append(" SHIFT(").append(action&STATE_MASK).append(")");
                        }
                        else {
                            if(action == 0xfffffffe)
                                b.append(" ACCEPT");
                            else
                                b.append(" REDUCE(").append(action&STATE_MASK).append(")");
                        }
                    }
                    else {
                        b.append("\n    ").append(grammar.nonterminalNames[~symbol]).append(" ->")
                         .append(" GOTO(").append(action).append(")");
                    }
                    return true;
                }
            });
            stateDescriptions[i] = b.toString();
        }
        
        //printParseTable();
        return new ParseTable(itemSets.size(), actionTable, gotoTable, productionInfo,
                initialStates, stateDescriptions);
    }

    private void printParseTable() {
        final ItemSet[] stateSets = new ItemSet[states.size()];
        states.forEachEntry(new TObjectIntProcedure<ItemSet>() {
            @Override
            public boolean execute(ItemSet a, int b) {
                stateSets[b] = a;
                return true;
            }
        });
        for(int i=0;i<stateSets.length;++i) {
            System.out.println("--- State " + i + " ---");
            System.out.println(stateSets[i].toString(grammar));
            final TIntIntHashMap stackOp = stackOps.get(i);
            transitions.get(i).forEachEntry(new TIntIntProcedure() {
                @Override
                public boolean execute(int a, int b) {
                    int sOp = stackOp.get(a);
                    System.out.print(grammar.getName(a) + " -> ");
                    if(sOp != 0) {
                        System.out.print("[");
                        if((sOp & PUSH_MASK) != 0) {
                            sOp ^= PUSH_MASK;
                            System.out.print("PUSH ");
                        }
                        if((sOp & POP_MASK) != 0) {
                            sOp ^= POP_MASK;
                            System.out.print("POP ");
                        }
                        if(sOp != 0)
                            System.out.print(sOp);
                        System.out.print("] ");
                    }
                    if((b & REDUCE_MASK) != 0) {
                        b ^= REDUCE_MASK;
                        System.out.println("reduce " + b); // grammar.prods.get(~b).toString(grammar));
                    }
                    else {
                        System.out.println("shift " + b);
                    }
                    return true;
                }
            });
        }
    }
}
