package org.simantics.scl.compiler.compilation;

import static org.simantics.scl.compiler.elaboration.expressions.Expressions.apply;
import static org.simantics.scl.compiler.elaboration.expressions.Expressions.applyTypes;
import static org.simantics.scl.compiler.elaboration.expressions.Expressions.lambda;
import static org.simantics.scl.compiler.elaboration.expressions.Expressions.loc;
import static org.simantics.scl.compiler.elaboration.expressions.Expressions.vars;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;

import org.simantics.scl.compiler.elaboration.contexts.TypingContext;
import org.simantics.scl.compiler.elaboration.expressions.EAmbiguous;
import org.simantics.scl.compiler.elaboration.expressions.EPlaceholder;
import org.simantics.scl.compiler.elaboration.expressions.ETransformation;
import org.simantics.scl.compiler.elaboration.expressions.EVariable;
import org.simantics.scl.compiler.elaboration.expressions.Expression;
import org.simantics.scl.compiler.elaboration.expressions.Variable;
import org.simantics.scl.compiler.elaboration.modules.SCLValue;
import org.simantics.scl.compiler.elaboration.query.Query;
import org.simantics.scl.compiler.elaboration.relations.ConcreteRelation;
import org.simantics.scl.compiler.elaboration.relations.ConcreteRelation.QuerySection;
import org.simantics.scl.compiler.elaboration.relations.SCLRelation;
import org.simantics.scl.compiler.elaboration.rules.MappingRelation;
import org.simantics.scl.compiler.elaboration.rules.TransformationRule;
import org.simantics.scl.compiler.environment.Environment;
import org.simantics.scl.compiler.internal.elaboration.constraints.Constraint;
import org.simantics.scl.compiler.internal.elaboration.constraints.ConstraintEnvironment;
import org.simantics.scl.compiler.internal.elaboration.constraints.ConstraintSolver;
import org.simantics.scl.compiler.internal.elaboration.constraints.ExpressionAugmentation;
import org.simantics.scl.compiler.internal.elaboration.constraints.ReducedConstraints;
import org.simantics.scl.compiler.module.ConcreteModule;
import org.simantics.scl.compiler.types.TPred;
import org.simantics.scl.compiler.types.TVar;
import org.simantics.scl.compiler.types.Type;
import org.simantics.scl.compiler.types.Types;
import org.simantics.scl.compiler.types.kinds.Kinds;
import org.simantics.scl.compiler.types.util.Polarity;

import gnu.trove.map.hash.THashMap;
import gnu.trove.map.hash.TObjectIntHashMap;
import gnu.trove.set.hash.THashSet;
import gnu.trove.set.hash.TIntHashSet;

public class TypeChecking {
    final CompilationContext compilationContext;
    final Environment environment;
    final ConcreteModule module;
    
    ConstraintEnvironment ce;
    TypeCheckingScheduler scheduler;
    
    public TypeChecking(CompilationContext compilationContext, ConcreteModule module) {
        this.compilationContext = compilationContext;
        this.environment = compilationContext.environment;
        this.module = module;
    }
    
    private void typeCheckValues() {
        for(final SCLValue value : module.getValues()) {
            if(value.getExpression() != null) {
                if(value.getType() == null)
                    scheduler.addTypeInferableDefinition(new TypeInferableDefinition() {
                        ArrayList<EPlaceholder> recursiveReferences;
                        ArrayList<EVariable> constraintDemand;
                        ArrayList<Variable> freeEvidence;
                        ArrayList<Constraint> unsolvedConstraints;
                        
                        @Override
                        public long getLocation() {
                            return value.getExpression().getLocation();
                        }
                        
                        @Override
                        public Collection<Object> getDefinedObjects() {
                            return Collections.<Object>singleton(value);
                        }
                        
                        @Override
                        public void collectRefs(TObjectIntHashMap<Object> allRefs, TIntHashSet refs) {
                            value.getExpression().collectRefs(allRefs, refs);
                        }
                        
                        @Override
                        public void initializeTypeChecking(TypingContext context) {
                            value.setType(Types.metaVar(Kinds.STAR));
                            context.recursiveValues.add(value);
                        }
                        
                        @Override
                        public void checkType(TypingContext context) {
                            context.recursiveReferences = 
                                    this.recursiveReferences = new ArrayList<EPlaceholder>();

                            Expression expression = value.getExpression();
                            context.pushEffectUpperBound(expression.location, Types.PROC);
                            expression = expression.checkType(context, value.getType());
                            context.popEffectUpperBound();
                            for(EAmbiguous overloaded : context.overloadedExpressions)
                                overloaded.assertResolved(compilationContext.errorLog);
                            value.setExpression(expression);
                            
                            ArrayList<EVariable> constraintDemand = context.getConstraintDemand();
                            if(!constraintDemand.isEmpty()) {
                                this.constraintDemand = constraintDemand;
                                context.resetConstraintDemand();
                            }
                            
                            expression.getType().addPolarity(Polarity.POSITIVE);
                        }
                        
                        @Override
                        public void solveConstraints() {
                            if(constraintDemand != null) {
                                Expression expression = value.getExpression();
                                
                                ReducedConstraints red = ConstraintSolver.solve(
                                        ce, new ArrayList<TPred>(0), constraintDemand,
                                        true);
                                
                                expression = ExpressionAugmentation.augmentSolved(
                                        red.solvedConstraints, 
                                        expression);
                                value.setExpression(expression);
                                value.setType(expression.getType());
                                
                                for(Constraint c : red.unsolvedConstraints)
                                    if(c.constraint.isGround())
                                        compilationContext.errorLog.log(c.getDemandLocation(), "There is no instance for <"+c.constraint+">.");
                                
                                ArrayList<Variable> fe = new ArrayList<Variable>(red.unsolvedConstraints.size());
                                for(Constraint c : red.unsolvedConstraints)
                                    fe.add(c.evidence);
                                unsolvedConstraints = red.unsolvedConstraints;
                                freeEvidence = fe;
                            }
                            else {
                                value.setExpression(value.getExpression().decomposeMatching());
                                freeEvidence = new ArrayList<Variable>(0);
                            }
                        }
                        
                        @Override
                        public void collectFreeTypeVariables(
                                THashSet<TVar> varSet) {
                            Type type = value.getType();
                            type = type.convertMetaVarsToVars();
                            value.setType(type);
                            varSet.addAll(Types.freeVars(type));
                        }
                        
                        @Override
                        public ArrayList<Variable> getFreeEvidence() {
                            return freeEvidence;
                        }
                        
                        @Override
                        public ArrayList<Constraint> getUnsolvedConstraints() {
                            return unsolvedConstraints;
                        }
                        
                        @Override
                        public void injectEvidence(TVar[] vars, TPred[] constraints) {
                            // Create evidence array of every value in the group that has the variables
                            // in the same array as in the shared array
                            THashMap<TPred, Variable> indexedEvidence = new THashMap<TPred, Variable>(freeEvidence.size());
                            for(Variable v : freeEvidence)
                                indexedEvidence.put((TPred)v.getType(), v);
                            freeEvidence.clear();
                            for(TPred c : constraints) {
                                Variable var = indexedEvidence.get(c);
                                if(var == null) {
                                    // These are variables that are not directly needed in 
                                    // this definition but in the definitions that are
                                    // recursively called
                                    var = new Variable("evX");
                                    var.setType(c);
                                    freeEvidence.add(var);
                                }
                                freeEvidence.add(var);
                            }
                            
                            // Add evidence parameters to the functions
                            value.setExpression(lambda(Types.NO_EFFECTS, freeEvidence, value.getExpression())
                                    .closure(vars));
                            value.setType(Types.forAll(vars, 
                                    Types.constrained(constraints, value.getType())));
                            
                            // Add evidence parameters to recursive calls
                            for(EPlaceholder ref : recursiveReferences) {
                                ref.expression = loc(ref.expression.location, apply(
                                        Types.NO_EFFECTS,
                                        applyTypes(ref.expression, vars),
                                        vars(freeEvidence)));
                            }
                        }
                    });
                else
                    scheduler.addPostTypeCheckingRunnable(new Runnable() {
                        @Override
                        public void run() {
                            Type type = value.getType();

                            Expression expression = value.getExpression();

                            int errorCountBeforeTypeChecking = compilationContext.errorLog.getErrorCount();
                            int functionArity = expression.getSyntacticFunctionArity();
                            
                            try {
                                ArrayList<TVar> vars = new ArrayList<TVar>();
                                type = Types.removeForAll(type, vars);
                                ArrayList<TPred> givenConstraints = new ArrayList<TPred>();
                                type = Types.removePred(type, givenConstraints);

                                TypingContext context = new TypingContext(compilationContext);
                                context.pushEffectUpperBound(expression.location, Types.PROC);
                                expression = expression.checkType(context, type);
                                context.popEffectUpperBound();
                                for(EAmbiguous overloaded : context.overloadedExpressions)
                                    overloaded.assertResolved(compilationContext.errorLog);
                                expression.getType().addPolarity(Polarity.POSITIVE);
                                context.solveSubsumptions(expression.getLocation());
                                
                                if(compilationContext.errorLog.getErrorCount() != errorCountBeforeTypeChecking) {
                                    int typeArity = Types.getArity(type); 
                                    if(typeArity != functionArity)
                                        compilationContext.errorLog.logWarning(value.definitionLocation, "Possible problem: type declaration has " + typeArity + " parameter types, but function definition has " + functionArity + " parameters.");
                                }
                                
                                ArrayList<EVariable> demands = context.getConstraintDemand();
                                if(!demands.isEmpty() || !givenConstraints.isEmpty()) {
                                    ReducedConstraints red = 
                                            ConstraintSolver.solve(ce, givenConstraints, demands, true);    
                                    givenConstraints.clear();
                                    for(Constraint c :  red.unsolvedConstraints) {
                                        compilationContext.errorLog.log(c.getDemandLocation(), 
                                                "Constraint <"+c.constraint+"> is not given and cannot be derived.");
                                    }
                                    if(compilationContext.errorLog.hasNoErrors()) { // To prevent exceptions
                                        expression = ExpressionAugmentation.augmentSolved(
                                                red.solvedConstraints,
                                                expression);
                                        expression = ExpressionAugmentation.augmentUnsolved(
                                                red.givenConstraints, 
                                                expression); 
                                    }
                                }
                                else {
                                    if(compilationContext.errorLog.hasNoErrors()) // To prevent exceptions
                                        expression = expression.decomposeMatching();
                                }
                                expression = expression.closure(vars.toArray(new TVar[vars.size()]));
                                value.setExpression(expression);
                            } catch(Exception e) {
                                compilationContext.errorLog.log(expression.location, e);
                            }
                        }
                    });
            }
        }
    }    
    
    private void typeCheckRelations() {
        for(SCLRelation relation_ : module.getRelations()) {
            if(!(relation_ instanceof ConcreteRelation))
                continue;
            final ConcreteRelation relation = (ConcreteRelation)relation_;
            scheduler.addTypeInferableDefinition(new TypeInferableDefinition() {
                
                @Override
                public void initializeTypeChecking(TypingContext context) {
                    for(Variable parameter : relation.parameters) {
                        Type type = Types.metaVar(Kinds.STAR);
                        type.addPolarity(Polarity.BIPOLAR);
                        parameter.setType(type);
                    }
                }
                
                @Override
                public void solveConstraints() {
                }
                
                @Override
                public void injectEvidence(TVar[] vars, TPred[] constraints) {
                    relation.typeVariables = vars;
                }
                
                @Override
                public ArrayList<Constraint> getUnsolvedConstraints() {
                    return new ArrayList<Constraint>(0); // TODO
                }
                
                @Override
                public long getLocation() {
                    return relation.location;
                }
                
                @Override
                public ArrayList<Variable> getFreeEvidence() {
                    return new ArrayList<Variable>(0); // TODO
                }
                
                @Override
                public Collection<Object> getDefinedObjects() {
                    return Collections.<Object>singleton(relation);
                }
                
                @Override
                public void collectRefs(TObjectIntHashMap<Object> allRefs, TIntHashSet refs) {
                    for(QuerySection section : relation.getSections())
                        section.query.collectRefs(allRefs, refs);
                }
                
                @Override
                public void collectFreeTypeVariables(THashSet<TVar> varSet) {
                    for(Variable parameter : relation.parameters) {
                        Type parameterType = parameter.getType().convertMetaVarsToVars();
                        varSet.addAll(Types.freeVars(parameterType));
                    }
                }
                
                @Override
                public void checkType(TypingContext context) {
                    for(QuerySection section : relation.getSections()) {
                        section.effect = Types.metaVar(Kinds.EFFECT);
                        context.pushEffectUpperBound(relation.location, section.effect);
                        for(Variable variable : section.existentials)
                            variable.setType(Types.metaVar(Kinds.STAR));
                        section.query.checkType(context);
                        context.popEffectUpperBound();
                    }
                    
                    if(relation.enforceSection != null) {
                        relation.writingEffect = Types.metaVar(Kinds.EFFECT);
                        context.pushEffectUpperBound(relation.location, relation.writingEffect);
                        relation.enforceSection.checkType(context);
                        context.popEffectUpperBound();
                    }
                }
            });
        }
    }
    
    public void typeCheck() {
        ce = new ConstraintEnvironment(environment);
        scheduler = new TypeCheckingScheduler(compilationContext);
        
        typeCheckValues();
        typeCheckRelations();
        typeCheckRules();
        
        scheduler.schedule();
    }
    
    private void typeCheckRules() {
        scheduler.addTypeInferableDefinition(new TypeInferableDefinition() {
            @Override
            public void solveConstraints() {
                // TODO Auto-generated method stub
            }
            
            @Override
            public void injectEvidence(TVar[] vars, TPred[] constraints) {
                // TODO Auto-generated method stub
                
            }
            
            @Override
            public void initializeTypeChecking(TypingContext context) {
                // TODO Auto-generated method stub
                
            }
            
            @Override
            public ArrayList<Constraint> getUnsolvedConstraints() {
                return new ArrayList<Constraint>(0);
                /*
                ArrayList<EVariable> demands = context.getConstraintDemand();
                if(!demands.isEmpty()) {
                    ReducedConstraints red = 
                            ConstraintSolver.solve(ce, new ArrayList<TPred>(), demands, true);
                    for(Constraint c :  red.unsolvedConstraints) {
                        errorLog.log(c.getDemandLocation(), 
                                "Constraint <"+c.constraint+"> is not given and cannot be derived.");
                    }
                }*/
            }
            
            @Override
            public long getLocation() {
                // TODO Auto-generated method stub
                return 0;
            }
            
            @Override
            public ArrayList<Variable> getFreeEvidence() {
                return new ArrayList<Variable>(0);
            }
            
            @Override
            public Collection<Object> getDefinedObjects() {
                return Collections.singleton(ETransformation.TRANSFORMATION_RULES_TYPECHECKED);
            }
            
            @Override
            public void collectRefs(TObjectIntHashMap<Object> allRefs, TIntHashSet refs) {
                for(TransformationRule rule : module.getRules())
                    for(Query[] queries : rule.sections.values())
                        for(Query query : queries)
                            query.collectRefs(allRefs, refs);
            }
            
            @Override
            public void collectFreeTypeVariables(THashSet<TVar> varSet) {
            }
            
            @Override
            public void checkType(TypingContext context) {
                for(TransformationRule rule : module.getRules()) {
                    context.pushEffectUpperBound(rule.location, Types.metaVar(Kinds.EFFECT));
                    rule.checkType(context);
                    rule.setEffect(Types.canonical(context.popEffectUpperBound()));
                }
            }
        });
        
        if(!module.getMappingRelations().isEmpty())
            scheduler.addPostTypeCheckingRunnable(new Runnable() {
                @Override
                public void run() {
                    for(MappingRelation mappingRelation : module.getMappingRelations())
                        for(Type parameterType : mappingRelation.parameterTypes)
                            if(!parameterType.isGround()) {
                                compilationContext.errorLog.log(mappingRelation.location, "Parameter types of the mapping relation are not completely determined.");
                                break;
                            }
                }
            });
    }
}
