package org.simantics.scl.compiler.elaboration.expressions;

import java.util.ArrayList;

import org.simantics.scl.compiler.common.exceptions.InternalCompilerError;
import org.simantics.scl.compiler.elaboration.contexts.ReplaceContext;
import org.simantics.scl.compiler.elaboration.contexts.SimplificationContext;
import org.simantics.scl.compiler.elaboration.contexts.TranslationContext;
import org.simantics.scl.compiler.elaboration.contexts.TypingContext;
import org.simantics.scl.compiler.environment.Environment;
import org.simantics.scl.compiler.errors.Locations;
import org.simantics.scl.compiler.internal.codegen.references.BoundVar;
import org.simantics.scl.compiler.internal.codegen.references.IVal;
import org.simantics.scl.compiler.internal.codegen.writer.CodeWriter;
import org.simantics.scl.compiler.internal.codegen.writer.RecursiveDefinitionWriter;
import org.simantics.scl.compiler.internal.elaboration.decomposed.DecomposedExpression;
import org.simantics.scl.compiler.internal.elaboration.utils.ExpressionDecorator;
import org.simantics.scl.compiler.internal.elaboration.utils.StronglyConnectedComponents;
import org.simantics.scl.compiler.types.Type;
import org.simantics.scl.compiler.types.Types;
import org.simantics.scl.compiler.types.exceptions.MatchException;
import org.simantics.scl.compiler.types.kinds.Kinds;

import gnu.trove.map.hash.TObjectIntHashMap;
import gnu.trove.set.hash.THashSet;
import gnu.trove.set.hash.TIntHashSet;

public class ELet extends Expression {
    public Assignment[] assignments;
    public Expression in;
    
    public ELet(long loc, Assignment[] assignments, Expression in) {
        super(loc);
        this.assignments = assignments;
        this.in = in;
    }

    @Override
    public void collectRefs(final TObjectIntHashMap<Object> allRefs, final TIntHashSet refs) {
        for(Assignment assign : assignments)
            assign.value.collectRefs(allRefs, refs);
        in.collectRefs(allRefs, refs);
    }
    
    @Override
    public void collectVars(TObjectIntHashMap<Variable> allVars,
            TIntHashSet vars) {
        for(Assignment assign : assignments)
            assign.value.collectVars(allVars, vars);
        in.collectVars(allVars, vars);
    }
    
    @Override
    protected void updateType() throws MatchException {
        setType(in.getType());
    }
   
    /**
     * Splits let 
     */
    @Override
    public Expression simplify(SimplificationContext context) {
        
        // Simplify assignments
        for(Assignment assignment : assignments) {
            assignment.value = assignment.value.simplify(context);
        }
        
        // Find strongly connected components
        final TObjectIntHashMap<Variable> allVars = new TObjectIntHashMap<Variable>(
                2*assignments.length, 0.5f, -1);

        for(int i=0;i<assignments.length;++i)
            for(Variable var : assignments[i].pattern.getFreeVariables())
                allVars.put(var, i);
        final boolean isRecursive[] = new boolean[assignments.length];
        final ArrayList<int[]> components = new ArrayList<int[]>(Math.max(10, assignments.length)); 
        new StronglyConnectedComponents(assignments.length) {
            @Override
            protected int[] findDependencies(int u) {
                TIntHashSet vars = new TIntHashSet();
                assignments[u].value.collectVars(allVars, vars);
                if(vars.contains(u))
                    isRecursive[u] = true;
                return vars.toArray();
            }

            @Override
            protected void reportComponent(int[] component) {
                components.add(component);
            }

        }.findComponents();

        // Simplify in
        Expression result = in.simplify(context);
        
        // Handle each component
        for(int j=components.size()-1;j>=0;--j) {
            int[] component = components.get(j);
            boolean recursive = component.length > 1 || isRecursive[component[0]];
            if(recursive) {
                Assignment[] cAssignments = new Assignment[component.length];
                for(int i=0;i<component.length;++i)
                    cAssignments[i] = assignments[component[i]];
                result = new ELet(location, cAssignments, result);
            }
            else {
                Assignment assignment = assignments[component[0]];
                Expression pattern = assignment.pattern;
                
                if(pattern instanceof EVariable) {
                    EVariable pvar = (EVariable)pattern;
                    result = new ESimpleLet(location, pvar.variable, assignment.value, result);
                }
                else {
                    result = new EMatch(location, new Expression[] {assignment.value},
                                    new Case(new Expression[] {pattern}, result));
                }
            }
        }
        
        return result;
    }

    @Override
    public void collectFreeVariables(THashSet<Variable> vars) {
        in.collectFreeVariables(vars);
        for(Assignment assign : assignments)
            assign.value.collectFreeVariables(vars);
        for(Assignment assign : assignments) 
            assign.pattern.removeFreeVariables(vars);
    }

    @Override
    public Expression resolve(TranslationContext context) {
        throw new InternalCompilerError("ELet should be already resolved.");
    }
    
    @Override
    public Expression replace(ReplaceContext context) {
        Assignment[] newAssignments = new Assignment[assignments.length];
        for(int i=0;i<assignments.length;++i)
            newAssignments[i] = assignments[i].replace(context);            
        Expression newIn = in.replace(context);
        return new ELet(getLocation(), newAssignments, newIn);
    }
    
    @Override
    public IVal toVal(Environment env, CodeWriter w) {
        // Create bound variables
        BoundVar[] vars = new BoundVar[assignments.length];
        for(int i=0;i<assignments.length;++i) {
            Expression pattern = assignments[i].pattern;
            if(!(pattern instanceof EVariable))
                throw new InternalCompilerError("Cannot handle pattern targets in recursive assignments.");
            vars[i] = new BoundVar(pattern.getType());
            ((EVariable)pattern).getVariable().setVal(vars[i]);
        }
        
        // Create values
        RecursiveDefinitionWriter rdw = w.createRecursiveDefinition();
        long range = Locations.NO_LOCATION;
        for(Assignment assign2 : assignments) {
            range = Locations.combine(range, assign2.pattern.location);
            range = Locations.combine(range, assign2.value.location);
        }
        rdw.setLocation(range);
        for(int i=0;i<assignments.length;++i) {
            DecomposedExpression decomposed = 
                    DecomposedExpression.decompose(assignments[i].value);
            CodeWriter newW = rdw.createFunction(vars[i], 
                    decomposed.typeParameters,
                    decomposed.effect,
                    decomposed.returnType, 
                    decomposed.parameterTypes);
            IVal[] parameters = newW.getParameters();
            for(int j=0;j<parameters.length;++j)
                decomposed.parameters[j].setVal(parameters[j]);
            newW.return_(decomposed.body.toVal(env, newW));
        }
        return in.toVal(env, w);
    }
        
    private void checkAssignments(TypingContext context) {
        for(Assignment assign : assignments)
            assign.pattern = assign.pattern.checkTypeAsPattern(context, Types.metaVar(Kinds.STAR));
        for(Assignment assign : assignments)
            assign.value = assign.value.checkType(context, assign.pattern.getType());
    }
    
    @Override
    public Expression inferType(TypingContext context) {
        checkAssignments(context);
        in = in.inferType(context);
        return this;
    }
    
    @Override
    public Expression checkBasicType(TypingContext context, Type requiredType) {
        checkAssignments(context);
        in = in.checkType(context, requiredType);
        return this;
    }
    
    @Override
    public Expression checkIgnoredType(TypingContext context) {
        checkAssignments(context);
        in = in.checkIgnoredType(context);
        return this;
    }

    @Override
    public Expression decorate(ExpressionDecorator decorator) {
        in = in.decorate(decorator);
        for(Assignment assignment : assignments)
            assignment.decorate(decorator);
        return decorator.decorate(this);
    }

    @Override
    public void collectEffects(THashSet<Type> effects) {
        for(Assignment assignment : assignments) {
            assignment.pattern.collectEffects(effects);
            assignment.value.collectEffects(effects);
        }
        in.collectEffects(effects);
    }
    
    @Override
    public void setLocationDeep(long loc) {
        if(location == Locations.NO_LOCATION) {
            location = loc;
            for(Assignment assignment : assignments)
                assignment.setLocationDeep(loc);
            in.setLocationDeep(loc);
        }
    }
    
    @Override
    public void accept(ExpressionVisitor visitor) {
        visitor.visit(this);
    }

    @Override
    public void forVariables(VariableProcedure procedure) {
        for(Assignment assignment : assignments)
            assignment.forVariables(procedure);
        in.forVariables(procedure);
    }
    
    @Override
    public Expression accept(ExpressionTransformer transformer) {
        return transformer.transform(this);
    }
    
    @Override
    public int getSyntacticFunctionArity() {
        return in.getSyntacticFunctionArity();
    }

}
