package org.simantics.scl.compiler.internal.elaboration.constraints;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import org.simantics.scl.compiler.common.exceptions.InternalCompilerError;
import org.simantics.scl.compiler.elaboration.expressions.EVariable;
import org.simantics.scl.compiler.top.SCLCompilerConfiguration;
import org.simantics.scl.compiler.types.TCon;
import org.simantics.scl.compiler.types.TMetaVar;
import org.simantics.scl.compiler.types.TPred;
import org.simantics.scl.compiler.types.Type;
import org.simantics.scl.compiler.types.Types;
import org.simantics.scl.compiler.types.exceptions.UnificationException;
import org.simantics.scl.compiler.types.util.TConComparator;
import org.simantics.scl.compiler.types.util.TypeUnparsingContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import gnu.trove.map.hash.THashMap;
import gnu.trove.set.hash.THashSet;

public class ConstraintSolver {

    private static final Logger LOGGER = LoggerFactory.getLogger(ConstraintSolver.class);
    
    public static THashSet<TCon> DEFAULTS_IGNORE = new THashSet<TCon>(); 
    public static THashMap<List<TCon>, Type> DEFAULTS = new THashMap<List<TCon>, Type>();
    
    static {
        DEFAULTS_IGNORE.add(Types.SHOW);
        DEFAULTS_IGNORE.add(Types.con("Json2", "JSON"));
        DEFAULTS_IGNORE.add(Types.VEC_COMP);
        DEFAULTS_IGNORE.add(Types.ORD);
        DEFAULTS_IGNORE.add(Types.TYPEABLE);
        DEFAULTS_IGNORE.add(Types.SERIALIZABLE);
        DEFAULTS_IGNORE.add(Types.con("Formatting", "FormatArgument"));
        
        DEFAULTS.put(Arrays.asList(Types.ADDITIVE), Types.INTEGER);
        DEFAULTS.put(Arrays.asList(Types.RING), Types.INTEGER);
        DEFAULTS.put(Arrays.asList(Types.ORDERED_RING), Types.INTEGER);
        DEFAULTS.put(Arrays.asList(Types.INTEGRAL), Types.INTEGER);       
        DEFAULTS.put(Arrays.asList(Types.REAL), Types.DOUBLE);
        
        { // Some R -module specific hacks
            TCon RCOMPATIBLE = Types.con("R/RExp", "RCompatible");
            TCon REXP = Types.con("R/RExp", "RExp");
            DEFAULTS.put(Arrays.asList(RCOMPATIBLE), REXP);
            DEFAULTS.put(Arrays.asList(RCOMPATIBLE, Types.ADDITIVE), Types.DOUBLE);
            DEFAULTS.put(Arrays.asList(RCOMPATIBLE, Types.RING), Types.DOUBLE);
            DEFAULTS.put(Arrays.asList(RCOMPATIBLE, Types.ORDERED_RING), Types.DOUBLE);
            DEFAULTS.put(Arrays.asList(RCOMPATIBLE, Types.INTEGRAL), Types.DOUBLE);
            DEFAULTS.put(Arrays.asList(RCOMPATIBLE, Types.REAL), Types.DOUBLE);
        }
    }
    
    public static ReducedConstraints solve(
            ConstraintEnvironment environment,
            ArrayList<TPred> given,
            ArrayList<EVariable> demands,
            boolean applyDefaults) {
        TypeUnparsingContext tuc = SCLCompilerConfiguration.TRACE_CONSTRAINT_SOLVER ? 
                new TypeUnparsingContext() : null;
        if(SCLCompilerConfiguration.TRACE_CONSTRAINT_SOLVER) {
            LOGGER.info("");
            LOGGER.info("GIVEN:");
            for(TPred g : given)
                LOGGER.info("    " + g.toString(tuc));
            LOGGER.info("DEMANDS:");
            for(EVariable demand : demands)
                LOGGER.info("    " + demand.getType().toString(tuc));
            LOGGER.info("==>");
        }
        
        ConstraintSet cs = new ConstraintSet(environment);
        ArrayList<Constraint> givenConstraints =
                new ArrayList<Constraint>(given.size());
        
        for(TPred g : given)
            givenConstraints.add(cs.addGiven(g));
        
        for(EVariable d : demands)
            cs.addDemand(d);
        
        cs.reduce();
        
        ArrayList<Constraint> unsolvedConstraints = new ArrayList<Constraint>();
        ArrayList<Constraint> solvedConstraints = new ArrayList<Constraint>();
        cs.collect(unsolvedConstraints, solvedConstraints);
        
        // Apply defaults
        if(applyDefaults && !unsolvedConstraints.isEmpty()) {
            ArrayList<ArrayList<Constraint>> groups = 
                    groupConstraintsByCommonMetavars(unsolvedConstraints);
            if(SCLCompilerConfiguration.TRACE_CONSTRAINT_SOLVER) {
                LOGGER.info("DEFAULT GROUPS:");
                for(ArrayList<Constraint> group : groups) {
                    for(Constraint c : group)
                        LOGGER.info("    " + c.constraint.toString(tuc));
                    LOGGER.info("    --");
                }
            }
            
            unsolvedConstraints.clear();
            ArrayList<Constraint> newSolvedConstraints = new ArrayList<Constraint>(unsolvedConstraints.size() + solvedConstraints.size()); 
            for(ArrayList<Constraint> group : groups) {
                // Special rule for Typeable
                /*if(group.size() == 1 && group.get(0).constraint.typeFunction == Types.TYPEABLE) {
                    Type parameter = Types.canonical(group.get(0).constraint.parameters[0]);
                    if(parameter instanceof TMetaVar) {
                        try {
                            ((TMetaVar)parameter).setRef(Types.INTEGER);
                        } catch (UnificationException e) {
                            throw new InternalCompilerError(e);
                        }

                        Constraint constraint = group.get(0);
                        Reduction reduction = environment.reduce(constraint.constraint);
                        if(reduction.parameters.length > 0)
                            throw new InternalCompilerError();
                        constraint.setGenerator(Constraint.STATE_HAS_INSTANCE,
                                reduction.generator, reduction.parameters);
                        newSolvedConstraints.add(constraint);
                    }
                    continue;
                }*/
                
                // Standard rule
                ArrayList<TCon> cons = new ArrayList<TCon>(group.size());
                for(Constraint constraint : group)
                    if(!DEFAULTS_IGNORE.contains(constraint.constraint.typeClass))
                        cons.add(constraint.constraint.typeClass);
                Collections.sort(cons, TConComparator.INSTANCE);
                
                Type defaultType = DEFAULTS.get(cons);
                if(defaultType != null) {
                    TMetaVar var = null;
                    for(Constraint constraint : group) {
                        if(constraint.constraint.parameters.length != 1) {
                            var = null;
                            break;
                        }
                        Type parameter = Types.canonical(constraint.constraint.parameters[0]);
                        if(!(parameter instanceof TMetaVar)) {
                            var = null;
                            break;
                        }
                        if(var == null)
                            var = (TMetaVar)parameter;
                    }
                    if(var != null) {
                        try {
                            var.setRef(defaultType);
                        } catch (UnificationException e) {
                            throw new InternalCompilerError();
                        }
                        for(Constraint constraint : group) {
                            Reduction reduction = environment.reduce(constraint.demandLocation, constraint.constraint);
                            if(reduction.demands.length > 0)
                                throw new InternalCompilerError();
                            constraint.setGenerator(Constraint.STATE_HAS_INSTANCE,
                                    reduction.generator, reduction.parameters);
                            newSolvedConstraints.add(constraint);
                        }  
                        continue;
                    }                                          
                }
                unsolvedConstraints.addAll(group);
            }
            
            Collections.sort(unsolvedConstraints, ConstraintComparator.INSTANCE);
            
            newSolvedConstraints.addAll(solvedConstraints);
            solvedConstraints = newSolvedConstraints;                    
        }

        if(SCLCompilerConfiguration.TRACE_CONSTRAINT_SOLVER) {
            LOGGER.info("UNSOLVED:");
            for(Constraint c : unsolvedConstraints)
                LOGGER.info("    " + c.constraint.toString(tuc));  
            LOGGER.info("SOLVED:");
            for(Constraint c : solvedConstraints)
                LOGGER.info("    " + c.constraint.toString(tuc) + " <= " + c.generator);
            //LOGGER.info("APPLY DEFAULTS: " + applyDefaults);
        }
        
        return new ReducedConstraints(givenConstraints, 
                solvedConstraints,
                unsolvedConstraints);
    }

    private static <K,V> void add(
            THashMap<K, ArrayList<V>> map, 
            K k, V v) {
        ArrayList<V> list = map.get(k);
        if(list == null) {
            list = new ArrayList<V>(2);
            map.put(k, list);
        }
        list.add(v);
    }
    
    private static TMetaVar canonical(
            THashMap<TMetaVar, TMetaVar> cMap,
            TMetaVar v) {
        while(true) {
            TMetaVar temp = cMap.get(v);
            if(temp == null)
                return v;
            else
                v = temp;
        }
    }       
    
    private static void merge(
            THashMap<TMetaVar, TMetaVar> cMap,
            THashMap<TMetaVar, ArrayList<Constraint>> groups,
            TMetaVar a,
            TMetaVar b) {
        if(a != b) {
            cMap.put(b, a);
            ArrayList<Constraint> listB = groups.remove(b);
            if(listB != null) {
                ArrayList<Constraint> listA = groups.get(a);
                if(listA == null)
                    groups.put(a, listB);
                else
                    listA.addAll(listB);
            }
        }
    }
    
    private static ArrayList<ArrayList<Constraint>> groupConstraintsByCommonMetavars(
            ArrayList<Constraint> constraints) {
        THashMap<TMetaVar, ArrayList<Constraint>> groups =
                new THashMap<TMetaVar, ArrayList<Constraint>>();
        THashMap<TMetaVar, TMetaVar> cMap = new THashMap<TMetaVar, TMetaVar>();
        
        ArrayList<TMetaVar> vars = new ArrayList<TMetaVar>(); 
        for(Constraint constraint : constraints) {
            constraint.constraint.collectMetaVars(vars);
            if(vars.isEmpty()) {
                add(groups, null, constraint);
            } 
            else {
                TMetaVar first = canonical(cMap, vars.get(0));
                for(int i=1;i<vars.size();++i)
                    merge(cMap, groups, first, canonical(cMap, vars.get(i)));
                vars.clear();
                add(groups, first, constraint);                
            }
        }
        
        return new ArrayList<ArrayList<Constraint>>(groups.values());
    }
    
}
