package org.simantics.scl.compiler.compilation;

import java.util.ArrayList;

import org.simantics.scl.compiler.elaboration.contexts.TypingContext;
import org.simantics.scl.compiler.elaboration.expressions.Variable;
import org.simantics.scl.compiler.elaboration.modules.SCLValue;
import org.simantics.scl.compiler.internal.elaboration.constraints.Constraint;
import org.simantics.scl.compiler.internal.elaboration.utils.StronglyConnectedComponents;
import org.simantics.scl.compiler.types.TPred;
import org.simantics.scl.compiler.types.TVar;

import gnu.trove.impl.Constants;
import gnu.trove.map.hash.THashMap;
import gnu.trove.map.hash.TObjectIntHashMap;
import gnu.trove.set.hash.THashSet;
import gnu.trove.set.hash.TIntHashSet;

/**
 * Schedules the order of type checking.
 * 
 * @author Hannu Niemist&ouml;
 */
public class TypeCheckingScheduler {
    private final CompilationContext compilationContext;
    
    private final ArrayList<TypeInferableDefinition> definitions = new ArrayList<TypeInferableDefinition>();
    private final ArrayList<Runnable> postTypeCheckingRunnables = new ArrayList<Runnable>();
    
    public TypeCheckingScheduler(CompilationContext compilationContext) {
        this.compilationContext = compilationContext;
    }

    public void addTypeInferableDefinition(TypeInferableDefinition definition) {
        definitions.add(definition);
    }
    
    public void addPostTypeCheckingRunnable(Runnable runnable) {
        postTypeCheckingRunnables.add(runnable);
    }
    
    public void schedule() {
        final TObjectIntHashMap<Object> allRefs = 
                new TObjectIntHashMap<Object>(definitions.size(), Constants.DEFAULT_LOAD_FACTOR, -1);
        
        for(int i=0;i<definitions.size();++i)
            for(Object definedObject : definitions.get(i).getDefinedObjects())
                allRefs.put(definedObject, i);
        new StronglyConnectedComponents(definitions.size()) {
            
            TIntHashSet set = new TIntHashSet();
            
            @Override
            protected void reportComponent(int[] component) {
                typeCheck(component);
            }

            @Override
            protected int[] findDependencies(int u) {
                definitions.get(u).collectRefs(allRefs, set);
                int[] result = set.toArray();
                set.clear();
                return result;
            }
            
        }.findComponents();
        
        for(Runnable runnable : postTypeCheckingRunnables)
            runnable.run();
    }
    
    private void typeCheck(int[] component) {
        TypingContext context = new TypingContext(compilationContext);
        context.recursiveValues = new THashSet<SCLValue>();
        
        for(int c : component)
            definitions.get(c).initializeTypeChecking(context);
        for(int c : component)
            definitions.get(c).checkType(context);
        context.solveSubsumptions(definitions.get(component[0]).getLocation());
        for(int c : component)
            definitions.get(c).solveConstraints();
        
        THashSet<TVar> varSet = new THashSet<TVar>(); 
        for(int c : component)
            definitions.get(c).collectFreeTypeVariables(varSet);
        TVar[] vars = varSet.toArray(new TVar[varSet.size()]);
        
        THashSet<TPred> constraintSet = new THashSet<TPred>();
        for(int c : component)
            for(Variable evidence : definitions.get(c).getFreeEvidence())
                constraintSet.add((TPred)evidence.getType());
        TPred[] constraints = constraintSet.toArray(new TPred[constraintSet.size()]);
        
        THashMap<TPred, Constraint> constraintMap = null;
        for(TPred constraint : constraints)
            if(constraint.containsMetaVars()) {
                if(constraintMap == null) {
                    constraintMap = new THashMap<TPred, Constraint>();
                    for(int c : component)
                        for(Constraint cons : definitions.get(c).getUnsolvedConstraints())
                            constraintMap.put(cons.constraint, cons);
                }
                Constraint cons = constraintMap.get(constraint);
                compilationContext.errorLog.log(cons.getDemandLocation(), 
                        "Constrain " + constraint + 
                        " contains free variables not mentioned in the type of the value.");
            }
        
        for(int c : component)
            definitions.get(c).injectEvidence(vars, constraints);
    }
}
