package org.simantics.scl.compiler.elaboration.relations;

import static org.simantics.scl.compiler.elaboration.expressions.Expressions.apply;
import static org.simantics.scl.compiler.elaboration.expressions.Expressions.if_;
import static org.simantics.scl.compiler.elaboration.expressions.Expressions.lambda;
import static org.simantics.scl.compiler.elaboration.expressions.Expressions.let;
import static org.simantics.scl.compiler.elaboration.expressions.Expressions.letRec;
import static org.simantics.scl.compiler.elaboration.expressions.Expressions.tuple;
import static org.simantics.scl.compiler.elaboration.expressions.Expressions.var;

import org.simantics.scl.compiler.common.names.Name;
import org.simantics.scl.compiler.common.names.Names;
import org.simantics.scl.compiler.elaboration.expressions.EApply;
import org.simantics.scl.compiler.elaboration.expressions.EVariable;
import org.simantics.scl.compiler.elaboration.expressions.Expression;
import org.simantics.scl.compiler.elaboration.expressions.Variable;
import org.simantics.scl.compiler.elaboration.query.compilation.QueryCompilationContext;
import org.simantics.scl.compiler.types.TVar;
import org.simantics.scl.compiler.types.Type;
import org.simantics.scl.compiler.types.Types;

public class TransitiveClosureRelation extends AbstractRelation implements CompositeRelation {

    SCLRelation baseRelation;
    
    public TransitiveClosureRelation(SCLRelation baseRelation) {
        this.baseRelation = baseRelation;
    }

    @Override
    public TVar[] getTypeVariables() {
        return baseRelation.getTypeVariables();
    }

    @Override
    public Type[] getParameterTypes() {
        return baseRelation.getParameterTypes();
    }
    
    @Override
    public double getSelectivity(int boundVariabes) {
        return baseRelation.getSelectivity(boundVariabes)*5.0;
    }
    
    @Override
    public int getRequiredVariablesMask() {
        return baseRelation.getRequiredVariablesMask();
    }

    @Override
    public void generate(long location,
            QueryCompilationContext context,
            Type[] typeParameters, Variable[] parameters, int boundVariables) {
        Variable bound, solved;
        if(boundVariables == (1<<(parameters.length-1))-1) {
            bound = parameters[0];
            solved = parameters[parameters.length-1];
        }
        else if(boundVariables == (1<<(parameters.length))-2) {
            bound = parameters[parameters.length-1];
            solved = parameters[0];
        }
        else //if(boundVariables == 3 || boundVariables == 0)
            throw new UnsupportedOperationException("boundVariables = " + boundVariables);
        
        Type type = baseRelation.getParameterTypes()[0];
        if(typeParameters.length > 0)
            type = type.replace(getTypeVariables(), typeParameters);

        Expression continuation = context.getContinuation();
        System.out.println("continuation = " + continuation + " :: " + continuation.getType());
        Variable set = new Variable("set", Types.apply(Types.con("MSet", "T"), type));
        Variable f = new Variable("f", Types.functionE(type, Types.PROC, continuation.getType()));
        Variable innerSolved = new Variable("tcTemp", solved.getType());
        System.out.println("set :: " + set.getType());
        System.out.println("f :: " + f.getType());
        System.out.println("tcTemp :: " + innerSolved.getType());
        
        QueryCompilationContext newContext = context.createSubcontext(new EApply(
                new EVariable(f), new EVariable(innerSolved)
                ));
        Variable[] innerParameters = new Variable[parameters.length];
        if(boundVariables == (1<<(parameters.length-1))-1) {
            innerParameters[0] = solved;
            innerParameters[parameters.length-1] = innerSolved;
        }
        else {
            innerParameters[0] = innerSolved;
            innerParameters[parameters.length-1] = solved;
        }
        for(int i=1;i<parameters.length-1;++i)
            innerParameters[i] = parameters[i];
        baseRelation.generate(location,
                newContext,
                typeParameters,
                innerParameters, boundVariables);
        
        continuation = context.disjunction(continuation, newContext.getContinuation());
        continuation = if_(apply(context.getCompilationContext(), Types.PROC, Names.MSet_add, type,
                var(set), var(solved)),
                continuation,
                context.failure());
        continuation = lambda(Types.PROC, solved, continuation);
        continuation = letRec(f, continuation, apply(var(f), var(bound)));
        continuation = let(set, 
                apply(context.getCompilationContext(), Types.PROC, Names.MSet_create, type, tuple()), 
                continuation);
        context.setContinuation(continuation);
        
        // TODO Auto-generated method stub
        // base :: (a -> <Proc> ()) -> a -> <Proc> ()
        // initial :: a
        // cont :: a -> <Proc> ()
        /* s = MSet.create ()
         * f cur = if MSet.add s cur
         *         then do
         *             cont cur
         *             base f cur
         *         else ()
         * f cur = MSet.add s cur && (cont cur || base f cur)
         * f cur = if MSet.add s cur
         *         then match cont cur with
         *             result @ (Just _) -> result
         *             _ -> base f cur
         *         else Nothing
         */
        
        /* let s = MSet.new ()
               f = \r -> if MSet.add s r
                         then do
                             g r
                             rel 
                         else ()
           in  f init 
         */
    }

    @Override
    public SCLRelation[] getSubrelations() {
        return new SCLRelation[] { baseRelation };
    }

}
