package org.simantics.scl.compiler.elaboration.query;

import org.simantics.scl.compiler.common.names.Names;
import org.simantics.scl.compiler.elaboration.contexts.ReplaceContext;
import org.simantics.scl.compiler.elaboration.expressions.EApply;
import org.simantics.scl.compiler.elaboration.expressions.QueryTransformer;
import org.simantics.scl.compiler.elaboration.query.compilation.ConstraintCollectionContext;
import org.simantics.scl.compiler.elaboration.query.compilation.DerivateException;
import org.simantics.scl.compiler.elaboration.query.compilation.QueryCompilationContext;
import org.simantics.scl.compiler.elaboration.query.compilation.QueryConstraint;
import org.simantics.scl.compiler.elaboration.query.compilation.UnsolvableQueryException;
import org.simantics.scl.compiler.elaboration.relations.LocalRelation;
import org.simantics.scl.compiler.types.Type;

import gnu.trove.map.hash.THashMap;
import gnu.trove.set.hash.TIntHashSet;


public class QNegation extends QAbstractModifier {

    public QNegation(Query query) {
        super(query);
    }

    @Override
    public void collectConstraints(ConstraintCollectionContext context) throws UnsolvableQueryException {
        TIntHashSet vars = new TIntHashSet();
        query.collectVars(context.getVariableMap(), vars);
        
        final QueryCompilationContext innerContext = context.getQueryCompilationContext().createCheckContext(); 
        query.generate(innerContext);
        
        context.addConstraint(new QueryConstraint(vars.toArray()) {
            long variableMask;
            {
                for(int v : variables)
                    variableMask |= 1L << v;
            }
            @Override
            public double getSolutionCost(long boundVariables) {
                if((boundVariables & variableMask) != variableMask)
                    return Double.POSITIVE_INFINITY;
                return innerContext.getCost();
            }
            @Override
            public double getSolutionBranching(long boundVariables) {
                if((boundVariables & variableMask) != variableMask)
                    return Double.POSITIVE_INFINITY;
                return innerContext.getBranching();
            }
            @Override
            public boolean canBeSolvedFrom(long boundVariables) {
                return (boundVariables & variableMask) == variableMask;
            }
            
            @Override
            public long getVariableMask() {
                return variableMask;
            }
            @Override
            public void generate(QueryCompilationContext context) {
                context.condition(new EApply(
                        context.getConstant(Names.Prelude_not, Type.EMPTY_ARRAY),
                        innerContext.getContinuation()));
            }
        });
    }
    
    @Override
    public Query replace(ReplaceContext context) {
        return new QNegation(query.replace(context));
    }
    
    @Override
    public Diff[] derivate(THashMap<LocalRelation, Diffable> diffables) throws DerivateException {
        Diff[] diffs = query.derivate(diffables);
        if(diffs.length == 0)
            return NO_DIFF;
        else
            throw new DerivateException(location);
    }
    
    @Override
    public void accept(QueryVisitor visitor) {
        visitor.visit(this);
    }
    
    @Override
    public Query accept(QueryTransformer transformer) {
        return transformer.transform(this);
    }

}
