package org.simantics.scl.compiler.internal.codegen.ssa.statements;

import java.util.ArrayList;

import org.simantics.scl.compiler.common.exceptions.InternalCompilerError;
import org.simantics.scl.compiler.constants.Constant;
import org.simantics.scl.compiler.constants.SCLConstant;
import org.simantics.scl.compiler.internal.codegen.references.BoundVar;
import org.simantics.scl.compiler.internal.codegen.references.ValRef;
import org.simantics.scl.compiler.internal.codegen.ssa.SSAClosure;
import org.simantics.scl.compiler.internal.codegen.ssa.SSAFunction;
import org.simantics.scl.compiler.internal.codegen.ssa.SSAStatement;
import org.simantics.scl.compiler.internal.codegen.ssa.binders.ClosureBinder;
import org.simantics.scl.compiler.internal.codegen.utils.CopyContext;
import org.simantics.scl.compiler.internal.codegen.utils.MethodBuilder;
import org.simantics.scl.compiler.internal.codegen.utils.PrintingContext;
import org.simantics.scl.compiler.internal.codegen.utils.SSALambdaLiftingContext;
import org.simantics.scl.compiler.internal.codegen.utils.SSASimplificationContext;
import org.simantics.scl.compiler.internal.codegen.utils.SSAValidationContext;
import org.simantics.scl.compiler.internal.codegen.utils.ValRefVisitor;
import org.simantics.scl.compiler.types.TVar;
import org.simantics.scl.compiler.types.Type;
import org.simantics.scl.compiler.types.Types;

import gnu.trove.map.hash.THashMap;
import gnu.trove.set.hash.THashSet;

public class LetFunctions extends SSAStatement implements ClosureBinder {
    long recursiveGroupLocation;
    SSAClosure firstClosure;

    public LetFunctions() {
    }
    
    public LetFunctions(SSAClosure closure) {
        firstClosure = closure;
        closure.setParent(this);
    }

    @Override
    public void toString(PrintingContext context) {
        for(SSAClosure closure = firstClosure; closure != null; closure = closure.getNext()) {
            context.indentation();
            context.append(closure.getTarget());
            context.append("(" + closure.getTarget().occurrenceCount() + ")");
            context.append(" :: ");
            context.append(closure.getTarget().getType());
            context.append(" = \n");
            context.indent();
            closure.toString(context);
            context.dedent();
        }
    }
    
    public void addClosure(SSAClosure closure) {
        closure.setParent(this);        
        closure.setNext(firstClosure);
        if(firstClosure != null)
            firstClosure.setPrev(closure);
        firstClosure = closure;
    }

    @Override
    public void generateCode(MethodBuilder mb) {
        throw new InternalCompilerError("Functions should be lambda lifted before code generation");        
    }

    @Override
    public void validate(SSAValidationContext context) {
        for(SSAClosure closure = firstClosure; closure != null; closure = closure.getNext()) {
            if(!(closure.getTarget() instanceof BoundVar))
                throw new InternalCompilerError();
            closure.validate(context);
        }
    }

    @Override
    public void destroy() {
        for(SSAClosure closure = firstClosure; closure != null; closure = closure.getNext())
            closure.destroy();
    }

    @Override
    public SSAStatement copy(CopyContext context) {
        LetFunctions result = new LetFunctions();
        for(SSAClosure closure = firstClosure; closure != null; closure = closure.getNext()) {
            SSAClosure newFunction = closure.copy(context);
            newFunction.setTarget(context.copy(closure.getTarget()));
            result.addClosure(newFunction);
        }
        return result;        
    }

    @Override
    public void replace(TVar[] vars, Type[] replacements) {
        for(SSAClosure closure = firstClosure; closure != null; closure = closure.getNext()) {
            ((BoundVar)closure.getTarget()).replace(vars, replacements);
            closure.replace(vars, replacements);
        }
    }

    @Override
    public void addBoundVariablesTo(SSAValidationContext context) {
        for(SSAClosure closure = firstClosure; closure != null; closure = closure.getNext())
            context.validBoundVariables.add((BoundVar)closure.getTarget());        
    }

    @Override
    public SSAClosure getFirstClosure() {
        return firstClosure;
    }

    @Override
    public void setFirstClosure(SSAClosure function) {
        this.firstClosure = function;     
        if(function == null)
            detach();
    }

    @Override
    public void collectFreeVariables(SSAFunction parentFunction,
            ArrayList<ValRef> vars) {
        throw new InternalCompilerError("Should not be called for non-lambda-lifted functions.");
        // FIXME inefficient, some kind of caching needed here
        /*THashSet<BoundVar> tempVars = new THashSet<BoundVar>();
        for(SSAFunction function = firstFunction; function != null; function = function.getNext())
            function.collectFreeVariables(tempVars);
        
        for(BoundVar var : tempVars)
            if(var.getFunctionParent() != parentFunction)
                vars.add(var);*/
    }

    @Override
    public void lambdaLift(SSALambdaLiftingContext context) {
        boolean hasValues = false;
        boolean isRecursive = false;
        
        // Lambda lift substructure and collect free variables
        THashSet<BoundVar> targets = new THashSet<BoundVar>();
        ArrayList<ValRef> freeVars = new ArrayList<ValRef>();        
        for(SSAClosure closure = firstClosure; 
                closure != null; 
                closure = closure.getNext()) {
            hasValues |= closure.isValue();
            closure.lambdaLift(context);
            targets.add((BoundVar)closure.getTarget());
            closure.collectFreeVariables(freeVars);
        }
        
        if(!(firstClosure instanceof SSAFunction) && firstClosure.getNext() == null) {
            THashMap<BoundVar, BoundVar> varMap = new THashMap<BoundVar, BoundVar>(); 
            ArrayList<BoundVar> oldVarsList = new ArrayList<BoundVar>(4);
            ArrayList<BoundVar> newVarsList = new ArrayList<BoundVar>(4);
            BoundVar newTarget = null;
            for(ValRef ref : freeVars) {
                BoundVar var = (BoundVar)ref.getBinding();
                if(targets.contains(var)) {
                    if(newTarget == null)
                        newTarget = new BoundVar(var.getType());
                    ref.replaceBy(newTarget);
                    continue;
                }
                BoundVar newVar = varMap.get(var);
                if(newVar == null) {
                    newVar = new BoundVar(var.getType());
                    newVar.setLabel(var.getLabel());
                    oldVarsList.add(var);
                    newVarsList.add(newVar);
                    varMap.put(var, newVar);
                }
                ref.replaceBy(newVar);
            }
            Constant constant = firstClosure.liftClosure(newTarget, newVarsList.toArray(new BoundVar[newVarsList.size()]));
            new LetApply(targets.iterator().next(), Types.PROC, constant.createOccurrence(), ValRef.createOccurrences(oldVarsList))
            .insertBefore(this);
            detach();
            context.addClosure(firstClosure);
            return;
        }
                
        // Classify by BoundVars
        THashSet<BoundVar> boundVars = new THashSet<BoundVar>(); 
        ArrayList<BoundVar> boundVarsList = new ArrayList<BoundVar>(4);
        ArrayList<ValRef> newFreeVars = new ArrayList<ValRef>(freeVars.size()); 
        for(ValRef ref : freeVars) {
            BoundVar var = (BoundVar)ref.getBinding();
            if(targets.contains(var)) {
                isRecursive = true;
                continue;
            }
            if(boundVars.add(var))
                boundVarsList.add(var);
            newFreeVars.add(ref);
        }
        BoundVar[] outVars = boundVarsList.toArray(new BoundVar[boundVarsList.size()]);
        freeVars = newFreeVars;
        
        // Modify functions
        THashMap<SSAClosure, THashMap<BoundVar, BoundVar>> varMap = new THashMap<SSAClosure, THashMap<BoundVar, BoundVar>>();
        THashMap<SSAClosure, BoundVar[]> inVarsMap = new THashMap<SSAClosure, BoundVar[]>();
        THashMap<SSAClosure, BoundVar> oldTargets = new THashMap<SSAClosure, BoundVar>();
        for(SSAClosure closure = firstClosure; 
                closure != null; 
                closure = closure.getNext()) {
            THashMap<BoundVar, BoundVar> map = new THashMap<BoundVar, BoundVar>(outVars.length);
            BoundVar[] inVars = new BoundVar[outVars.length];            
            for(int i=0;i<inVars.length;++i) {
                inVars[i] = new BoundVar(outVars[i].getType());
                map.put(outVars[i], inVars[i]);
            }
            inVarsMap.put(closure, inVars);
            varMap.put(closure, map);
            
            closure.parametrize(inVars);
            SCLConstant functionConstant = new SCLConstant(context.createName(), closure.getType());
            context.addConstant(functionConstant);   
            oldTargets.put(closure, (BoundVar)closure.getTarget());
            closure.setTarget(functionConstant);
            functionConstant.setDefinition((SSAFunction)closure);   
            functionConstant.setPrivate(true);
            // TODO handle type parameters
            
            // Define target by an application
            /*new LetApply(oldTarget, functionConstant.createOccurrence(), 
                    ValRef.createOccurrences(outVars)).insertBefore(this);*/
        }
        
        for(SSAClosure closure = firstClosure; 
                closure != null; 
                closure = closure.getNext()) {
            BoundVar oldTarget = oldTargets.get(closure);
            for(ValRef ref : oldTarget.getOccurences()) {
                SSAFunction parent = ref.getParentFunction();
                BoundVar[] vars = inVarsMap.get(parent);
                if(vars == null)
                    vars = outVars;
                if(vars.length > 0)
                    ref.replaceByApply(closure.getTarget(), vars);
                else
                    ref.replaceBy(closure.getTarget());
            }
        }
            
        // Fix references
        for(ValRef ref : freeVars) {
            BoundVar inVar = (BoundVar)ref.getBinding();
            if(targets.contains(inVar))
                continue;
            BoundVar outVar = varMap.get(ref.getParentFunction()).get(inVar);
            ref.replaceBy(outVar);
        }
        
        detach();
        //context.validate();
        
        if(hasValues && isRecursive)
            context.getErrorLog().log(recursiveGroupLocation, "Variables defined recursively must all be functions.");
    }
    
    @Override
    public void simplify(SSASimplificationContext context) {
        for(SSAClosure function = firstClosure; 
                function != null; 
                function = function.getNext())
            function.simplify(context);
    }
    
    public void setRecursiveGroupLocation(long recursiveGroupLocation) {
        this.recursiveGroupLocation = recursiveGroupLocation;
    }
    
    @Override
    public void forValRefs(ValRefVisitor visitor) {
        for(SSAClosure closure = firstClosure; closure != null; closure = closure.getNext())
            closure.forValRefs(visitor);    
    }

    @Override
    public void cleanup() {
        for(SSAClosure closure = firstClosure; closure != null; closure = closure.getNext())
            closure.cleanup();
    }
}
