package org.simantics.scl.compiler.internal.elaboration.subsumption2;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;

import org.simantics.scl.compiler.errors.ErrorLog;
import org.simantics.scl.compiler.internal.elaboration.subsumption.Subsumption;
import org.simantics.scl.compiler.internal.elaboration.subsumption2.SubsumptionGraph.LowerBoundSource;
import org.simantics.scl.compiler.internal.elaboration.subsumption2.SubsumptionGraph.Node;
import org.simantics.scl.compiler.internal.elaboration.subsumption2.SubsumptionGraph.PartOfUnion;
import org.simantics.scl.compiler.internal.elaboration.subsumption2.SubsumptionGraph.Sub;
import org.simantics.scl.compiler.internal.elaboration.subsumption2.SubsumptionGraph.UnionNode;
import org.simantics.scl.compiler.internal.elaboration.subsumption2.SubsumptionGraph.VarNode;
import org.simantics.scl.compiler.internal.types.effects.EffectIdMap;
import org.simantics.scl.compiler.types.TMetaVar;
import org.simantics.scl.compiler.types.util.Polarity;

import gnu.trove.map.hash.THashMap;
import gnu.trove.map.hash.TIntIntHashMap;
import gnu.trove.set.hash.THashSet;

public class SubSolver2 {
    public static final boolean DEBUG = false;
    
    // Input
    private final ErrorLog errorLog;
    private final ArrayList<Subsumption> subsumptions;

    //
    private final EffectIdMap effectIds = new EffectIdMap();
    private final THashMap<TMetaVar, VarNode> varNodeMap = new THashMap<TMetaVar, VarNode>();
    private final ArrayList<UnionNode> unionNodes = new ArrayList<UnionNode>(); 

    private static TIntIntHashMap STATISTICS = new TIntIntHashMap();
    
    private SubSolver2(ErrorLog errorLog, ArrayList<Subsumption> subsumptions) {
        this.errorLog = errorLog;
        this.subsumptions = subsumptions;
        /*if(subsumptions.size() == 1) {
            TypeUnparsingContext tuc = new TypeUnparsingContext();
            Subsumption sub = subsumptions.get(0);
            if(sub.a instanceof TCon && sub.b instanceof TCon)
                System.out.println("caseCC");
            else if(sub.a instanceof TMetaVar && sub.b instanceof TCon)
                System.out.println("caseMC");
            else if(sub.a instanceof TVar && sub.b instanceof TCon)
                System.out.println("caseVC");
            System.out.println("    " + sub.a.toString(tuc) + " < " + sub.b.toString(tuc));
        }
        synchronized(STATISTICS) {
            STATISTICS.adjustOrPutValue(subsumptions.size(), 1, 1);
            showStatistics();
        }*/
    }
    
    public static void showStatistics() {
        System.out.println("---");
        int[] keys = STATISTICS.keys();
        Arrays.sort(keys);
        int sum = 0;
        for(int key : keys)
            sum += STATISTICS.get(key);
        for(int key : keys) {
            int value = STATISTICS.get(key);
            System.out.println(key + ": " + value + " (" + (value*100.0/sum) + "%)");
        }
    }

    private static boolean subsumes(int a, int b) {
        return (a&b) == a;
    }

    private void processSubsumptions() {
        ArrayList<TMetaVar> aVars = new ArrayList<TMetaVar>(2);
        ArrayList<TMetaVar> bVars = new ArrayList<TMetaVar>(2);
        for(Subsumption subsumption : subsumptions) {
            int aCons = effectIds.toId(subsumption.a, aVars);
            int bCons = effectIds.toId(subsumption.b, bVars);

            if(bVars.isEmpty()) {
                if(!subsumes(aCons, bCons)) {
                    reportSubsumptionFailure(subsumption.loc, aCons, bCons);
                    continue;
                }
                for(TMetaVar aVar : aVars)
                    getOrCreateNode(aVar).upperBound &= bCons;
            }
            else {
                Node bNode;
                if(bVars.size() == 1 && bCons == 0)
                    bNode = getOrCreateNode(bVars.get(0));
                else
                    bNode = createUnion(subsumption.loc, bCons, bVars);
                if(aCons != 0)
                    setLowerBound(subsumption.loc, aCons, bNode);
                for(TMetaVar aVar : aVars)
                    new Sub(getOrCreateNode(aVar), bNode);
                bVars.clear();
            }
            aVars.clear();
        }
    }

    private void setLowerBound(long location, int lower, Node node) {
        node.lowerBound |= lower;
        node.addLowerBoundSource(location, lower);
    }

    private UnionNode createUnion(long location, int cons, ArrayList<TMetaVar> vars) {
        UnionNode node = new UnionNode(location, cons);
        for(TMetaVar var : vars)
            new PartOfUnion(getOrCreateNode(var), node);
        unionNodes.add(node);
        return node;
    }

    private VarNode getOrCreateNode(TMetaVar var) {
        VarNode node = varNodeMap.get(var);
        if(node == null) {
            node = new VarNode(var);
            varNodeMap.put(var, node);
        }
        return node;
    }

    public boolean solve() {
        //System.out.println("------------------------------------------------------");
        int errorCount = errorLog.getErrorCount();

        // Check errors
        processSubsumptions();
        propagateUpperBounds();
        checkLowerBounds();
        
        if(DEBUG)
            print();

        if(errorLog.getErrorCount() != errorCount)
            return false;

        // Simplify constraints
        stronglyConnectedComponents();
        propagateLowerBounds();
        simplify();

        if(DEBUG)
            print();
        
        return true;
    }
    
    private void touchNeighborhood(VarNode node) {
        for(Sub cur=node.lower;cur!=null;cur=cur.bNext)
            touch(cur.a);
        for(Sub cur=node.upper;cur!=null;cur=cur.aNext)
            touch(cur.b);
        for(PartOfUnion cur=node.partOf;cur!=null;cur=cur.aNext)
            touch(cur.b);
    }
    
    private void touchNeighborhood(UnionNode node) {
        for(Sub cur=node.lower;cur!=null;cur=cur.bNext)
            touch(cur.a);
        for(PartOfUnion cur=node.parts;cur!=null;cur=cur.bNext)
            touch(cur.a);
    }

    THashSet<Node> set = new THashSet<Node>(); 
    private void simplify() {
        for(VarNode node : sortedNodes) {
            if(node.index == SubsumptionGraph.REMOVED)
                continue;
            activeSet.add(node);
            queue.addLast(node);
        }
        for(UnionNode node : unionNodes) {
            if(node.constPart == SubsumptionGraph.REMOVED)
                continue;
            activeSet.add(node);
            queue.addLast(node);
        }
        
        while(!queue.isEmpty()) {
            Node node_ = queue.removeFirst();
            activeSet.remove(node_);
            if(node_ instanceof VarNode) {
                VarNode node = (VarNode)node_;
                if(node.index == SubsumptionGraph.REMOVED)
                    continue;
                if(node.lowerBound == node.upperBound) {
                    if(DEBUG)
                        System.out.println("replace " + toName(node) + " by " + effectIds.toType(node.lowerBound) + ", node.lowerBound == node.upperBound");
                    touchNeighborhood(node);
                    node.removeConstantNode(effectIds, node.lowerBound);
                    continue;
                }
                for(Sub cur=node.upper;cur!=null;cur=cur.aNext)
                    if(cur.b == node)
                        cur.remove();
                if(node.upper != null && node.upper.aNext != null) {
                    for(Sub cur=node.upper;cur!=null;cur=cur.aNext)
                        if(!set.add(cur.b) || subsumes(node.upperBound, cur.a.lowerBound)) {
                            touch(cur.b);
                            cur.remove();
                        }
                    set.clear();
                }
                if(node.lower != null && node.lower.bNext != null) {
                    for(Sub cur=node.lower;cur!=null;cur=cur.bNext)
                        if(!set.add(cur.a) || subsumes(cur.a.upperBound, node.lowerBound)) {
                            touch(cur.a);
                            cur.remove();
                        }
                    set.clear();
                }
                Polarity polarity = node.getPolarity();
                if(!polarity.isNegative()) { 
                    if(node.partOf == null) {
                        if(node.lower == null) {
                            // No low nodes
                            if(DEBUG)
                                System.out.println("replace " + toName(node) + " by " + effectIds.toType(node.lowerBound) + ", polarity=" + polarity + ", no low nodes");
                            touchNeighborhood(node);
                            node.removeConstantNode(effectIds, node.lowerBound);
                            continue;
                        }
                        else if(node.lower.bNext == null) {
                            // Exactly one low node
                            VarNode low = node.lower.a;

                            if(low.lowerBound == node.lowerBound) {
                                node.lower.remove();
                                if(DEBUG)
                                    System.out.println("replace " + toName(node) + " by " + toName(low) + ", polarity=" + polarity + ", just one low node");
                                touchNeighborhood(node);
                                node.replaceBy(low);
                                continue;
                            }
                        }
                    }
                }
                else if(polarity == Polarity.NEGATIVE) {
                    if(node.upper != null && node.upper.aNext == null) {
                        Node high = node.upper.b;
                        if(node.upperBound == high.upperBound && high instanceof VarNode) {
                            VarNode varHigh = (VarNode)high;
                            
                            node.upper.remove();
                            if(DEBUG)
                                System.out.println("replace " + toName(node) + " by " + toName(varHigh) + ", polarity=" + polarity + ", just one high node");
                            touchNeighborhood(node);
                            node.replaceBy(varHigh);
                            continue;
                        }
                    }
                }
            }
            else {
                UnionNode union = (UnionNode)node_;
                if(union.constPart == SubsumptionGraph.REMOVED)
                    continue;
                if(union.lower == null) {
                    int low = union.constPart;
                    for(PartOfUnion partOf=union.parts;partOf!=null;partOf=partOf.bNext)
                        low |= partOf.a.lowerBound;

                    if(subsumes(union.lowerBound, low)) {
                        if(DEBUG) {
                            System.out.print("remove union, " + constToString(union.lowerBound) + " < " + constToString(low));
                            printUnion(union);
                        }
                        touchNeighborhood(union);
                        union.remove();
                        continue;
                    }
                }
                else {
                    for(Sub cur=union.lower;cur!=null;cur=cur.bNext) {
                        VarNode lowNode = union.lower.a;
                        for(PartOfUnion partOf=union.parts;partOf!=null;partOf=partOf.bNext)
                            if(partOf.a == lowNode) {
                                cur.remove();
                                touch(union);
                                break;
                            }
                    }
                }
            }
        }
    }

    private void checkLowerBounds() {
        for(VarNode node : varNodeMap.values())
            checkLowerBound(node);
        for(UnionNode node : unionNodes) 
            checkLowerBound(node);
    }

    private void checkLowerBound(Node node) {
        int upperBound = node.upperBound;
        if(!subsumes(node.lowerBound, upperBound))
            for(LowerBoundSource source=node.lowerBoundSource;source!=null;source=source.next)
                if(!subsumes(source.lower, upperBound))
                    reportSubsumptionFailure(source.location, source.lower, upperBound);
        node.lowerBoundSource = null;
    }

    private void propagateLowerBounds() {
        for(VarNode node : sortedNodes) {
            for(Sub cur=node.lower;cur!=null;cur=cur.bNext)
                node.lowerBound |= cur.a.lowerBound;
        }
        if(!unionNodes.isEmpty()) {
            for(UnionNode node : unionNodes) {
                if(node.parts != null && node.parts.bNext != null) {
                    // Remove duplicate parts of the union, might be there because of merging of strongly connected components
                    THashSet<VarNode> varSet = new THashSet<VarNode>(); 
                    for(PartOfUnion cur=node.parts;cur!=null;cur=cur.bNext)
                        if(!varSet.add(cur.a))
                            cur.remove();
                }
                
                for(Sub cur=node.lower;cur!=null;cur=cur.bNext)
                    node.lowerBound |= cur.a.lowerBound;

                activeSet.add(node);
                queue.addLast(node);
            }
            while(!queue.isEmpty()) {
                Node node = queue.removeFirst();
                activeSet.remove(node);
                int lowerBound = node.lowerBound;

                if(node instanceof VarNode) {
                    VarNode var = (VarNode)node;
                    for(Sub cur=var.upper;cur!=null;cur=cur.aNext) {
                        Node highNode = cur.b;
                        int newLowerBound = highNode.lowerBound & lowerBound;
                        if(newLowerBound != highNode.lowerBound) {
                            highNode.lowerBound = newLowerBound;
                            touch(highNode);
                        }
                    }
                }
                else {
                    UnionNode union = (UnionNode)node;
                    for(PartOfUnion cur=union.parts;cur!=null;cur=cur.bNext) {
                        int residual = lowerBound & (~union.constPart);
                        for(PartOfUnion cur2=union.parts;cur2!=null;cur2=cur2.bNext)
                            if(cur2 != cur)
                                residual = lowerBound & (~cur2.a.upperBound);
                        VarNode partNode = cur.a;
                        int newLowerBound = partNode.lowerBound | residual;
                        if(newLowerBound != partNode.lowerBound) {
                            partNode.lowerBound = newLowerBound;
                            touch(partNode);
                        }
                    }
                }
            }
        }
    }

    private void reportSubsumptionFailure(long location, int lowerBound, int upperBound) {
        errorLog.log(location, "Side-effect " + effectIds.toType(lowerBound & (~upperBound)) + " is forbidden here.");        
    }

    private final THashSet<Node> activeSet = new THashSet<>();
    private final ArrayDeque<Node> queue = new ArrayDeque<>(); 

    private void touch(Node node) {
        if(activeSet.add(node))
            queue.addLast(node);
    }
    
    private void propagateUpperBounds() {
        for(VarNode node : varNodeMap.values())
            if(node.upperBound != EffectIdMap.MAX) {
                activeSet.add(node);
                queue.addLast(node);
            }

        while(!queue.isEmpty()) {
            Node node = queue.removeFirst();
            activeSet.remove(node);
            int upperBound = node.upperBound;

            if(node instanceof VarNode) {
                // Upper bounds for unions are not calculated immediately
                for(PartOfUnion cur=((VarNode)node).partOf;cur!=null;cur=cur.aNext) {
                    UnionNode union = cur.b;
                    touch(union);
                }
            }
            else {
                // New upper bound for union is calculated here
                UnionNode union = (UnionNode)node;
                int newUpperBound = union.constPart;
                for(PartOfUnion cur=union.parts;cur!=null;cur=cur.bNext)
                    newUpperBound |= cur.a.upperBound;
                if(newUpperBound != upperBound)
                    node.upperBound = upperBound = newUpperBound;
                else
                    continue; // No changes in upper bound, no need to propagate
            }

            // Propagate upper bound to smaller variables
            for(Sub cur=node.lower;cur!=null;cur=cur.bNext) {
                VarNode lowNode = cur.a;
                int newUpperBound = lowNode.upperBound & upperBound;
                if(newUpperBound != lowNode.upperBound) {
                    lowNode.upperBound = newUpperBound;
                    touch(lowNode);
                }
            }
        }
    }

    int curIndex;
    private void stronglyConnectedComponents() {
        sortedNodes = new ArrayList<VarNode>(varNodeMap.size());
        for(VarNode node : varNodeMap.values())
            node.index = -1;
        for(VarNode node : varNodeMap.values())
            if(node.index == -1) {
                curIndex = 0;
                stronglyConnectedComponents(node);
            }
    }

    ArrayList<VarNode> sortedNodes;
    ArrayList<VarNode> stack = new ArrayList<VarNode>(); 
    private int stronglyConnectedComponents(VarNode node) {
        int lowindex = node.index = curIndex++;
        stack.add(node);
        for(Sub sub=node.lower;sub != null;sub=sub.bNext) {
            VarNode child = sub.a;
            int childIndex = child.index;
            if(childIndex == -1)
                childIndex = stronglyConnectedComponents(child);
            lowindex = Math.min(lowindex, childIndex);
        }
        if(node.index == lowindex) {
            // root of strongly connected component
            VarNode stackNode = stack.remove(stack.size()-1);
            if(stackNode != node) {
                ArrayList<VarNode> otherInComponent = new ArrayList<VarNode>(4);
                while(stackNode != node) {
                    otherInComponent.add(stackNode);
                    stackNode = stack.remove(stack.size()-1);
                }
                mergeComponent(node, otherInComponent);
            }
            node.index = Integer.MAX_VALUE;
            sortedNodes.add(node);
        }
        return lowindex;
    }

    private void mergeComponent(VarNode root, ArrayList<VarNode> otherInComponent) {
        // There is no need to merge upper bounds, because they have been propagated
        int lowerBound = root.lowerBound;
        for(VarNode node : otherInComponent)
            lowerBound |= node.lowerBound;
        root.lowerBound = lowerBound;

        for(VarNode node : otherInComponent) {
            if(DEBUG)
                System.out.println("replace " + toName(node) + " by " + toName(root));
            node.replaceBy(root);
        }
    }

    // Dummy debugging functions
    private String toName(Node node) {
        return "";
    }
    private void printUnion(UnionNode union) {
    }
    private void print() {
    }
    private String constToString(int cons) {
        return "";
    }
    /*
    private TypeUnparsingContext tuc = new TypeUnparsingContext();
    private THashMap<Node, String> nameMap = new THashMap<Node, String>();
    private char nextChar = 'a';
    
    private String toName(Node node) {
        String name = nameMap.get(node);
        if(name == null) {
            name = new String(new char[] {'?', nextChar++});
            nameMap.put(node, name);
        }
        return name;
    }
    
    private String constToString(int cons) {
        return effectIds.toType(cons).toString(tuc);
    }
    
    private boolean hasContent() {
        for(VarNode node : varNodeMap.values())
            if(node.index != SubsumptionGraph.REMOVED)
//                if(node.lower != null)
                return true;
        for(UnionNode node : unionNodes)
            if(node.constPart != SubsumptionGraph.REMOVED)
                return true;
        return false;
    }
    
    private void print() {
        if(!hasContent())
            return;
        System.out.println("vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv");
        TypeUnparsingContext tuc = new TypeUnparsingContext();
        for(VarNode node : varNodeMap.values()) {
            if(node.index == SubsumptionGraph.REMOVED) {
                //System.out.println(toName(node) + " removed");
                continue;
            }
            System.out.print(toName(node));
            if(node.lowerBound != EffectIdMap.MIN || node.upperBound != EffectIdMap.MAX) {
                System.out.print(" in [");
                if(node.lowerBound != EffectIdMap.MIN)
                    System.out.print(constToString(node.lowerBound));
                System.out.print("..");
                if(node.upperBound != EffectIdMap.MAX) {
                    if(node.upperBound == 0)
                        System.out.print("Pure");
                    else
                        System.out.print(constToString(node.upperBound));
                }
                System.out.print("]");
            }
            System.out.println(" (" + node.getPolarity() + ")");
            
            for(Sub cur=node.upper;cur!=null;cur=cur.aNext) {
                System.out.print("    < ");
                Node highNode = cur.b;
                if(highNode instanceof VarNode) {
                    System.out.println(toName(highNode));
                }
                else
                    printUnion((UnionNode)highNode);
            }
        }
        for(UnionNode node : unionNodes) {
            if(node.lower != null)
                continue;
            System.out.print(constToString(node.lowerBound) + " < ");
            printUnion(node);
        }
        System.out.println("^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^");
    }
    
    private void printUnion(UnionNode union) {
        System.out.print("union(");
        boolean first = true;
        if(union.constPart != EffectIdMap.MIN) {
            System.out.print(constToString(union.constPart));
            first = false;
        }
        for(PartOfUnion part=union.parts;part!=null;part=part.bNext) {
            if(first)
                first = false;
            else
                System.out.print(", ");
            System.out.print(toName(part.a));
        }
        System.out.println(")");
    }
    */
    
    public static void solve(ErrorLog errorLog, ArrayList<Subsumption> subsumptions) {
        new SubSolver2(errorLog, subsumptions).solve();
    }
}
