package org.simantics.scl.compiler.elaboration.query.compilation;

import java.util.ArrayList;

import org.simantics.scl.compiler.common.exceptions.InternalCompilerError;
import org.simantics.scl.compiler.elaboration.contexts.ReplaceContext;
import org.simantics.scl.compiler.elaboration.expressions.EVariable;
import org.simantics.scl.compiler.elaboration.expressions.Expression;
import org.simantics.scl.compiler.elaboration.expressions.Variable;
import org.simantics.scl.compiler.errors.Locations;
import org.simantics.scl.compiler.types.TVar;
import org.simantics.scl.compiler.types.Type;

import gnu.trove.map.hash.THashMap;
import gnu.trove.procedure.TObjectObjectProcedure;
import gnu.trove.set.hash.TIntHashSet;

public class ExpressionConstraint extends QueryConstraint {
    Variable variable;
    Expression expression;
    boolean isPattern;
    
    long forwardVariableMask;
    long backwardVariableMask;
    
    ArrayList<Variable> globalVariables;
    
    public ExpressionConstraint(final ConstraintCollectionContext context, Variable variable,
            Expression expression, boolean isPattern) {
        this.variable = variable;
        this.expression = expression;
        this.isPattern = isPattern;
        
        final TIntHashSet vars = new TIntHashSet();
        expression.collectVars(context.getVariableMap(), vars);
        
        int var1 = context.variableMap.get(variable);
        vars.add(var1);
        backwardVariableMask = 1L << var1;
        
        variables = vars.toArray();
        
        for(int v : variables)
            forwardVariableMask |= 1L << v;
        forwardVariableMask ^= backwardVariableMask;
        
        this.globalVariables = context.variables;
    }
    
    private boolean canBeSolvedForwards(long boundVariables) {
        return (forwardVariableMask & boundVariables) == forwardVariableMask;
    }
    
    private boolean canBeSolvedBackwards(long boundVariables) {
        return (backwardVariableMask & boundVariables) == backwardVariableMask;
    }
    
    @Override
    public boolean canBeSolvedFrom(long boundVariables) {
        return canBeSolvedForwards(boundVariables) || (isPattern && canBeSolvedBackwards(boundVariables));  
    }
    
    @Override
    public double getSolutionCost(long boundVariables) {
        return 1.0;
    }
    
    @Override
    public double getSolutionBranching(long boundVariables) {
        if(canBeSolvedForwards(boundVariables))
            return (boundVariables&1)==0 ? 1.0 : 0.95;
        else if(isPattern && canBeSolvedBackwards(boundVariables))
            return 0.95;
        else
            return Double.POSITIVE_INFINITY;
    }
    
    @Override
    public void generate(final QueryCompilationContext context) {
        if(canBeSolvedForwards(finalBoundVariables)) {
            if(canBeSolvedBackwards(finalBoundVariables))
                context.equalityCondition(expression.location, new EVariable(variable), expression);
            else
                context.let(variable, expression);
        }
        else if(canBeSolvedBackwards(finalBoundVariables)) {
            Expression pattern = expression;
            
            long mask = forwardVariableMask & finalBoundVariables;
            THashMap<Variable, Expression> map = new THashMap<Variable, Expression>();
            if(mask != 0L) {
                for(int variableId : variables)
                    if( ((mask >> variableId)&1L) == 1L ) {
                        Variable original = globalVariables.get(variableId);
                        Variable newVariable = new Variable(original.getName() + "_temp", original.getType());
                        map.put(original, new EVariable(newVariable));
                    }
                
                ReplaceContext replaceContext = new ReplaceContext(new THashMap<TVar,Type>(0), map, context.getTypingContext());
                pattern = pattern.replace(replaceContext);
            }
            context.match(pattern, new EVariable(variable), true);
            map.forEachEntry(new TObjectObjectProcedure<Variable, Expression>() {
                @Override
                public boolean execute(Variable a, Expression b) {
                    context.equalityCondition(Locations.NO_LOCATION, new EVariable(a), b);
                    return true;
                }
            });
        }
        else
            throw new InternalCompilerError(expression.location, "Error happened when tried to solve the query.");
    }
}
