package org.simantics.scl.compiler.elaboration.expressions;

import java.util.ArrayList;

import org.simantics.scl.compiler.elaboration.contexts.ReplaceContext;
import org.simantics.scl.compiler.elaboration.contexts.SimplificationContext;
import org.simantics.scl.compiler.elaboration.contexts.TranslationContext;
import org.simantics.scl.compiler.elaboration.contexts.TypingContext;
import org.simantics.scl.compiler.environment.Environment;
import org.simantics.scl.compiler.errors.Locations;
import org.simantics.scl.compiler.internal.codegen.references.IVal;
import org.simantics.scl.compiler.internal.codegen.ssa.exits.Throw;
import org.simantics.scl.compiler.internal.codegen.writer.CodeWriter;
import org.simantics.scl.compiler.internal.elaboration.matching.PatternMatchingCompiler;
import org.simantics.scl.compiler.internal.elaboration.matching.Row;
import org.simantics.scl.compiler.internal.elaboration.utils.ExpressionDecorator;
import org.simantics.scl.compiler.types.Type;
import org.simantics.scl.compiler.types.Types;
import org.simantics.scl.compiler.types.kinds.Kinds;

import gnu.trove.map.hash.TObjectIntHashMap;
import gnu.trove.set.hash.THashSet;
import gnu.trove.set.hash.TIntHashSet;

public class EMatch extends Expression {

    public Expression[] scrutinee;
    public Case[] cases;
    
    public EMatch(Expression[] scrutinee, Case ... cases) {
        this.scrutinee = scrutinee;
        this.cases = cases;
    }
    
    public EMatch(Expression scrutinee, Case ... cases) {
        this(new Expression[] {scrutinee}, cases);
    }

    public EMatch(long loc, Expression[] scrutinee, Case ... cases) {
        super(loc);
        this.scrutinee = scrutinee;
        this.cases = cases;
    }

	public void collectRefs(TObjectIntHashMap<Object> allRefs, TIntHashSet refs) {
        for(Expression s : scrutinee)
            s.collectRefs(allRefs, refs);
        for(Case case_ : cases)
            case_.collectRefs(allRefs, refs);
    }
	
	@Override
	public void collectVars(TObjectIntHashMap<Variable> allVars,
	        TIntHashSet vars) {
	    for(Expression s : scrutinee)
            s.collectVars(allVars, vars);
	    for(Case case_ : cases)
            case_.collectVars(allVars, vars);
	}
	
	@Override
	protected void updateType() {
	    setType(cases[0].value.getType());
	}

	@Override
	public IVal toVal(Environment env, CodeWriter w) {
	    ArrayList<Row> rows = new ArrayList<Row>(cases.length);
	    for(Case case_ : cases)
	        rows.add(new Row(case_.patterns, case_.value));
	    
	    IVal[] scrutineeVals = new IVal[scrutinee.length];
	    for(int i=0;i<scrutinee.length;++i)
	        scrutineeVals[i] = scrutinee[i].toVal(env, w);
	    
	    CodeWriter joinPoint = w.createBlock(getType());
	    CodeWriter failurePoint = w.createBlock(); // TODO generate only one failurePoint per function
	    PatternMatchingCompiler.split(w, env, scrutineeVals, joinPoint.getContinuation(), failurePoint.getContinuation(), rows);
	    failurePoint.throw_(location, Throw.MatchingException, "Matching failure at: " + toString());
	    w.continueAs(joinPoint);
	    return w.getParameters()[0];
    }

    @Override
    public void collectFreeVariables(THashSet<Variable> vars) {
        for(Expression s : scrutinee)
            s.collectFreeVariables(vars);
        for(Case case_ : cases)
            case_.collectFreeVariables(vars);
    }
    
    @Override
    public Expression simplify(SimplificationContext context) {
        for(int i=0;i<scrutinee.length;++i)
            scrutinee[i] = scrutinee[i].simplify(context);
        for(Case case_ : cases)
            case_.simplify(context);                    
        if(cases.length == 1 && scrutinee.length == 1) {
            Case case_ = cases[0];
            Expression pattern = case_.patterns[0];
            if(case_.patterns[0] instanceof EVariable
                    && !(case_.value instanceof GuardedExpressionGroup)) {
                Variable var = ((EVariable)pattern).variable;
                return new ESimpleLet(var, scrutinee[0], case_.value);
            }
        }
        return this;
    }

    @Override
    public Expression resolve(TranslationContext context) {
        for(int i=0;i<scrutinee.length;++i)
            scrutinee[i] = scrutinee[i].resolve(context);
        for(Case case_ : cases)
            case_.resolve(context);
        return this;
    }
    
    @Override
    public void setLocationDeep(long loc) {
        if(location == Locations.NO_LOCATION) {
            location = loc;
            for(Case case_ : cases)
                case_.setLocationDeep(loc);
            for(Expression e : scrutinee)
                e.setLocationDeep(loc);
        }
    }
    
    @Override
    public Expression replace(ReplaceContext context) {
        Expression[] newScrutinee = new Expression[scrutinee.length];
        for(int i=0;i<scrutinee.length;++i)
            newScrutinee[i] = scrutinee[i].replace(context);
        Case[] newCases = new Case[cases.length];
        for(int i=0;i<cases.length;++i)
            newCases[i] = cases[i].replace(context);
        return new EMatch(getLocation(), newScrutinee, newCases);
    }
    
    @Override
    public Expression checkBasicType(TypingContext context, Type requiredType) {
        Type[] scrutineeTypes = new Type[scrutinee.length];
        for(int i=0;i<scrutinee.length;++i) {
            scrutinee[i] = scrutinee[i].checkType(context, Types.metaVar(Kinds.STAR));
            scrutineeTypes[i] = scrutinee[i].getType();
        }
        for(Case case_ : cases)
            case_.checkType(context, scrutineeTypes, requiredType);
        setType(requiredType);
        return this;
    }
    
    @Override
    public Expression checkIgnoredType(TypingContext context) {
        Type[] scrutineeTypes = new Type[scrutinee.length];
        for(int i=0;i<scrutinee.length;++i) {
            scrutinee[i] = scrutinee[i].checkType(context, Types.metaVar(Kinds.STAR));
            scrutineeTypes[i] = scrutinee[i].getType();
        }
        for(Case case_ : cases)
            case_.checkIgnoredType(context, scrutineeTypes);
        setType(Types.UNIT);
        return this;
    }

    @Override
    public Expression decorate(ExpressionDecorator decorator) {
        for(int i=0;i<scrutinee.length;++i)
            scrutinee[i] = scrutinee[i].decorate(decorator);
        for(Case case_ : cases)
            case_.decorate(decorator);
        return decorator.decorate(this);
    }

    @Override
    public void collectEffects(THashSet<Type> effects) {
        for(Expression s : scrutinee)
            s.collectEffects(effects);
        for(Case case_ : cases) {
            for(Expression pattern : case_.patterns)
                pattern.collectEffects(effects);
            case_.value.collectEffects(effects);
        }
    }
    
    @Override
    public void accept(ExpressionVisitor visitor) {
        visitor.visit(this);
    }
    
    public Expression[] getScrutinee() {
        return scrutinee;
    }
    
    public Case[] getCases() {
        return cases;
    }

    @Override
    public void forVariables(VariableProcedure procedure) {
        for(Expression s : scrutinee)
            s.forVariables(procedure);
        for(Case case_ : cases)
            case_.forVariables(procedure);
    }
    
    @Override
    public Expression accept(ExpressionTransformer transformer) {
        return transformer.transform(this);
    }

    @Override
    public int getSyntacticFunctionArity() {
        int result = 0;
        for(Case case_ : cases)
            result = Math.max(result, case_.value.getSyntacticFunctionArity());
        return result;
    }
}
