package org.simantics.scl.compiler.constants;

import java.util.Arrays;

import org.cojen.classfile.TypeDesc;
import org.objectweb.asm.Opcodes;
import org.simantics.scl.compiler.common.exceptions.InternalCompilerError;
import org.simantics.scl.compiler.internal.codegen.references.Val;
import org.simantics.scl.compiler.internal.codegen.types.BTypes;
import org.simantics.scl.compiler.internal.codegen.utils.ClassBuilder;
import org.simantics.scl.compiler.internal.codegen.utils.CodeBuilderUtils;
import org.simantics.scl.compiler.internal.codegen.utils.Constants;
import org.simantics.scl.compiler.internal.codegen.utils.JavaNamingPolicy;
import org.simantics.scl.compiler.internal.codegen.utils.MethodBuilder;
import org.simantics.scl.compiler.internal.codegen.utils.MethodBuilderBase;
import org.simantics.scl.compiler.internal.codegen.utils.ModuleBuilder;
import org.simantics.scl.compiler.internal.codegen.utils.TransientClassBuilder;
import org.simantics.scl.compiler.runtime.MutableClassLoader;
import org.simantics.scl.compiler.top.SCLCompilerConfiguration;
import org.simantics.scl.compiler.types.TVar;
import org.simantics.scl.compiler.types.Type;
import org.simantics.scl.compiler.types.Types;
import org.simantics.scl.compiler.types.exceptions.MatchException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import gnu.trove.map.hash.THashMap;

public abstract class FunctionValue extends Constant {

    private static final Logger LOGGER = LoggerFactory.getLogger(FunctionValue.class);

    TVar[] typeParameters;
    Type returnType;
    protected Type[] parameterTypes;
    Type effect;
        
    public FunctionValue(TVar[] typeParameters, Type effect, Type returnType, Type ... parameterTypes) {
        super(Types.forAll(typeParameters, 
                Types.functionE(parameterTypes, effect, returnType)));
        this.typeParameters = typeParameters;
        this.returnType = returnType;
        this.parameterTypes = parameterTypes;
        this.effect = effect;
    }

    public Type getReturnType() {
        return returnType;
    }

    public Type[] getParameterTypes() {
        return parameterTypes;
    }
    
    public TVar[] getTypeParameters() {
        return typeParameters;
    }
    
    @Override
    public int getArity() {
        return parameterTypes.length;
    }

    @Override
    public void push(MethodBuilder mb) {
        apply(mb, Type.EMPTY_ARRAY);
    }
        
    @Override
    public Type apply(MethodBuilder mb, Type[] typeParameters, Val... parameters) {
        int arity = getArity();
        
        /*LOGGER.info("MONADIC APPLICATION " + this);
        LOGGER.info("    call arity: " + parameters.length);
        LOGGER.info("    func arity: " + arity);
        LOGGER.info("    func monadic: " + isMonadic());
        */
        if(parameters.length < arity) {
            ModuleBuilder moduleBuilder = mb.getModuleBuilder();            
            TypeDesc closureClass = moduleBuilder.getClosure(this, parameters.length);
            CodeBuilderUtils.constructRecord(closureClass, mb, parameterTypes, parameters);
            return Types.function(Arrays.copyOfRange(parameterTypes, parameters.length, parameterTypes.length), returnType);
        }
        else if(parameters.length > arity) {
            Type returnType = applyExact(mb, Arrays.copyOf(parameters, arity));
            mb.pushBoxed(Arrays.copyOfRange(parameters, arity, parameters.length));
            int remainingArity = parameters.length - arity;
            mb.genericApply(remainingArity);
            
            if(typeParameters.length > 0)
                returnType = returnType.replace(this.typeParameters, typeParameters);
            try {
                returnType = BTypes.matchFunction(returnType, remainingArity)[remainingArity];                
            } catch (MatchException e) {
                throw new InternalCompilerError("Tried to apply value of type " + returnType + " with " + remainingArity + " parameters.");
            }
            mb.unbox(returnType);
            return returnType;
        }
        else {
            return applyExact(mb, parameters);
        }
    }
        
    public abstract Type applyExact(MethodBuilder mb, Val[] parameters);
    
    @Override
    public int getEffectiveArity() {
        return parameterTypes.length;
    }
    

    @Override
    public Object realizeValue(TransientClassBuilder builder) {
        THashMap<Constant, Object> valueCache = builder.classLoader.getConstantCache();
        if(valueCache != null) {
            Object cachedResult = valueCache.get(this);
            if(cachedResult != null)
                return cachedResult;
        }
        
        int arity = getEffectiveArity();
        if(arity == 0)
            return super.realizeValue(builder);        
        
        String packageName = builder.classLoader.getFreshPackageName();
        String moduleName = packageName + "/Temp";
        JavaNamingPolicy policy = new JavaNamingPolicy(moduleName);
        ModuleBuilder moduleBuilder = new ModuleBuilder(policy, builder.javaTypeTranslator);
        
        ClassBuilder classFile;
        if(arity <= Constants.MAX_FUNCTION_PARAMETER_COUNT) {
            if(SCLCompilerConfiguration.TRACE_METHOD_CREATION)
                LOGGER.info("Create class " + policy.getModuleClassName());
            classFile = new ClassBuilder(moduleBuilder, Opcodes.ACC_PUBLIC, policy.getModuleClassName(), 
                    MethodBuilderBase.getClassName(Constants.FUNCTION_IMPL[arity]));
            classFile.setSourceFile("_SCL_FunctionValue");
            classFile.addDefaultConstructor();
    
            MethodBuilder mb =classFile.addMethod(Opcodes.ACC_PUBLIC, "apply", TypeDesc.OBJECT, Constants.OBJECTS[arity]);
            Val[] parameters = new Val[arity];
            for(int i=0;i<arity;++i)
                parameters[i] = new LocalVariableConstant(parameterTypes[i], mb.getParameter(i));
            prepare(mb);
            mb.box(applyExact(mb, parameters));
            mb.returnValue(TypeDesc.OBJECT);
            mb.finish();
        }
        else {
            if(SCLCompilerConfiguration.TRACE_METHOD_CREATION)
                LOGGER.info("Create class " + policy.getModuleClassName());
            classFile = new ClassBuilder(moduleBuilder, Opcodes.ACC_PUBLIC, policy.getModuleClassName(), 
                    MethodBuilderBase.getClassName(Constants.FUNCTION_N_IMPL));
            classFile.setSourceFile("_SCL_FunctionValue");
            
            // Constructor
            { 
                MethodBuilderBase mb = classFile.addConstructorBase(Opcodes.ACC_PUBLIC, Constants.EMPTY_TYPEDESC_ARRAY);
                mb.loadThis();
                mb.loadConstant(arity);
                mb.invokeConstructor(MethodBuilderBase.getClassName(Constants.FUNCTION_N_IMPL), new TypeDesc[] {TypeDesc.INT});
                mb.returnVoid();
                mb.finish();
            }
            
            // doApply
            MethodBuilder mb = classFile.addMethod(Opcodes.ACC_PUBLIC, "doApply", TypeDesc.OBJECT, 
                    new TypeDesc[] {TypeDesc.forClass(Object[].class)});
            Val[] parameters = new Val[arity];
            for(int i=0;i<arity;++i)
                parameters[i] = new LocalBoxedArrayElementConstant(parameterTypes[i], 
                        mb.getParameter(0), i);
            applyExact(mb, parameters);
            mb.box(returnType);
            mb.returnValue(TypeDesc.OBJECT);
            mb.finish();
        }
        
        /* Add a toString() method that returns the function name */
        MethodBuilder mb2 = classFile.addMethod(Opcodes.ACC_PUBLIC, "toString", TypeDesc.STRING, Constants.OBJECTS[0]);
        mb2.loadConstant(this.toString());
        mb2.returnValue(TypeDesc.STRING);
        mb2.finish();
        
        moduleBuilder.addClass(classFile);
        
        MutableClassLoader classLoader = builder.classLoader;
        classLoader.addClasses(moduleBuilder.getClasses());
        try {
            Object result = classLoader.loadClass(policy.getModuleClassName().replace('/', '.')).newInstance();
            if(valueCache != null) {
                valueCache.put(this, result);
                if(TRACE_REALIZATION)
                    LOGGER.info("/REALIZED/ " + this + " " + getClass().getSimpleName());
            }
            return result;
        } catch (InstantiationException e) {
            throw new InternalCompilerError(e);
        } catch (IllegalAccessException e) {
            throw new InternalCompilerError(e);
        } catch (ClassNotFoundException e) {
            throw new InternalCompilerError(e);
        }
    }
}
