package org.simantics.scl.compiler.internal.codegen.ssa;

import java.util.ArrayList;

import org.simantics.scl.compiler.common.exceptions.InternalCompilerError;
import org.simantics.scl.compiler.constants.Constant;
import org.simantics.scl.compiler.constants.NoRepConstant;
import org.simantics.scl.compiler.constants.SCLConstant;
import org.simantics.scl.compiler.internal.codegen.continuations.BranchRef;
import org.simantics.scl.compiler.internal.codegen.continuations.Cont;
import org.simantics.scl.compiler.internal.codegen.continuations.ContRef;
import org.simantics.scl.compiler.internal.codegen.references.BoundVar;
import org.simantics.scl.compiler.internal.codegen.references.Val;
import org.simantics.scl.compiler.internal.codegen.references.ValRef;
import org.simantics.scl.compiler.internal.codegen.ssa.binders.BoundVarBinder;
import org.simantics.scl.compiler.internal.codegen.ssa.exits.Jump;
import org.simantics.scl.compiler.internal.codegen.ssa.exits.Switch;
import org.simantics.scl.compiler.internal.codegen.ssa.statements.LetApply;
import org.simantics.scl.compiler.internal.codegen.utils.CopyContext;
import org.simantics.scl.compiler.internal.codegen.utils.MethodBuilder;
import org.simantics.scl.compiler.internal.codegen.utils.Printable;
import org.simantics.scl.compiler.internal.codegen.utils.PrintingContext;
import org.simantics.scl.compiler.internal.codegen.utils.SSALambdaLiftingContext;
import org.simantics.scl.compiler.internal.codegen.utils.SSASimplificationContext;
import org.simantics.scl.compiler.internal.codegen.utils.SSAUtils;
import org.simantics.scl.compiler.internal.codegen.utils.SSAValidationContext;
import org.simantics.scl.compiler.internal.codegen.utils.ValRefVisitor;
import org.simantics.scl.compiler.top.SCLCompilerConfiguration;
import org.simantics.scl.compiler.types.TVar;
import org.simantics.scl.compiler.types.Type;
import org.simantics.scl.compiler.types.Types;

public final class SSABlock extends Cont implements Printable, BoundVarBinder {
    public static final SSABlock[] EMPTY_ARRAY = new SSABlock[0];
    
    BoundVar[] parameters;
    
    SSAFunction parent;
    SSABlock prev;
    SSABlock next;
    SSAStatement firstStatement;
    SSAStatement lastStatement;
    SSAExit exit;
    
    public SSABlock(Type ... parameterTypes) {
        parameters = new BoundVar[parameterTypes.length];
        for(int i=0;i<parameterTypes.length;++i) {
            BoundVar parameter = new BoundVar(parameterTypes[i]); 
            parameters[i] = parameter;
            parameter.parent = this;
        }
    }
    
    public SSABlock(BoundVar[] parameters) {    
        this.parameters = parameters;
        for(BoundVar parameter : parameters)
            parameter.parent = this;
    }

    public SSAFunction getParent() {
        return parent;
    }
    
    public SSABlock getNext() {
        return next;
    }
    
    public SSABlock getPrev() {
        return prev;
    }
    
    public SSAExit getExit() {
        return exit;
    }
    
    public void removeStatements() {
        this.firstStatement = null;
        this.lastStatement = null; 
    }
    
    @Override
    public int getArity() {
        return parameters.length;
    }
    
    public int getStatementCount() {
        int count = 0;
        for(SSAStatement stat = firstStatement;stat != null;stat = stat.next)
            ++count;
        return count;
    }
    
    @Override
    public Type getParameterType(int parameterId) {
        return parameters[parameterId].getType();
    }
    
    public Type[] getParameterTypes() {
        return Types.getTypes(parameters);
    }
    
    public BoundVar[] getParameters() {
        return parameters;
    }
    
    public void setExit(SSAExit exit) {
        this.exit = exit;
        exit.parent = this;
    }

    public void detach() {
        if(prev == null)
            parent.firstBlock = next;
        else
            prev.next = next;
        if(next == null)
            parent.lastBlock = prev;
        else
            next.prev = prev;
        if(parent.firstBlock == null)
            throw new InternalCompilerError();
    }

    public void remove() {
        detach();
        destroy();
    }
    
    public void destroy() {
        for(SSAStatement statement = firstStatement; statement != null; statement = statement.next)
            statement.destroy();
        exit.destroy();
    }

    public void addStatement(SSAStatement stat) {
        if(SCLCompilerConfiguration.DEBUG) {
            if((firstStatement == null) != (lastStatement == null))
                throw new InternalCompilerError();
        }
        stat.parent = this;
        if(lastStatement == null) {
            firstStatement = lastStatement = stat;
            stat.next = stat.prev = null;            
        }
        else {
            lastStatement.next = stat;
            stat.prev = lastStatement;
            stat.next = null;
            lastStatement = stat;
        }
    }
    
    public void generateCode(MethodBuilder mb) {
        mb.setLocation(this);
        for(SSAStatement stat = firstStatement; stat != null; stat = stat.next)
            stat.generateCode(mb);
        exit.generateCode(mb);
    }

    public void toString(PrintingContext context) {
        context.indentation();
        context.append(this);
        context.append("(" + occurrenceCount() + ")");
        parametersToString(context);
        context.append(" =\n");
        bodyToString(context);
    }
    
    public void parametersToString(PrintingContext context) {
        for(BoundVar parameter : parameters) {
            context.append(' ');
            if(parameter.hasNoOccurences()) {
                context.append('(');
                context.append('_');
                context.append(" :: ");
                context.append(parameter.getType());
                context.append(')');
            }
            else {
                context.append('(');
                context.append(parameter);
                context.append(" :: ");
                context.append(parameter.getType());
                context.append(')');
            }
        }
    }
    
    public void bodyToString(PrintingContext context) {
        context.indent();
        for(SSAStatement statement = firstStatement; statement != null; statement = statement.next)
            statement.toString(context);
        context.indentation();
        exit.toString(context);
        context.dedent();
    }

    public void validate(SSAValidationContext context) {
        if(exit.getParent() != this)
            throw new InternalCompilerError();
        for(BoundVar parameter : parameters) {         
            context.validate(parameter);
            if(parameter.parent != this)
                throw new InternalCompilerError();
        }
        for(SSAStatement statement = firstStatement; statement != null; statement = statement.next) {
            if(statement.getParent() != this)
                throw new InternalCompilerError();
            statement.validate(context);
        }
        exit.validate(context);
                
        {
            SSAStatement last = firstStatement;
            if(last != null) {
                while(last.next != null)
                    last = last.next;
            }
            if(last != lastStatement)
                throw new InternalCompilerError();
        }
    }

    public void simplify(SSASimplificationContext context) {
        if(hasNoOccurences() && parent.firstBlock != this) {
            remove();
            context.markModified("dead-block");
            return;
        }
        
        tryToImproveParameters(context);
        
        // Simplify statements and exit
        for(SSAStatement statement = firstStatement; statement != null; statement = statement.next)
            statement.simplify(context);
        exit.simplify(context);
        
        // Simplifications to this block
        if(exit instanceof Switch) {
            if(simplifySwitch()) {
                context.markModified("beta-switch");
            }
        }
        if(exit instanceof Jump) {
            if(firstStatement == null && parent.firstBlock != this) {
                if(etaBlock(context)) {
                    context.markModified("eta-block");
                    return;
                }
                else if(inlineJump()) {
                    context.markModified("beta-block");
                    return;
                }
            }
            else {
                if(optimizeTailSelfCall()) {
                    context.markModified("simplify-tail-call");
                    return;
                }
                else if(inlineJump()) {
                    context.markModified("beta-block");
                    return;
                }
            }
        }        
    }

    private void tryToImproveParameters(SSASimplificationContext context) {
        if(parent.firstBlock == this)
            return;
        if(parameters.length == 0)
            return;
        for(ContRef ref = getOccurrence(); ref != null; ref = ref.getNext())
            if(!(ref.getParent() instanceof Jump))
                return;
        boolean modified = false;
        for(int i=0;i<parameters.length;++i)
            if(tryToImproveParameter(i)) {
                --i;
                modified = true;
            }
        if(modified)
            context.markModified("improve-parameters");
    }

    private static Constant getOnlyPossibleValue(Type type) {
        type = Types.canonical(type);
        if(type == Types.UNIT)
            return NoRepConstant.UNIT;
        else if(type == Types.PUNIT)
            return NoRepConstant.PUNIT;
        return null; 
    }
    
    private boolean tryToImproveParameter(int position) {
        BoundVar parameter = parameters[position];
        Constant onlyPossibleValue = getOnlyPossibleValue(parameter.getType());
        if(onlyPossibleValue == null) {
            Val constant = null;
            ValRef constantRef = null;
            for(ContRef ref = getOccurrence(); ref != null; ref = ref.getNext()) {
                Jump jump = (Jump)ref.getParent();
                ValRef valRef = jump.getParameters()[position];
                Val val = valRef.getBinding();
                if(val == parameter)
                    continue;
                if(constant == null) {
                    constant = val;
                    constantRef = valRef;
                    continue;
                }
                if(val != constant)
                    return false;
            }
            if(constant == null)
                return false; // This is a strange case, because we cannot get the parameter anywhere
            parameter.replaceBy(constantRef);
        }
        else {
            parameter.replaceBy(onlyPossibleValue);
        }
        
        for(ContRef ref = getOccurrence(); ref != null; ref = ref.getNext()) {
            Jump jump = (Jump)ref.getParent();
            jump.setParameters(removeAt(jump.getParameters(), position));
        }
        
        parameters = removeAt(parameters, position);
        return true;
    }
    
    private static BoundVar[] removeAt(BoundVar[] vars, int pos) {
        BoundVar[] result = new BoundVar[vars.length-1];
        for(int i=0;i<pos;++i)
            result[i] = vars[i];
        for(int i=pos+1;i<vars.length;++i)
            result[i-1] = vars[i];
        return result;
    }

    private static ValRef[] removeAt(ValRef[] vars, int pos) {
        ValRef[] result = new ValRef[vars.length-1];
        for(int i=0;i<pos;++i)
            result[i] = vars[i];
        vars[pos].remove();
        for(int i=pos+1;i<vars.length;++i)
            result[i-1] = vars[i];
        return result;
    }
    
    /*
     * This method assumes that the exit of the block is Jump.
     */
    private boolean optimizeTailSelfCall() {
        // The last statement of the block is LetApply that calls the parent function with right number of parameters 
        if(lastStatement == null || !(lastStatement instanceof LetApply))
            return false;
        LetApply apply = (LetApply)lastStatement;
        Val function = apply.getFunction().getBinding();
        if(function != parent.target)
            return false;
        SSABlock initialBlock = parent.firstBlock;
        if(initialBlock.parameters.length != apply.getParameters().length)
            return false;

        // The jump is a return (with one parameter)
        // The parameter of the jump is the target of LetApply
        Jump jump = (Jump)exit;
        Cont targetCont = jump.getTarget().getBinding();
        if(targetCont != parent.returnCont) {
            SSABlock targetBlock = (SSABlock)targetCont;
            if(targetBlock.firstStatement != null)
                return false;
            if(!(targetBlock.exit instanceof Jump))
                return false;
            Jump targetJump = (Jump)targetBlock.exit;
            if(targetJump.getTarget().getBinding() != parent.returnCont)
                return false;
            if(targetJump.getParameters().length != 1)
                return false;
            
            BoundVar applyTarget = apply.getTarget();
            ValRef targetJumpParameter = targetJump.getParameter(0);
            isSameParam: if(!SSAUtils.representSameValue(applyTarget, targetJumpParameter)) {
                BoundVar[] targetBlockParameters = targetBlock.getParameters();
                for(int i=0;i<targetBlockParameters.length;++i) {
                    if(targetJumpParameter.getBinding() == targetBlockParameters[i]
                            && jump.getParameter(i).getBinding() == applyTarget)
                        break isSameParam;
                }
                return false;
            }
        }
        else {
            if(jump.getParameters().length != 1)
                return false;
            if(!SSAUtils.representSameValue(apply.getTarget(), jump.getParameter(0)))
                return false;
        }
        
        // Do modifications
        apply.detach();
        apply.getFunction().remove();
        jump.getTarget().remove();
        jump.setTarget(initialBlock.createOccurrence());
        for(ValRef parameter : jump.getParameters())
            parameter.remove();
        jump.setParameters(apply.getParameters());
        
        return true;
    }

    /**
     * Assumes that this block has no statements, the block is not the first block
     * and the exit of the block is Jump.
     */
    private boolean etaBlock(SSASimplificationContext context) {
        Jump jump = (Jump)exit;
        if(parameters.length != jump.getParameters().length)
            return false;
        for(int i=0;i<parameters.length;++i)
            if(parameters[i] != jump.getParameters()[i].getBinding() ||
               parameters[i].hasMoreThanOneOccurences())
                return false;
        
        replaceWith(jump.getTarget().getBinding());
        remove();
        return true;
    }
    
    private boolean simplifySwitch() {
        Switch sw = (Switch)exit;
        ValRef scrutineeRef = sw.getScrutinee();
        Val scrutinee = scrutineeRef.getBinding();
        if(scrutinee instanceof BoundVar) {
            BoundVarBinder parent = ((BoundVar)scrutinee).parent;
            if(!(parent instanceof LetApply))
                return false;
            LetApply apply = (LetApply)parent;
            Val function = apply.getFunction().getBinding();
            if(!(function instanceof Constant) || function instanceof SCLConstant)
                return false;
            for(BranchRef branch : sw.getBranches()) {
                if(branch.constructor == function) {
                    sw.destroy();
                    setExit(new Jump(sw.lineNumber, branch.cont.getBinding().createOccurrence(), 
                            ValRef.copy(apply.getParameters())));
                    return true;
                }
            }
        }
        else if(scrutinee instanceof Constant) {
            if(sw.getBranches().length == 1) {
                BranchRef branch = sw.getBranches()[0];
                if(branch.constructor == scrutinee) {
                    /**
                     * Optimizes for example
                     *      switch ()
                     *          () -> [a]
                     * ===>
                     *      [a]
                     */
                    sw.destroy();
                    setExit(new Jump(sw.lineNumber, branch.cont.getBinding().createOccurrence()));
                }
            }
        }
        return false;
    }
    
    private boolean inlineJump() {
        Jump jump = (Jump)exit;
        Cont target = jump.getTarget().getBinding();
        if(!(target instanceof SSABlock))
            return false;
        if(target.hasMoreThanOneOccurences())
            return false;
        SSABlock block = (SSABlock)target;
        if(block == parent.firstBlock || block == this)
            return false;
        
        /*System.out.println(">> BEFORE INLINE >>");
        System.out.println(getParent());
        System.out.println(">> THIS BLOCK >>");
        System.out.println(this);
        System.out.println(">> TARGET BLOCK >>");
        System.out.println(block);*/
                
        mergeStatements(block);

        for(int i=0;i<jump.getParameters().length;++i)
            block.parameters[i].replaceBy(jump.getParameters()[i]);
        block.detach();
        
        jump.destroy();
        setExit(block.exit);
        
        /*System.out.println(">> AFTER INLINE >>");
        System.out.println(getParent());
        System.out.println(">> THIS BLOCK >>");
        System.out.println(this);
        System.out.println(">>>>>>>>>>>>>>>>>>");*/
        return true;
    }
    
    private void mergeStatements(SSABlock block) {
        if(SCLCompilerConfiguration.DEBUG) {
            SSAStatement last = firstStatement;
            if(last != null) {
                while(last.next != null)
                    last = last.next;
            }
            if(last != lastStatement)
                throw new InternalCompilerError();
        }
        SSAStatement stat = block.firstStatement;
        while(stat != null) {
            SSAStatement next = stat.next;
            addStatement(stat);
            stat = next;
        }
    }

    @Override
    public String toString() {
        PrintingContext context = new PrintingContext();
        toString(context);
        return context.toString();
    }

    public void markGenerateOnFly() {
        for(SSAStatement stat = firstStatement; stat != null; stat = stat.next)
            stat.markGenerateOnFly();
    }
    
    public SSABlock copy(CopyContext context) {
        SSABlock newBlock = new SSABlock(context.copy(parameters));
        context.put(this, newBlock);
        for(SSAStatement statement = firstStatement;
                statement != null; statement = statement.next)
            newBlock.addStatement(statement.copy(context));
        newBlock.setExit(exit.copy(context)); 
        return newBlock;
    }

    public void setParameters(BoundVar[] parameters) {
        for(BoundVar parameter : parameters)
            parameter.parent = this;
        this.parameters = parameters;
    }

    @Override
    public void replace(TVar[] vars, Type[] replacements) {
        for(BoundVar parameter : parameters)
            parameter.replace(vars, replacements);
        for(SSAStatement statement = firstStatement;
                statement != null; statement = statement.next)
            statement.replace(vars, replacements);
        exit.replace(vars, replacements);
    }

    public void collectFreeVariables(SSAFunction function, ArrayList<ValRef> vars) {
        for(SSAStatement statement = firstStatement;
                statement != null; statement = statement.next)
            statement.collectFreeVariables(function, vars);
        exit.collectFreeVariables(function, vars);
    }

    @Override
    public SSAFunction getParentFunction() {
        return parent;
    }

    public void lambdaLift(SSALambdaLiftingContext context) {
        for(SSAStatement statement = firstStatement;
                statement != null; statement = statement.next)
            statement.lambdaLift(context);
    }
    
    public SSAStatement getFirstStatement() {
        return firstStatement;
    }

    public void setParameter(int position, BoundVar target) {
        parameters[position] = target;
        target.parent = this;
    }

    public void prepare(MethodBuilder mb) {
        for(SSAStatement stat = firstStatement; stat != null; stat = stat.next)
            stat.prepare(mb);
        exit.prepare(mb);
    }

    public void forValRefs(ValRefVisitor visitor) {
        for(SSAStatement statement = firstStatement;
                statement != null; statement = statement.next)
            statement.forValRefs(visitor);
        exit.forValRefs(visitor);
    }

    public void cleanup() {
        for(SSAStatement statement = firstStatement;
                statement != null; statement = statement.next)
            statement.cleanup();
        exit.cleanup();
    }
    
}
