package org.simantics.scl.compiler.internal.codegen.ssa.statements;

import java.util.ArrayList;
import java.util.Arrays;

import org.simantics.scl.compiler.common.exceptions.InternalCompilerError;
import org.simantics.scl.compiler.constants.Constant;
import org.simantics.scl.compiler.constants.SCLConstant;
import org.simantics.scl.compiler.internal.codegen.continuations.ContRef;
import org.simantics.scl.compiler.internal.codegen.references.BoundVar;
import org.simantics.scl.compiler.internal.codegen.references.Val;
import org.simantics.scl.compiler.internal.codegen.references.ValRef;
import org.simantics.scl.compiler.internal.codegen.ssa.SSABlock;
import org.simantics.scl.compiler.internal.codegen.ssa.SSAExit;
import org.simantics.scl.compiler.internal.codegen.ssa.SSAFunction;
import org.simantics.scl.compiler.internal.codegen.ssa.SSAStatement;
import org.simantics.scl.compiler.internal.codegen.ssa.binders.BoundVarBinder;
import org.simantics.scl.compiler.internal.codegen.ssa.binders.ClosureBinder;
import org.simantics.scl.compiler.internal.codegen.ssa.binders.ValRefBinder;
import org.simantics.scl.compiler.internal.codegen.ssa.exits.Jump;
import org.simantics.scl.compiler.internal.codegen.ssa.exits.Switch;
import org.simantics.scl.compiler.internal.codegen.utils.CopyContext;
import org.simantics.scl.compiler.internal.codegen.utils.MethodBuilder;
import org.simantics.scl.compiler.internal.codegen.utils.PrintingContext;
import org.simantics.scl.compiler.internal.codegen.utils.SSASimplificationContext;
import org.simantics.scl.compiler.internal.codegen.utils.SSAValidationContext;
import org.simantics.scl.compiler.internal.codegen.utils.ValRefVisitor;
import org.simantics.scl.compiler.top.SCLCompilerConfiguration;
import org.simantics.scl.compiler.types.TVar;
import org.simantics.scl.compiler.types.Type;
import org.simantics.scl.compiler.types.Types;
import org.simantics.scl.compiler.types.exceptions.MatchException;
import org.simantics.scl.compiler.types.util.MultiFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LetApply extends LetStatement implements ValRefBinder {
    private static final Logger LOGGER = LoggerFactory.getLogger(LetApply.class);
    
    private ValRef function;
    private ValRef[] parameters;
    Type effect;
    
    public LetApply(BoundVar target, Type effect, ValRef function, ValRef ... parameters) {
        super(target);
        if(SCLCompilerConfiguration.DEBUG) {
            if(effect == null)
                throw new InternalCompilerError();
            if(function.getBinding() == null)
                throw new InternalCompilerError();
        }
        this.setFunction(function);
        this.setParameters(parameters);
        this.effect = Types.canonical(effect);
    }

    public void push(MethodBuilder mb) {
        int oldLineNumber = mb.lineNumber(lineNumber);
        Val f = getFunction().getBinding();
        Val[] ps = ValRef.getBindings(getParameters());
        if(f instanceof Constant) {
            Constant cf = (Constant)f;
            Type returnType = cf.apply(mb, getFunction().getTypeParameters(), ps);
            if(Types.isBoxed(returnType))
                mb.unbox(target.getType());
        }
        else {            
            mb.push(f, f.getType());
            mb.pushBoxed(ps);
            mb.genericApply(ps.length);
            mb.unbox(target.getType());
        }
        mb.lineNumber(oldLineNumber);
    }
    
    @Override
    public void generateCode(MethodBuilder mb) {
        if(!target.generateOnFly) {
            mb.lineNumber(lineNumber);
            push(mb);
            mb.store(target);
        }
    }

    @Override
    public void toString(PrintingContext context) {
        if(/*target.getLabel() == null &&*/ determineGenerateOnFly())
            context.addInlineExpression(target, this);
        else
            toStringAux(context);
    }
    
    private void toStringAux(PrintingContext context) {
        context.indentation();
        context.append(target);
        context.append("(" + target.occurrenceCount() + ")");
        context.append(" = ");
        bodyToString(context);
        context.append('\n');
        
    }
    
    public void bodyToString(PrintingContext context) {
        if(context.getErrorMarker() == this)
            context.append("!> ");
        context.append("L" + lineNumber + ": ");
        if(hasEffect()) {
            context.append("<");
            context.append(effect);
            context.append("> ");
        }
        context.append(getFunction());
        for(ValRef parameter : getParameters()) {
            context.append(' ');
            context.append(parameter);
        }
    }
    
    @Override
    public String toString() {
        PrintingContext context = new PrintingContext();
        toStringAux(context);
        return context.toString();
    }

    @Override
    public void validate(SSAValidationContext context) {
        context.validate(target);
        if(target.getParent() != this)
            throw new InternalCompilerError();
        context.validate(function);
        if(function.getParent() != this)
            throw new InternalCompilerError();
        for(ValRef parameter : parameters) {
            context.validate(parameter);
            if(parameter.getParent() != this)
                throw new InternalCompilerError();
        }
        //if(parameters.length == 0)
        //    throw new InternalCompilerError();
        
        
        MultiFunction mFun;
        try {
            mFun = Types.matchFunction(getFunction().getType(), getParameters().length);
        } catch (MatchException e) {
            context.setErrorMarker(this);
            throw new InternalCompilerError();
        }
        for(int i=0;i<getParameters().length;++i)
            context.assertSubsumes(this, getParameters()[i].getType(), mFun.parameterTypes[i]);
        context.assertSubsumes(this, target.getType(), mFun.returnType);
        context.assertEqualsEffect(this, effect, mFun.effect);
    }

    @Override
    public void destroy() {
        getFunction().remove();
        for(ValRef parameter : getParameters())
            parameter.remove();
    }
    
    @Override
    public void simplify(SSASimplificationContext context) {
        if(target.hasNoOccurences() && !hasEffect()) {
            remove();
            context.markModified("LetApply.dead-let-statement");
            return;
        }
        // TODO this is quite heavy way for inlining constants
        for(int i=0;i<parameters.length;++i) {
            ValRef parameter = parameters[i];
            Val value = parameter.getBinding();
            if(!(value instanceof SCLConstant))
                continue;
            SCLConstant constant = (SCLConstant)value;
            if(constant.inlineArity != 0)
                continue;
            SSAFunction definition = constant.definition;
            SSABlock block = definition.getFirstBlock();
            if(block.getFirstStatement() != null || !(block.getExit() instanceof Jump))
                continue;
            Jump jump = (Jump)block.getExit();
            if(jump.getTarget().getBinding() != definition.getReturnCont())
                continue;
            if(jump.getParameter(0).getTypeParameters().length > 0)
                continue;
            parameter.replaceBy(jump.getParameter(0).getBinding());
        }
        Val functionVal = getFunction().getBinding();
        if(functionVal instanceof BoundVar) {
            BoundVarBinder parent_ = ((BoundVar)functionVal).parent;
            if(parent_ instanceof SSAFunction) {
                SSAFunction function = (SSAFunction)parent_;
                if(functionVal.hasMoreThanOneOccurences())
                    return;
                if(getParameters().length < function.getArity())
                    return;
                if(getParameters().length > function.getArity())
                    split(function.getArity());
                inline(function);
                function.detach();
                context.markModified("LetApply.beta-lambda");
            }
            else if(parent_ instanceof LetApply) {
                LetApply apply = (LetApply)parent_;
                if(apply.hasEffect())
                    return;
                boolean hasJustOneOccurence = !functionVal.hasMoreThanOneOccurences();
                if((hasJustOneOccurence && apply.getParent() == getParent()) ||
                        apply.isPartial()) {
                    if(hasJustOneOccurence) {
                        apply.detach();
                        setFunction(apply.getFunction());
                        setParameters(ValRef.concat(apply.getParameters(), getParameters()));
                    }
                    else {
                        setFunction(apply.getFunction().copy());
                        setParameters(ValRef.concat(ValRef.copy(apply.getParameters()), getParameters()));
                    }
                    context.markModified("LetApply.merge-applications");
                }
            }
            else if(parent_ instanceof SSABlock) {
                SSABlock parent = getParent();
                if(parent_ != parent)
                    return;
                if(parent.getFirstStatement() != this)
                    return;
                if(!parent.hasMoreThanOneOccurences())
                    return; // We stop here, because situation can be handled by better transformations
                if(functionVal.hasMoreThanOneOccurences())
                    return;
                // We have now the following situation:
                //    [c] ... f ... =
                //        x = f ... 
                // * this application is the only reference to f
                // * there are multiple references to [c]
                for(ContRef ref = parent.getOccurrence();ref != null; ref = ref.getNext())
                    if(!(ref.getParent() instanceof Jump))
                        return;
                
                // Finds the position of the functionVal in the parameter list of 
                // the parent block.
                int position;
                for(position=0;position<parent.getParameters().length;++position)
                    if(parent.getParameters()[position] == functionVal)
                        break;
                if(position == parent.getParameters().length)
                    throw new InternalCompilerError();
                
                // Do tranformation
                for(ContRef ref = parent.getOccurrence();ref != null; ref = ref.getNext()) {
                    Jump jump = (Jump)ref.getParent();
                    SSABlock block = jump.getParent();
                    
                    BoundVar newTarget = new BoundVar(target.getType());
                    block.addStatement(new LetApply(newTarget, effect, jump.getParameter(position), ValRef.copy(parameters)));
                    jump.setParameter(position, newTarget.createOccurrence());
                }
                
                parent.setParameter(position, target);
                remove();
                context.markModified("LetApply.hoist-apply");
            }
        }
        else if(functionVal instanceof Constant) {
            ((Constant)functionVal).inline(context, this);
        }
    }
    
    public boolean isPartial() {
        return parameters.length < function.getBinding().getEffectiveArity();
    }

    /**
     * Removes apply if it does not have parameters.
     */
    public void removeDegenerated() {
        if(getParameters().length == 0) {
            target.replaceBy(getFunction());
            getFunction().remove();
            detach();
        }
    }

    public boolean determineGenerateOnFly() {
        if(hasEffect())
            return false;
        ValRef ref = target.getOccurrence();
        if(ref == null || ref.getNext() != null)
            return false;
        Object parent = ref.getParent();
        if(parent instanceof SSAStatement) {
            if(((SSAStatement)parent).getParent() != getParent())
                return false;
        }
        else if(parent instanceof SSAExit) {
            if(((SSAExit)parent).getParent() != getParent())
                return false;
            if(parent instanceof Switch)
                return false;
        }
        else
            return false;
        return true;
    }
    
    @Override
    public void markGenerateOnFly() {        
        target.generateOnFly = determineGenerateOnFly();
    }

    public ValRef getFunction() {
        return function;
    }

    public void setFunction(ValRef function) {
        this.function = function;
        function.setParent(this);
    }

    public ValRef[] getParameters() {
        return parameters;
    }

    public void setParameters(ValRef[] parameters) {
        /*if(SCLCompilerConfiguration.DEBUG)
            if(parameters.length == 0)
                throw new InternalCompilerError();*/
        this.parameters = parameters;
        for(ValRef parameter : parameters)
            parameter.setParent(this);
    }
    
    @Override
    public SSAStatement copy(CopyContext context) {
        LetApply result = new LetApply(context.copy(target), 
                context.copyType(effect),
                context.copy(function), 
                context.copy(parameters));
        return result;
    }
    
    @Override
    public void replace(TVar[] vars, Type[] replacements) {
        target.replace(vars, replacements);
        function.replace(vars, replacements);
        effect = effect.replace(vars, replacements);
        for(ValRef parameter : parameters)
            parameter.replace(vars, replacements);
    }
    
    /**
     * Inlines the application by the given function.
     * It is assumed that the function has the same number
     * of parameters as this one and there are no other 
     * references to the function (copy function before
     * inlining if necessary).
     */
    public void inline(SSAFunction function) {
        if(function.getArity() != parameters.length)
            throw new InternalCompilerError();        
               
        SSABlock headBlock = getParent();
        SSAFunction thisFunction = headBlock.getParent();
        {
            SSAFunction curParent=thisFunction;
            while(true) {
                if(curParent == function)
                    return;
                ClosureBinder binder = curParent.getParent();
                if(binder == null)
                    break;
                curParent = binder.getParentFunction();
            }
        }
               
        /*System.out.println("--- INLINING -------------------------------");
        System.out.println(thisFunction);
        System.out.println("Function name: " + getFunction().getBinding());
        System.out.println("++++++++++++++++++++++++++++++++++++++++++++");
        System.out.println(function);   
        */
        
        if(this.function.getTypeParameters().length > 0) {
            /*if(function.getParent() != null) {
                PrintingContext pc = new PrintingContext();
                pc.append("----------------------------\n");
                function.getParent().getParentFunction().toString(pc);
                pc.append("\n----\n");
                function.toString(pc);
                pc.append("\n");
                pc.append(function.getTypeParameters());
                pc.append(" -> ");
                pc.append(this.function.getTypeParameters());
                System.out.println(pc.toString());
            }*/
            function.replace(function.getTypeParameters(), this.function.getTypeParameters());
        }

        if(getPrev() != null)
            getPrev().setAsLastStatement();
        else
            headBlock.removeStatements();
        
        // Create tail block
        SSABlock tailBlock = new SSABlock(new BoundVar[] {target});
        thisFunction.addBlock(tailBlock);
        {
            SSAStatement stat = getNext();
            while(stat != null) {
                SSAStatement temp = stat.getNext();
                tailBlock.addStatement(stat);
                stat = temp;
            }
        }
        tailBlock.setExit(headBlock.getExit());
        
        // Merge blocks        
        thisFunction.mergeBlocks(function);
        
        headBlock.setExit(new Jump(lineNumber, function.getFirstBlock().createOccurrence(), parameters));
        function.getReturnCont().replaceWith(tailBlock);

        this.function.remove();
        // Note: No need to remove or detach this statement anymore
        
        // TODO remove function
        /*
        System.out.println("============================================");
        System.out.println(thisFunction);
        */
    }

    @Override
    public void collectFreeVariables(SSAFunction parentFunction,
            ArrayList<ValRef> vars) {
        function.collectFreeVariables(parentFunction, vars);
        for(ValRef parameter : parameters)
            parameter.collectFreeVariables(parentFunction, vars);
    }
    
    @Override
    public void replaceByApply(ValRef valRef, Val newFunction, Type[] typeParameters, Val[] parameters) {
        if(function == valRef) {
            valRef.remove();
            setFunction(newFunction.createOccurrence(typeParameters));
            setParameters(ValRef.concat(ValRef.createOccurrences(parameters), this.parameters));
        }
        else
            super.replaceByApply(valRef, newFunction, typeParameters, parameters);
    }

    /**
     * Splits this application into two applications where the first has
     * the arity given as a parameter and the new application inserted 
     * after this statement has the rest of the parameters.
     */
    public void split(int arity) {
        if(arity == parameters.length)
            return;
        if(arity > parameters.length)
            throw new InternalCompilerError();
        ValRef[] firstHalf = arity == 0 ? ValRef.EMPTY_ARRAY : Arrays.copyOf(parameters, arity);
        ValRef[] secondHalf = arity == parameters.length ? ValRef.EMPTY_ARRAY : Arrays.copyOfRange(parameters, arity, parameters.length);
        BoundVar newVar;
        try {
            MultiFunction mfun = Types.matchFunction(function.getType(), arity);
            newVar = new BoundVar(mfun.returnType);
        } catch (MatchException e) {
            throw new InternalCompilerError();
        }
        LetApply newApply = new LetApply(target, effect, 
                newVar.createOccurrence(), secondHalf);
        newApply.insertAfter(this);
        effect = Types.NO_EFFECTS;
        setTarget(newVar);
        setParameters(firstHalf);
    }
    
    /**
     * True, if the application may have side effects.
     */
    public boolean hasEffect() {
        return effect != Types.NO_EFFECTS;
    }

    public void updateEffect() {
        try {
            MultiFunction mFun = Types.matchFunction(function.getType(), parameters.length);
            this.effect = mFun.effect;
        } catch (MatchException e) {
            throw new InternalCompilerError(e);
        }
    }
    
    @Override
    public void prepare(MethodBuilder mb) {
        function.getBinding().prepare(mb);
    }

    @Override
    public void forValRefs(ValRefVisitor visitor) {
        visitor.visit(function);
        for(ValRef parameter : parameters)
            visitor.visit(parameter);
    }

    @Override
    public void cleanup() {
        function.remove();
        for(ValRef parameter : parameters)
            parameter.remove();
    }
}
