package org.simantics.scl.compiler.elaboration.query.compilation;

import static org.simantics.scl.compiler.elaboration.expressions.Expressions.Just;
import static org.simantics.scl.compiler.elaboration.expressions.Expressions.Nothing;
import static org.simantics.scl.compiler.elaboration.expressions.Expressions.apply;
import static org.simantics.scl.compiler.elaboration.expressions.Expressions.lambda;
import static org.simantics.scl.compiler.elaboration.expressions.Expressions.var;

import org.simantics.scl.compiler.common.exceptions.InternalCompilerError;
import org.simantics.scl.compiler.common.names.Name;
import org.simantics.scl.compiler.common.names.Names;
import org.simantics.scl.compiler.compilation.CompilationContext;
import org.simantics.scl.compiler.constants.BooleanConstant;
import org.simantics.scl.compiler.elaboration.contexts.TypingContext;
import org.simantics.scl.compiler.elaboration.expressions.Case;
import org.simantics.scl.compiler.elaboration.expressions.EApply;
import org.simantics.scl.compiler.elaboration.expressions.EConstant;
import org.simantics.scl.compiler.elaboration.expressions.EIf;
import org.simantics.scl.compiler.elaboration.expressions.ELiteral;
import org.simantics.scl.compiler.elaboration.expressions.EMatch;
import org.simantics.scl.compiler.elaboration.expressions.ESimpleLambda;
import org.simantics.scl.compiler.elaboration.expressions.ESimpleLet;
import org.simantics.scl.compiler.elaboration.expressions.EVariable;
import org.simantics.scl.compiler.elaboration.expressions.Expression;
import org.simantics.scl.compiler.elaboration.expressions.Variable;
import org.simantics.scl.compiler.elaboration.java.Builtins;
import org.simantics.scl.compiler.errors.Locations;
import org.simantics.scl.compiler.types.TPred;
import org.simantics.scl.compiler.types.Type;
import org.simantics.scl.compiler.types.Types;
import org.simantics.scl.compiler.types.exceptions.MatchException;

public class QueryCompilationContext {    
    TypingContext context;
    QueryCompilationMode mode;
    Type resultType;
    Expression continuation;
    double branching = 1.0;
    double cost = 0.0;
    
    public QueryCompilationContext(
            TypingContext context,
            QueryCompilationMode mode,
            Type resultType,
            Expression continuation) {
        this.context = context;
        this.mode = mode;
        this.resultType = resultType;
        this.continuation = continuation;
    }

    public Expression failure() {
        switch(mode) {
        case ITERATE: return new EConstant(Builtins.TUPLE_CONSTRUCTORS[0]);
        case GET_FIRST: return new EConstant(Builtins.Nothing, resultType);
        case GET_ALL: return new EConstant(Builtins.LIST_CONSTRUCTORS[0], resultType);
        case CHECK: return new ELiteral(new BooleanConstant(false));
        default: throw new InternalCompilerError();
        }
    }
    
    public Expression disjunction(Expression a, Expression b) {
        switch(mode) {
        case ITERATE: return new ESimpleLet(new Variable("_", Types.UNIT), a, b);
        case GET_FIRST: {
            Variable var = new Variable("temp", a.getType());
            return new EMatch(a,
                new Case(new EConstant(Builtins.Nothing), b), 
                new Case(new EVariable(var), new EVariable(var)));
        }
        case GET_ALL: {
            try {
                return new EApply(context.getCompilationContext().getConstant(Names.Prelude_appendList, 
                        Types.matchApply(Types.LIST, a.getType())), a, b);
            } catch (MatchException e) {
                throw new InternalCompilerError();
            }
        }
        case CHECK: return new EIf(a, new ELiteral(new BooleanConstant(true)), b);
        default: throw new InternalCompilerError();
        }
    }
    
    public Expression condition(Expression condition, Expression continuation) {
        return new EIf(condition, continuation, failure());
    }
    
    public void condition(Expression condition) {
        continuation = condition(condition, continuation);
    }
    
    public void equalityCondition(long location, Expression a, Expression b) {
        Type type = a.getType();
        condition(new EApply(
                location,
                Types.PROC,
                context.getCompilationContext().getConstant(Names.Builtin_equals, type),
                new Expression[] {
                    a,
                    b
                }
                ));
    }
    
    public void let(Variable variable, Expression value) {
        continuation = new ESimpleLet(variable, value, continuation);
    }
    
    public void iterateMaybe(Variable variable, Expression value) {
        continuation = new EMatch(value,
                new Case(Nothing(variable.getType()), failure()),
                new Case(Just(var(variable)), continuation));
    }
    
    public void match(Expression pattern, Expression value, boolean mayFail) {
        if(mayFail)
            continuation = new EMatch(value,
                    new Case(pattern, continuation),
                    new Case(new EVariable(new Variable("_", pattern.getType())), failure()));
        else
            continuation = new EMatch(value,
                    new Case(pattern, continuation));
    }
    
    public void iterateList(Variable variable, Expression list) {
        try {
            switch(mode) {
            case ITERATE:
                continuation = new EApply(
                        Locations.NO_LOCATION,
                        Types.PROC,
                        context.getCompilationContext().getConstant(Names.Prelude_iterList, variable.getType(), Types.PROC, Types.tupleConstructor(0)),
                        new Expression[] {
                            new ESimpleLambda(Types.PROC, variable, continuation), 
                            list
                        }
                        );
                break;
            case CHECK:
                continuation = new EApply(
                        Locations.NO_LOCATION,
                        Types.PROC,
                        context.getCompilationContext().getConstant(Names.Prelude_any, variable.getType(), Types.PROC),
                        new Expression[] {
                            new ESimpleLambda(Types.PROC, variable, continuation), 
                            list
                        }
                        );
                break;
            case GET_ALL:
                continuation = new EApply(
                        Locations.NO_LOCATION,
                        Types.PROC,
                        context.getCompilationContext().getConstant(Names.Prelude_concatMap, variable.getType(), Types.PROC, 
                                Types.matchApply(Types.LIST, continuation.getType())),
                        new Expression[] {
                            new ESimpleLambda(Types.PROC, variable, continuation), 
                            list
                        }
                        );
                break;
            case GET_FIRST:
                continuation = new EApply(
                        Locations.NO_LOCATION,
                        Types.PROC,
                        context.getCompilationContext().getConstant(Names.Prelude_mapFirst, variable.getType(), Types.PROC,
                                Types.matchApply(Types.MAYBE, continuation.getType())),
                        new Expression[] {
                            new ESimpleLambda(Types.PROC, variable, continuation), 
                            list
                        }
                        );
                break;
            default: throw new InternalCompilerError("iterateList could not handle mode " + mode);
            }
        } catch(MatchException e) {
            throw new InternalCompilerError(e);
        }
    }
    
    public void iterateVector(Variable variable, Expression vector) {
        try {
            switch(mode) {
            case ITERATE:
                continuation = new EApply(
                        Locations.NO_LOCATION,
                        Types.PROC,
                        context.getCompilationContext().getConstant(Names.Vector_iterVector, variable.getType(), Types.PROC, continuation.getType()),
                        new Expression[] {
                            new ESimpleLambda(Types.PROC, variable, continuation), 
                            vector
                        }
                        );
                break;
            case CHECK:
                continuation = new EApply(
                        Locations.NO_LOCATION,
                        Types.PROC,
                        context.getCompilationContext().getConstant(Names.Vector_anyVector, variable.getType(), Types.PROC),
                        new Expression[] {
                            new ESimpleLambda(Types.PROC, variable, continuation), 
                            vector
                        }
                        );
                break;
            case GET_ALL:
                continuation = new EApply(
                        Locations.NO_LOCATION,
                        Types.PROC,
                        context.getCompilationContext().getConstant(Names.Vector_concatMapVector, variable.getType(), Types.PROC,
                                Types.matchApply(Types.LIST, continuation.getType())),
                        new Expression[] {
                            new ESimpleLambda(Types.PROC, variable, continuation), 
                            vector
                        }
                        );
                break;
            case GET_FIRST:
                continuation = new EApply(
                        Locations.NO_LOCATION,
                        Types.PROC,
                        context.getCompilationContext().getConstant(Names.Vector_mapFirstVector, variable.getType(), Types.PROC, 
                                Types.matchApply(Types.MAYBE, continuation.getType())),
                        new Expression[] {
                            new ESimpleLambda(Types.PROC, variable, continuation), 
                            vector
                        }
                        );
                break;
            default: throw new InternalCompilerError("iterateVector could not handle mode " + mode);
            }
        } catch(MatchException e) {
            throw new InternalCompilerError(e);
        }
    }
    
    public void iterateMSet(Variable variable, Expression set) {
        try {
            switch(mode) {
            case ITERATE:
                continuation = apply(context.getCompilationContext(), Types.PROC, Names.MSet_iter, variable.getType(), Types.PROC, continuation.getType(),
                        lambda(Types.PROC, variable, continuation), 
                        set
                        );
                break;
            case GET_FIRST:
                continuation = apply(context.getCompilationContext(), Types.PROC, Names.MSet_mapFirst, variable.getType(), Types.PROC, 
                        Types.matchApply(Types.MAYBE, continuation.getType()),
                        lambda(Types.PROC, variable, continuation), 
                        set
                        );
                break;
            default: throw new InternalCompilerError("iterateMSet could not handle mode " + mode);
            }
        } catch(MatchException e) {
            throw new InternalCompilerError(e);
        }
    }

    public void updateCost(double localBranching, double localCost) {
        branching *= localBranching;
        cost *= localBranching;
        cost += localCost;
    }
    
    public Expression getConstant(Name name, Type[] typeParameters) {
        return context.getCompilationContext().getConstant(name, typeParameters);
    }
    
    public QueryCompilationContext createCheckContext() {
        return new QueryCompilationContext(context, QueryCompilationMode.CHECK,
                null, new ELiteral(new BooleanConstant(true)));
    }
    
    public double getBranching() {
        return branching;
    }
    
    public double getCost() {
        return cost;
    }

    public QueryCompilationContext createSubcontext(Expression innerExpression) {
        return new QueryCompilationContext(context, mode, resultType, innerExpression);
    }
    
    public void setContinuation(Expression continuation) {
        this.continuation = continuation;
    }
    
    public Expression getContinuation() {
        return continuation;
    }

    public Expression disjunction(Expression[] disjuncts) {
        Expression result = failure();
        for(int i=disjuncts.length-1;i>=0;--i)
            result = disjunction(disjuncts[i], result);
        return result;
    }
    
    public TypingContext getTypingContext() {
        return context;
    }

    public EVariable getEvidence(long location, TPred pred) {
        EVariable evidence = new EVariable(location, null);
        evidence.setType(pred);
        context.addConstraintDemand(evidence);
        return evidence;
    }

    public CompilationContext getCompilationContext() {
        return context.getCompilationContext();
    }
}
