package org.simantics.scl.compiler.elaboration.query;

import java.util.ArrayList;
import java.util.Set;

import org.simantics.scl.compiler.elaboration.contexts.ReplaceContext;
import org.simantics.scl.compiler.elaboration.expressions.EApply;
import org.simantics.scl.compiler.elaboration.expressions.ESimpleLambda;
import org.simantics.scl.compiler.elaboration.expressions.ESimpleLet;
import org.simantics.scl.compiler.elaboration.expressions.EVariable;
import org.simantics.scl.compiler.elaboration.expressions.Expression;
import org.simantics.scl.compiler.elaboration.expressions.QueryTransformer;
import org.simantics.scl.compiler.elaboration.expressions.Variable;
import org.simantics.scl.compiler.elaboration.query.compilation.ConstraintCollectionContext;
import org.simantics.scl.compiler.elaboration.query.compilation.DerivateException;
import org.simantics.scl.compiler.elaboration.query.compilation.QueryCompilationContext;
import org.simantics.scl.compiler.elaboration.query.compilation.QueryConstraint;
import org.simantics.scl.compiler.elaboration.query.compilation.UnsolvableQueryException;
import org.simantics.scl.compiler.elaboration.relations.LocalRelation;
import org.simantics.scl.compiler.elaboration.relations.SCLRelation;
import org.simantics.scl.compiler.errors.Locations;
import org.simantics.scl.compiler.types.Types;

import gnu.trove.map.hash.THashMap;
import gnu.trove.map.hash.TLongObjectHashMap;
import gnu.trove.set.hash.TIntHashSet;


public class QDisjunction extends QAbstractCombiner {

    public QDisjunction(Query ... queries) {
        super(queries);
    }
    
    private static class CachedPlan {
        Variable[] variables;
        QueryCompilationContext[] subplans;
        double totalBranching;
        double totalCost;
        
        public CachedPlan(Variable[] variables, QueryCompilationContext[] subplans,
                double totalBranching, double totalCost) {
            this.variables = variables;
            this.subplans = subplans;
            this.totalBranching = totalBranching;
            this.totalCost = totalCost;
        }
    }

    @Override
    public void collectConstraints(final ConstraintCollectionContext context) {
        TIntHashSet vars = new TIntHashSet();
        collectVars(context.getVariableMap(), vars);
        
        final Variable continuationFunction = new Variable("continuation");
        int[] variables = vars.toArray();
        long variableMask_ = 0L;
        for(int v : variables)
            variableMask_ |= 1L << v;
        final long variableMask = variableMask_;
        
        context.addConstraint(new QueryConstraint(variables) {
            
            TLongObjectHashMap<CachedPlan> cache = new TLongObjectHashMap<CachedPlan>();
            
            private CachedPlan create(long boundVariables) {
                QueryCompilationContext[] subplans = new QueryCompilationContext[queries.length];
                double totalBranching = 1.0;
                double totalCost = 0.0;
                ArrayList<Variable> solvedVariablesList = new ArrayList<Variable>();
                for(int v : variables)
                    if( ((boundVariables >> v)&1) == 0 )
                        solvedVariablesList.add(context.getVariable(v));
                Variable[] solvedVariables = solvedVariablesList.toArray(new Variable[solvedVariablesList.size()]);
                for(int i=0;i<queries.length;++i) {
                    Expression[] parameters = new Expression[solvedVariables.length];
                    for(int j=0;j<solvedVariables.length;++j)
                        parameters[j] = new EVariable(solvedVariables[j]);
                    EApply cont = new EApply(Locations.NO_LOCATION, Types.PROC,
                            new EVariable(continuationFunction), parameters);
                    cont.setType(context.getQueryCompilationContext().getContinuation().getType());
                    subplans[i] = context.getQueryCompilationContext().createSubcontext(cont);
                    try {
                        new QExists(solvedVariables, queries[i]).generate(subplans[i]);
                    } catch (UnsolvableQueryException e) {
                        return new CachedPlan(null, null, Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY);
                    }
                    totalBranching += subplans[i].getBranching();
                    totalCost += subplans[i].getCost();
                }
                return new CachedPlan(solvedVariables, subplans, totalBranching, totalCost);
            }
            
            private CachedPlan get(long boundVariables) {
                boundVariables &= variableMask;
                CachedPlan plan = cache.get(boundVariables);
                if(plan == null) {
                    plan = create(boundVariables);
                    cache.put(boundVariables, plan);
                }
                return plan;
            }
            
            @Override
            public double getSolutionCost(long boundVariables) {
                return get(boundVariables).totalCost;
            }
            
            @Override
            public double getSolutionBranching(long boundVariables) {
                return get(boundVariables).totalBranching;
            }
            
            @Override
            public boolean canBeSolvedFrom(long boundVariables) {
                return get(boundVariables).totalCost != Double.POSITIVE_INFINITY;
            }
            
            @Override
            public void generate(QueryCompilationContext context) {
                CachedPlan plan = get(finalBoundVariables);
                
                Expression[] disjuncts = new Expression[plan.subplans.length];
                for(int i=0;i<plan.subplans.length;++i)
                    disjuncts[i] = plan.subplans[i].getContinuation().copy(context.getTypingContext());
                Expression result = context.disjunction(disjuncts);
                
                ReplaceContext replaceContext = new ReplaceContext(context.getTypingContext());
                Variable[] newVariables = new Variable[plan.variables.length];
                for(int i=0;i<newVariables.length;++i) {
                    Variable oldVariable = plan.variables[i];
                    Variable newVariable = new Variable(oldVariable.getName(), oldVariable.getType());
                    newVariables[i] = newVariable;
                    oldVariable.setName(oldVariable.getName() + "_temp");
                    replaceContext.varMap.put(oldVariable, new EVariable(newVariable));
                }
                
                Expression functionDefinition = context.getContinuation().replace(replaceContext);
                boolean first = true;
                for(int i=plan.variables.length-1;i>=0;--i) {
                    functionDefinition = new ESimpleLambda(
                            first ? Types.PROC /* FIXME */ : Types.NO_EFFECTS,
                            newVariables[i], functionDefinition);
                    first = false;
                }
                continuationFunction.setType(functionDefinition.getType());
                
                context.setContinuation(new ESimpleLet(
                        continuationFunction, 
                        functionDefinition, 
                        result));
            }
        });
    }

    @Override
    public Diff[] derivate(THashMap<LocalRelation, Diffable> diffables) throws DerivateException {
        Diff[][] diffs = new Diff[queries.length][];
        int totalDiffCount = 0;
        for(int i=0;i<queries.length;++i) {
            Diff[] ds = queries[i].derivate(diffables);
            diffs[i] = ds;
            totalDiffCount += ds.length;
        }
        if(totalDiffCount == 0)
            return NO_DIFF;
        else {
            Diff[] result = new Diff[totalDiffCount];
            int i=0;
            for(Diff[] ds : diffs)
                for(Diff diff : ds)
                    result[i++] = diff;
            return result;
        }
    }
    
    @Override
    public Query replace(ReplaceContext context) {
        Query[] newQueries = new Query[queries.length];
        for(int i=0;i<queries.length;++i)
            newQueries[i] = queries[i].replace(context);
        return new QDisjunction(newQueries);
    }

    @Override
    public Query removeRelations(Set<SCLRelation> relations) {
        for(int i=0;i<queries.length;++i) {
            Query query = queries[i];
            Query newQuery = query.removeRelations(relations);
            if(query != newQuery) {
                ArrayList<Query> newQueries = new ArrayList<Query>(queries.length);
                for(int j=0;j<i;++j)
                    newQueries.add(queries[j]);
                if(newQuery != EMPTY_QUERY)
                    newQueries.add(newQuery);
                for(++i;i<queries.length;++i) {
                    query = queries[i];
                    newQuery = query.removeRelations(relations);
                    if(newQuery != EMPTY_QUERY)
                        newQueries.add(newQuery);
                }
                if(newQueries.isEmpty())
                    return EMPTY_QUERY;
                else if(newQueries.size()==1)
                    return newQueries.get(0);
                else
                    return new QDisjunction(newQueries.toArray(new Query[newQueries.size()]));
            }
        }
        return this;
    }
    
    @Override
    public void accept(QueryVisitor visitor) {
        visitor.visit(this);
    }
    
    @Override
    public Query accept(QueryTransformer transformer) {
        return transformer.transform(this);
    }

}
