package org.simantics.scl.compiler.internal.elaboration.subsumption;

import java.util.ArrayDeque;
import java.util.ArrayList;

import org.simantics.scl.compiler.errors.ErrorLog;
import org.simantics.scl.compiler.errors.Locations;
import org.simantics.scl.compiler.internal.types.effects.EffectIdMap;
import org.simantics.scl.compiler.types.TMetaVar;
import org.simantics.scl.compiler.types.TUnion;
import org.simantics.scl.compiler.types.Type;
import org.simantics.scl.compiler.types.Types;
import org.simantics.scl.compiler.types.util.Polarity;
import org.simantics.scl.compiler.types.util.TypeUnparsingContext;

import gnu.trove.map.hash.THashMap;
import gnu.trove.set.hash.THashSet;

public class SubSolver {

    public static final boolean DEBUG = false;
    
    long globalLoc;
    ErrorLog errorLog;
    ArrayList<Subsumption> subsumptions;
    ArrayList<Type> potentialSingletonEffects;
    EffectIdMap effectIds = new EffectIdMap();
    THashMap<TMetaVar,Var> vars = new THashMap<TMetaVar,Var>();
    ArrayDeque<Var> dirtyQueue = new ArrayDeque<Var>();
    
    TypeUnparsingContext tuc = new TypeUnparsingContext();
    
    public SubSolver(ErrorLog errorLog, ArrayList<Subsumption> subsumptions,
            ArrayList<Type> potentialSingletonEffects,
            long globalLoc) {
        this.errorLog = errorLog;
        this.subsumptions = subsumptions;
        this.potentialSingletonEffects = potentialSingletonEffects;
        this.globalLoc = globalLoc;
        
        //for(Subsumption sub : subsumptions)
        //    System.out.println("[" + sub.a.toString(tuc) + " < " + sub.b.toString(tuc) + "]");
    }    
    
    public void solve() {
        //System.out.println("--------------------------------------------------");
        //printSubsumptions();
        createVar();
        //print();
        reduceChains();
        propagateUpperBounds();
        checkLowerBounds();
        //errorFromUnsolvedEquations();
        //print();
    }
    
    private void markAllDirty() {
        for(Var v : vars.values())
            v.markDirty();
    }
    
    private Var getOrCreateVar(TMetaVar mv) {
        Var var = vars.get(mv);
        if(var == null) {
            var = new Var(mv, mv.toString(tuc).substring(1), this);
            vars.put(mv, var);
        }
        return var;
    }

    private void addVar(Type t) {
        t = Types.canonical(t);
        if(t instanceof TMetaVar)
            getOrCreateVar((TMetaVar)t);
        else if(t instanceof TUnion)
            for(Type c : ((TUnion)t).effects)
                addVar(c);
    }

    ArrayList<TMetaVar> aVars = new ArrayList<TMetaVar>();
    
    private void addSubsumption(long loc, Type a, Type b) {
        int aCons = effectIds.toId(a, aVars);
        if(!aVars.isEmpty()) {
            for(TMetaVar var : aVars)
                getOrCreateVar((TMetaVar)var).addUpperBound(toVUnion(b));
            aVars.clear();
        }        
        if(aCons != 0) {
            VUnion u = toVUnion(b);
            if(u.vars.isEmpty()) {
                testSubsumption(loc, aCons, u.con);
            }
            else
                u.makeLowerBound(aCons);
        }
    }

    ArrayList<TMetaVar> bVars = new ArrayList<TMetaVar>();
    
    private VUnion toVUnion(Type a) {
        int cons = effectIds.toId(a, bVars);
        ArrayList<Var> vars = new ArrayList<Var>(bVars.size());
        for(TMetaVar v : bVars)
            vars.add(getOrCreateVar(v));
        bVars.clear();
        return new VUnion(cons, vars);
    }

    private void createVar() {
        for(Subsumption sub : subsumptions)
            addSubsumption(sub.loc, sub.a, sub.b);
        // In some cases there might be types that are not part of any subsumption, for example related to typeOf
        for(Type t : potentialSingletonEffects)
            addVar(t);
    }
    
    private void reduceChains() {
        markAllDirty();
        while(true) {
            Var v = dirtyQueue.poll();
            if(v == null)
                break;
            
            reduceChains(v);
            v.dirty = false;
        }
    }
    
    private void reduceChains(Var v) {
        if(v.constLowerBound == v.constUpperBound) {
            v.replaceWith(v.constLowerBound);
            return;
        }
        Polarity p = v.original.getPolarity();
        if(!p.isNegative() && v.complexLowerBounds.isEmpty()) {
            if(v.simpleLowerBounds.isEmpty()) {
                if((v.constLowerBound&v.constUpperBound) != v.constLowerBound)
                    errorLog.log(globalLoc, "Subsumption failed.");
                v.replaceWith(v.constLowerBound);
                return;
            }
            else if(v.simpleLowerBounds.size() == 1 && v.constLowerBound == 0) {
                v.replaceDownwards(v.simpleLowerBounds.get(0));
                return;
            }
        }
        if(p == Polarity.NEGATIVE && v.complexUpperBounds.isEmpty()) {
            if(v.simpleUpperBounds.isEmpty()) {
                if((v.constLowerBound&v.constUpperBound) != v.constLowerBound)
                    errorLog.log(globalLoc, "Subsumption failed.");
                v.replaceWith(v.constUpperBound);
                return;
            }
            else if(v.simpleUpperBounds.size() == 1 && v.constUpperBound == EffectIdMap.MAX) {
                v.replaceUpwards(v.simpleUpperBounds.get(0));
                return;
            }
        }
    }
    
    private void propagateUpperBounds() {
        markAllDirty();        
        while(true) {
            Var v = dirtyQueue.poll();
            if(v == null)
                break;
            
            int upperApprox = v.constUpperBound;
            for(Var u : v.simpleUpperBounds)
                upperApprox &= u.upperApprox;
            for(VUnion u : v.complexUpperBounds)
                upperApprox &= u.getUpperApprox();
            
            if(upperApprox != v.upperApprox) {
                v.upperApprox = upperApprox;
                for(Var u : v.simpleLowerBounds)
                    u.markDirty();
                for(VUnion u : v.complexLowerBounds)
                    if(u.low != null)
                        u.low.markDirty();
            }   
            
            v.dirty = false;
        }
    }
    
    private void checkLowerBounds() {
        THashSet<VUnion> lowerBounds = new THashSet<VUnion>(); 
        for(Var v : vars.values()) {
            if((v.constLowerBound & v.upperApprox) != v.constLowerBound)
                testSubsumption(globalLoc, v.constLowerBound, v.upperApprox);
            for(VUnion u : v.complexLowerBounds)
                if(u.low == null)
                    lowerBounds.add(u);
        }
        for(VUnion u : lowerBounds)
            if(u.getUpperApprox() != EffectIdMap.MAX)
                errorLog.log(globalLoc, "Subsumption failed.");
    }

    private void errorFromUnsolvedEquations() {
        for(Var v : vars.values()) {
            if(!v.isFree()) {
                errorLog.log(globalLoc, "Couldn't simplify all effect subsumptions away. " +
                		"The current compiler cannot handle this situation. Try adding more type annotations.");
                break;
            }
        }
    }
    
    private void print() {
        for(Var v : vars.values()) {
            System.out.println(v.name + 
                    ", " + v.original.getPolarity() + 
                    (v.simpleLowerBounds.isEmpty() ? "" : ", simpleLowerRefs: " + v.simpleLowerBounds.size()) +  
                    (v.complexLowerBounds.isEmpty() ? "" : ", complexLowerRefs: " + v.complexLowerBounds.size()) +
                    ", " + v.original.getKind());
            if(v.constLowerBound != EffectIdMap.MIN)
                System.out.println("    > " + v.constLowerBound);
            if(v.constUpperBound != EffectIdMap.MAX)
                System.out.println("    < " + v.constUpperBound);
            for(Var u : v.simpleUpperBounds)
                System.out.println("    < " + u.name);
            for(VUnion u : v.complexUpperBounds)
                System.out.println("    << " + u);
        }
    }
    
    private void printSubsumptions() {
        for(Subsumption sub : subsumptions)
            System.out.println("[" + sub.a.toString(tuc) + " < " + sub.b.toString(tuc) + " : " + 
                    Locations.beginOf(sub.loc) + "," + Locations.endOf(sub.loc) + "]");
    }

    private void testSubsumption(long location, int a, int b) {
        int extraEffects = a & (~b);
        if(extraEffects != 0)
            subsumptionFailed(location, extraEffects);
    }
    
    private void subsumptionFailed(long location , int effects) {
        errorLog.log(location, "Side-effect " + effectIds.toType(effects) + " is forbidden here.");
    }
}
