package org.simantics.scl.compiler.internal.codegen.ssa;

import java.util.ArrayList;
import java.util.Arrays;

import org.cojen.classfile.TypeDesc;
import org.simantics.scl.compiler.common.exceptions.InternalCompilerError;
import org.simantics.scl.compiler.constants.NoRepConstant;
import org.simantics.scl.compiler.internal.codegen.continuations.Cont;
import org.simantics.scl.compiler.internal.codegen.continuations.ContRef;
import org.simantics.scl.compiler.internal.codegen.continuations.ReturnCont;
import org.simantics.scl.compiler.internal.codegen.references.BoundVar;
import org.simantics.scl.compiler.internal.codegen.references.Val;
import org.simantics.scl.compiler.internal.codegen.references.ValRef;
import org.simantics.scl.compiler.internal.codegen.ssa.exits.Jump;
import org.simantics.scl.compiler.internal.codegen.ssa.statements.LetApply;
import org.simantics.scl.compiler.internal.codegen.ssa.statements.LetFunctions;
import org.simantics.scl.compiler.internal.codegen.types.JavaTypeTranslator;
import org.simantics.scl.compiler.internal.codegen.utils.CopyContext;
import org.simantics.scl.compiler.internal.codegen.utils.MethodBuilder;
import org.simantics.scl.compiler.internal.codegen.utils.PrintingContext;
import org.simantics.scl.compiler.internal.codegen.utils.SSALambdaLiftingContext;
import org.simantics.scl.compiler.internal.codegen.utils.SSASimplificationContext;
import org.simantics.scl.compiler.internal.codegen.utils.SSAValidationContext;
import org.simantics.scl.compiler.internal.codegen.utils.ValRefVisitor;
import org.simantics.scl.compiler.types.TVar;
import org.simantics.scl.compiler.types.Type;
import org.simantics.scl.compiler.types.Types;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class SSAFunction extends SSAClosure {
    private static final Logger LOGGER = LoggerFactory.getLogger(SSAFunction.class);

    TVar[] typeParameters;
    Type effect;
    SSABlock firstBlock;
    SSABlock lastBlock;
    ReturnCont returnCont;
    
    public SSAFunction(TVar[] typeParameters, Type effect, Type returnType) {
        this(typeParameters, effect, new ReturnCont(returnType));
    }
        
    public SSAFunction(TVar[] typeParameters, Type effect, ReturnCont returnCont) {
        this.typeParameters = typeParameters;
        this.returnCont = returnCont;
        this.effect = Types.canonical(effect);
        returnCont.setParent(this);
    }
    
    public boolean hasEffect() {
        return effect != Types.NO_EFFECTS;
    }
    
    public void addBlock(SSABlock block) {
        block.parent = this;
        if(lastBlock == null) {
            firstBlock = lastBlock = block;
            block.next = block.prev = null;            
        }
        else {
            lastBlock.next = block;
            block.prev = lastBlock;
            block.next = null;
            lastBlock = block;
        }
    }
    
    public void addBlockInFront(SSABlock block) {
        block.parent = this;
        if(firstBlock == null) {
            firstBlock = lastBlock = block;
            block.next = block.prev = null;            
        }
        else {
            firstBlock.prev = block;
            block.next = firstBlock;
            block.prev = null;
            firstBlock = block;
        }
    }

    public ReturnCont getReturnCont() {
        return returnCont;
    }
    
    public TVar[] getTypeParameters() {
        return typeParameters;
    }
    
    public SSABlock getFirstBlock() {
        return firstBlock;
    }
    
    public void generateCode(MethodBuilder mb) {
        JavaTypeTranslator tt = mb.getJavaTypeTranslator();
        for(int i=0,j=0;i<firstBlock.parameters.length;++i)
            if(!tt.toTypeDesc(firstBlock.parameters[i].getType()).equals(TypeDesc.VOID))
                mb.setLocalVariable(firstBlock.parameters[i], mb.getParameter(j++));
        generateCodeWithAlreadyPreparedParameters(mb);
    }
    
    public void generateCodeWithAlreadyPreparedParameters(MethodBuilder mb) {
        for(SSABlock block = firstBlock; block != null; block = block.next)
            block.prepare(mb);
        firstBlock.generateCode(mb);
    }

    public void toString(PrintingContext context) {
        context.indentation();
        if(typeParameters.length > 0) {            
            context.append('<');
            boolean first = true;
            for(TVar var : typeParameters) {
                if(first)
                    first = false;
                else
                    context.append(',');
                context.append(var);
            }
            context.append("> ");
        }
        if(hasEffect()) {
            context.append(effect);
            context.append(" ");
        }
        context.append("RETURNS ");
        context.append(returnCont.getType());
        context.append('\n');
        
        // Print blocks
        context.pushBlockQueue();
        context.addBlock(getFirstBlock());
        while(true) {
            SSABlock block = context.pollBlock();
            if(block == null)
                break;
            block.toString(context);
        }
        context.popBlockQueue();
    }
    
    @Override
    public String toString() {
        PrintingContext context = new PrintingContext();
        toString(context);
        return context.toString();
    }
    
    public void validate(SSAValidationContext context) {
        if(target instanceof BoundVar && ((BoundVar)target).parent != this)
            throw new InternalCompilerError();
        
        // Add valid type variables
        for(TVar var : typeParameters)
            context.validTypeVariables.add(var);
        
        // Add valid variables and continuations
        context.validContinuations.add(returnCont);        
        for(SSABlock block = firstBlock; block != null; block = block.next) {
            context.validContinuations.add(block);  
            for(BoundVar parameter : block.parameters)
                context.validBoundVariables.add(parameter);
            for(SSAStatement stat = block.firstStatement; stat != null; stat = stat.next)
                stat.addBoundVariablesTo(context);
        }

        // Validate blocks
        for(SSABlock block = firstBlock; block != null; block = block.next)
            block.validate(context);
        context.validate(returnCont);
        
        //context.reset(); // FIXME not good when there are nested functions
    }
    
    @Override
    public void simplify(SSASimplificationContext context) {
        for(SSABlock block = firstBlock; block != null; block = block.next)
            block.simplify(context);
        if(firstBlock == lastBlock && firstBlock.firstStatement == firstBlock.lastStatement) {
            if(firstBlock.firstStatement instanceof LetApply)
                simplifySingleApply(context);
            else if(firstBlock.firstStatement instanceof LetFunctions)
                simplifySingleLambda(context);
        }
    }
    
    
    /**
     * Simplifies the following kind of function definition
     *     \x -> f x
     * to
     *     f
     */
    private void simplifySingleApply(SSASimplificationContext context) {
        if(!(parent instanceof LetFunctions) || parent.getFirstClosure().next != null)
            return;
        LetApply apply = (LetApply)firstBlock.firstStatement;
        if(!(firstBlock.exit instanceof Jump))
            return;
        Jump exit = (Jump)firstBlock.exit;
        if(exit.getTarget().getBinding() != returnCont)
            return;
        if(exit.getParameter(0).getBinding() != apply.getTarget())
            return;
        BoundVar[] functionParameters = getParameters();
        ValRef[] applyParameters = apply.getParameters();
        if(functionParameters.length > applyParameters.length)
            return;
        int extraApplyParameters = applyParameters.length - functionParameters.length;
        for(int i=0;i<functionParameters.length;++i)
            if(!representSameValues(functionParameters[i], applyParameters[extraApplyParameters+i]))
                return;
        for(int i=0;i<extraApplyParameters;++i) {
            Val b = applyParameters[i].getBinding();
            if(b instanceof BoundVar) {
                BoundVar bv = (BoundVar)b;
                if(bv == target || bv.getParent() == firstBlock)
                    return;
            }
        }
        if(!(target instanceof BoundVar))
            return;
        
        // Transform
        
        LetFunctions binder = (LetFunctions)parent;
        SSAFunction parentFunction = binder.getParentFunction();
        if(extraApplyParameters > 0) {
            //System.out.println("-------------------------------------------------------------");
            //System.out.println(parentFunction);
            //System.out.println("-------------------------------------------------------------");
            apply.setTarget((BoundVar)target);
            apply.setParameters(Arrays.copyOf(applyParameters, extraApplyParameters));
            apply.insertBefore(binder);
            binder.detach();
            //System.out.println(parentFunction);
            //System.out.println("-------------------------------------------------------------");
        }
        else {
            binder.detach();
            ((BoundVar)target).replaceBy(apply.getFunction());
        }
        context.markModified("SSAFunction.eta-reduce");
    }
    
    private boolean representSameValues(BoundVar boundVar, ValRef valRef) {
        Val val = valRef.getBinding(); 
        if(val == boundVar && valRef.getTypeParameters().length == 0)
            return true;
        if(val instanceof NoRepConstant && Types.equals(valRef.getType(), boundVar.getType()))
            return true;
        return false;
    }

    /**
     * Simplifies the following kind of function definition
     *     \x -> \y -> e
     * to
     *     \x y -> e
     */
    private void simplifySingleLambda(SSASimplificationContext context) {
        LetFunctions letF = (LetFunctions)firstBlock.firstStatement;
        if(!(letF.getFirstClosure() instanceof SSAFunction))
            return;
        SSAFunction f = (SSAFunction)letF.getFirstClosure();
        if(f.getNext() != null)
            return;
        Val fVal = f.getTarget();
        if(!firstBlock.exit.isJump(getReturnCont(), fVal))
            return;
        if(fVal.hasMoreThanOneOccurences())
            return; // Possible if function is recursive and refers to itself
        if(hasEffect())
            return; // Not probably possible (?)
                
        // Transform
        for(SSABlock block = f.firstBlock; block != null; block = block.next)
            block.parent = this;
        lastBlock.next = f.firstBlock;
        f.firstBlock.prev = lastBlock;
        lastBlock = f.lastBlock;
        
        firstBlock.firstStatement = firstBlock.lastStatement = null;
        setReturnCont(f.getReturnCont());
        effect = f.effect;
        BoundVar[] newParameters = BoundVar.copy(f.firstBlock.parameters);
        firstBlock.setParameters(BoundVar.concat(getParameters(), newParameters));
        firstBlock.setExit(new Jump(-1, f.firstBlock.createOccurrence(), ValRef.createOccurrences(newParameters)));
        context.markModified("SSAFunction.simplify-simple-lambda");
    }

    public void setReturnCont(ReturnCont returnCont) {
        this.returnCont = returnCont;
        returnCont.setParent(this);
    }

    public ValRef isEqualToConstant() {
        if(firstBlock.parameters.length > 0)
            return null;
        if(firstBlock != lastBlock)
            return null;
        if(firstBlock.firstStatement != null)
            return null;
        if(!(firstBlock.exit instanceof Jump))
            return null;
        Jump exit = (Jump)firstBlock.exit;
        if(exit.getTarget().getBinding() != returnCont)
            return null;
        return exit.getParameters()[0];
    }    

    public BoundVar[] getParameters() {
        return firstBlock.parameters;                
    }
    
    public Type[] getParameterTypes() {
        return Types.getTypes(firstBlock.parameters);                
    }

    public int getArity() {
        return firstBlock.parameters.length;
    }

    @Override
    public void markGenerateOnFly() {
        for(SSABlock block = firstBlock; block != null; block = block.next)
            block.markGenerateOnFly();
    }

    @Override
    public SSAClosure copy(CopyContext context) {
        TVar[] newTypeParameters = context.copyParameters(typeParameters);
        SSAFunction newFunction = new SSAFunction(newTypeParameters, effect, context.copy(returnCont));
        for(SSABlock block = firstBlock;
                block != null; block = block.next)
            newFunction.addBlock(context.copy(block));
        return newFunction;
    }
    
    @Override
    public void replace(TVar[] vars, Type[] replacements) {
        returnCont.replace(vars, replacements);
        for(SSABlock block = firstBlock;
                block != null; block = block.next)
            block.replace(vars, replacements);
    }

    public void setTypeParameters(TVar[] typeParameters) {
        this.typeParameters = typeParameters;
    }

    @Override
    public Type getType() {
        Type type = returnCont.getType();
        type = Types.functionE(
                Types.getTypes(firstBlock.parameters),
                effect,
                type);
        type = Types.forAll(typeParameters, type);
        return type;
    }

    public void mergeBlocks(SSAFunction function) {
        if(this == function)
            throw new InternalCompilerError();
        SSABlock block = function.firstBlock;
        while(block != null) {
            SSABlock next = block.next;
            addBlock(block);
            block = next;
        }
    }

    public Type getReturnType() {
        return returnCont.getType();
    }

    @Override
    public void destroy() {
        for(SSABlock block = firstBlock;
                block != null; block = block.next)
            block.destroy();
    }
    
    @Override    
    public void collectFreeVariables(ArrayList<ValRef> vars) {
        for(SSABlock block = firstBlock;
                block != null; block = block.next)
            block.collectFreeVariables(this, vars);
    }
    
    @Override    
    public void lambdaLift(SSALambdaLiftingContext context) {
        for(SSABlock block = firstBlock;
                block != null; block = block.next)
            block.lambdaLift(context);
    }

    @Override
    public void parametrize(BoundVar[] parameters) {
        Cont proxy = null;
        for(ContRef ref = firstBlock.getOccurrence(); ref != null; ref = ref.getNext())
            proxy = ref.addParametersInFront(parameters, firstBlock.parameters, proxy);
        
        firstBlock.parameters = BoundVar.concat(parameters, firstBlock.parameters);
        for(BoundVar parameter : parameters)
            parameter.parent = firstBlock;
    }
    
    public void apply(int lineNumber, ValRef[] parameters) {
        if(parameters.length == 0)
            return;
        if(firstBlock.hasNoOccurences()) {
            BoundVar[] vars = firstBlock.getParameters();
            for(int i=0;i<parameters.length;++i)
                vars[i].replaceBy(parameters[i]);
            firstBlock.setParameters(Arrays.copyOfRange(vars, parameters.length, vars.length));
        }
        else {
            BoundVar[] newVars = new BoundVar[getArity()-parameters.length];
            SSABlock block = new SSABlock(newVars);
            block.setExit(new Jump(lineNumber, firstBlock.createOccurrence(), 
                    ValRef.concat(ValRef.copy(parameters), ValRef.createOccurrences(newVars))));
            addBlockInFront(block);
        }
    }

    public void applyTypes(Type[] types) {
        if(types.length == 0)
            return;
        if(types.length == typeParameters.length) {
            replace(typeParameters, types);
            typeParameters = TVar.EMPTY_ARRAY;
        }
        else {
            replace(Arrays.copyOf(typeParameters, types.length), types);
            typeParameters = 
                    Arrays.copyOfRange(typeParameters, 
                            types.length, typeParameters.length);
        }            
    }
    
    public boolean isSimpleEnoughForInline() {
        return firstBlock == lastBlock && 
                (firstBlock.firstStatement == null || 
                (firstBlock.firstStatement == firstBlock.lastStatement
                && firstBlock.firstStatement instanceof LetApply));
    }

    public Type getEffect() {
        return effect;
    }

    @Override
    public boolean isValue() {
        return getArity() == 0;
    }

    @Override
    public void forValRefs(ValRefVisitor visitor) {
        for(SSABlock block = firstBlock;
                block != null; block = block.next)
            block.forValRefs(visitor);
    }

    @Override
    public void cleanup() {
        for(SSABlock block = firstBlock; block != null; block = block.next)
            block.cleanup();
    }
}
