package org.simantics.scl.compiler.elaboration.expressions;

import org.simantics.scl.compiler.common.exceptions.InternalCompilerError;
import org.simantics.scl.compiler.elaboration.contexts.SimplificationContext;
import org.simantics.scl.compiler.elaboration.contexts.TranslationContext;
import org.simantics.scl.compiler.elaboration.contexts.TypingContext;
import org.simantics.scl.compiler.elaboration.modules.SCLValue;
import org.simantics.scl.compiler.environment.Environment;
import org.simantics.scl.compiler.errors.Locations;
import org.simantics.scl.compiler.internal.codegen.references.IVal;
import org.simantics.scl.compiler.internal.codegen.writer.CodeWriter;
import org.simantics.scl.compiler.internal.elaboration.utils.ExpressionDecorator;
import org.simantics.scl.compiler.types.Type;
import org.simantics.scl.compiler.types.Types;
import org.simantics.scl.compiler.types.exceptions.MatchException;
import org.simantics.scl.compiler.types.exceptions.UnificationException;
import org.simantics.scl.compiler.types.kinds.Kinds;

import gnu.trove.map.hash.TObjectIntHashMap;
import gnu.trove.set.hash.THashSet;
import gnu.trove.set.hash.TIntHashSet;

public class EBind extends SimplifiableExpression {
    public Expression pattern;
    public Expression value;
    public Expression in;
    private EVariable monadEvidence;
    SCLValue bindFunction;
    Type monadType;
    Type valueContentType;
    Type inContentType;
    
    public EBind(long loc, Expression pattern, Expression value, Expression in) {
        super(loc);
        this.pattern = pattern;
        this.value = value;
        this.in = in;
    }

    public EBind(long loc, Expression pattern, Expression value, Expression in,
            SCLValue bindFunction) {
        super(loc);
        this.pattern = pattern;
        this.value = value;
        this.in = in;
    }

    @Override
    public void collectRefs(final TObjectIntHashMap<Object> allRefs, final TIntHashSet refs) {
        value.collectRefs(allRefs, refs);
        in.collectRefs(allRefs, refs);
    }
    
    @Override
    public void collectVars(TObjectIntHashMap<Variable> allVars,
            TIntHashSet vars) {
        value.collectVars(allVars, vars);
        in.collectVars(allVars, vars);
    }
    
    @Override
    protected void updateType() throws MatchException {
        setType(in.getType());
    }
    
    @Override
    public Expression checkBasicType(TypingContext context, Type requiredType) {
        monadType = Types.metaVar(Kinds.STAR_TO_STAR);
        inContentType = Types.metaVar(Kinds.STAR);
        Type monadContent = Types.apply(monadType, inContentType);
        try {
            Types.unify(requiredType, monadContent);
        } catch (UnificationException e) {
            context.typeError(location, requiredType, monadContent);
            return this;
        }
        
        Variable variable = new Variable("monadEvidence");
        variable.setType(Types.pred(Types.MONAD, monadType));
        monadEvidence = new EVariable(getLocation(), variable);
        monadEvidence.setType(variable.getType());
        context.addConstraintDemand(monadEvidence);
        
        pattern = pattern.checkTypeAsPattern(context, Types.metaVar(Kinds.STAR));
        valueContentType = pattern.getType();
        value = value.checkType(context, Types.apply(monadType, valueContentType));
        in = in.checkType(context, requiredType);
        Type inType = in.getType();
        setType(inType);
        return this;
    }

    @Override
    public IVal toVal(Environment env, CodeWriter w) {
        throw new InternalCompilerError("EBind should be eliminated.");
    }

    /**
     * Splits let 
     */
    @Override
    public Expression simplify(SimplificationContext context) {    
        value = value.simplify(context);
        in = in.simplify(context);
        pattern = pattern.simplify(context);
        
        long loc = getLocation();
        Expression simplified = new EApply(loc,
                new EConstant(loc, bindFunction, Types.canonical(monadType), Types.canonical(valueContentType), Types.canonical(inContentType)),
                monadEvidence, 
                value,
                new ELambda(loc, new Case[] {
                    new Case(new Expression[] { pattern }, in)
                }));
        simplified.setType(getType());
        
        return simplified.simplify(context);
    }

    @Override
    public void collectFreeVariables(THashSet<Variable> vars) {
        in.collectFreeVariables(vars);
        value.collectFreeVariables(vars);
        pattern.removeFreeVariables(vars);
    }

    @Override
    public Expression resolve(TranslationContext context) {
        value = value.resolve(context);
        
        context.pushFrame();
        pattern = pattern.resolveAsPattern(context);        
        in = in.resolve(context);
        context.popFrame();
        
        bindFunction = context.getBindFunction();
        
        return this; 
    }
    
    @Override
    public Expression decorate(ExpressionDecorator decorator) {
        pattern = pattern.decorate(decorator);
        value = value.decorate(decorator);
        in = in.decorate(decorator);
        return decorator.decorate(this);
    }

    @Override
    public void collectEffects(THashSet<Type> effects) {
        pattern.collectEffects(effects);
        value.collectEffects(effects);
        in.collectEffects(effects);
    }
    
    @Override
    public void setLocationDeep(long loc) {
        if(location == Locations.NO_LOCATION) {
            location = loc;
            pattern.setLocationDeep(loc);
            value.setLocationDeep(loc);
            in.setLocationDeep(loc);
        }
    }
    
    @Override
    public void accept(ExpressionVisitor visitor) {
        visitor.visit(this);
    }

    @Override
    public void forVariables(VariableProcedure procedure) {
        pattern.forVariables(procedure);
        value.forVariables(procedure);
        if(monadEvidence != null)
            monadEvidence.forVariables(procedure);
    }
    
    @Override
    public Expression accept(ExpressionTransformer transformer) {
        return transformer.transform(this);
    }

}
