package org.simantics.scl.compiler.compilation;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.Map;

import org.cojen.classfile.TypeDesc;
import org.objectweb.asm.Opcodes;
import org.simantics.scl.compiler.common.datatypes.Constructor;
import org.simantics.scl.compiler.common.exceptions.InternalCompilerError;
import org.simantics.scl.compiler.common.names.Name;
import org.simantics.scl.compiler.constants.LocalFieldConstant;
import org.simantics.scl.compiler.constants.LocalVariableConstant;
import org.simantics.scl.compiler.constants.NoRepConstant;
import org.simantics.scl.compiler.constants.SCLConstant;
import org.simantics.scl.compiler.constants.ThisConstant;
import org.simantics.scl.compiler.elaboration.contexts.SimplificationContext;
import org.simantics.scl.compiler.elaboration.expressions.Expression;
import org.simantics.scl.compiler.elaboration.macros.StandardMacroRule;
import org.simantics.scl.compiler.elaboration.modules.DerivedProperty;
import org.simantics.scl.compiler.elaboration.modules.InlineProperty;
import org.simantics.scl.compiler.elaboration.modules.MethodImplementation;
import org.simantics.scl.compiler.elaboration.modules.PrivateProperty;
import org.simantics.scl.compiler.elaboration.modules.SCLValue;
import org.simantics.scl.compiler.elaboration.modules.SCLValueProperty;
import org.simantics.scl.compiler.elaboration.modules.TypeClass;
import org.simantics.scl.compiler.elaboration.modules.TypeClassInstance;
import org.simantics.scl.compiler.elaboration.modules.TypeClassMethod;
import org.simantics.scl.compiler.errors.ErrorLog;
import org.simantics.scl.compiler.errors.Locations;
import org.simantics.scl.compiler.internal.codegen.references.IVal;
import org.simantics.scl.compiler.internal.codegen.references.Val;
import org.simantics.scl.compiler.internal.codegen.ssa.SSAModule;
import org.simantics.scl.compiler.internal.codegen.types.JavaReferenceValidator;
import org.simantics.scl.compiler.internal.codegen.types.JavaTypeTranslator;
import org.simantics.scl.compiler.internal.codegen.types.StandardTypeConstructor;
import org.simantics.scl.compiler.internal.codegen.utils.ClassBuilder;
import org.simantics.scl.compiler.internal.codegen.utils.CodeBuilderUtils;
import org.simantics.scl.compiler.internal.codegen.utils.CodeBuildingException;
import org.simantics.scl.compiler.internal.codegen.utils.Constants;
import org.simantics.scl.compiler.internal.codegen.utils.MethodBuilder;
import org.simantics.scl.compiler.internal.codegen.utils.MethodBuilderBase;
import org.simantics.scl.compiler.internal.codegen.utils.ModuleBuilder;
import org.simantics.scl.compiler.internal.codegen.writer.CodeWriter;
import org.simantics.scl.compiler.internal.codegen.writer.ExternalConstant;
import org.simantics.scl.compiler.internal.codegen.writer.ModuleWriter;
import org.simantics.scl.compiler.internal.elaboration.decomposed.DecomposedExpression;
import org.simantics.scl.compiler.module.ConcreteModule;
import org.simantics.scl.compiler.top.SCLCompilerConfiguration;
import org.simantics.scl.compiler.types.TCon;
import org.simantics.scl.compiler.types.TPred;
import org.simantics.scl.compiler.types.Type;
import org.simantics.scl.compiler.types.Types;
import org.simantics.scl.compiler.types.exceptions.MatchException;
import org.simantics.scl.compiler.types.util.MultiFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import gnu.trove.map.hash.THashMap;

public class CodeGeneration {

    private static final Logger LOGGER = LoggerFactory.getLogger(CodeGeneration.class);

    public static final int OPTIMIZATION_PHASES = 2;
    
    CompilationContext compilationContext;
    ErrorLog errorLog;
    JavaReferenceValidator<Object, Object, Object, Object> validator;
    ConcreteModule module;
    ModuleBuilder moduleBuilder;
    
    // creates
    SSAModule ssaModule;
    ExternalConstant[] externalConstants;
    Map<String, byte[]> classes;
    
    @SuppressWarnings("unchecked")
    public CodeGeneration(CompilationContext compilationContext,
            JavaReferenceValidator<?, ?, ?, ?> javaReferenceValidator,
            ConcreteModule module) {    
        this.compilationContext = compilationContext;
        this.errorLog = compilationContext.errorLog;
        this.module = module;
        this.validator = (JavaReferenceValidator<Object, Object, Object, Object>) javaReferenceValidator;
        
        moduleBuilder = new ModuleBuilder(compilationContext.namingPolicy, compilationContext.javaTypeTranslator);
    }
    
    public void simplifyValues() {
        //LOGGER.info("===== Simplify values =====");
        
        Collection<SCLValue> values = module.getValues();
        SimplificationContext simplificationContext = new SimplificationContext(compilationContext, validator);
        //LOGGER.info("-----------------------------------------------");
        SCLValue[] valueArray = values.toArray(new SCLValue[values.size()]);
        
        for(SCLValue value : valueArray) {
            if(value.getMacroRule() instanceof StandardMacroRule) {
                StandardMacroRule rule = (StandardMacroRule)value.getMacroRule();
                rule.setBaseExpression(value.getExpression().copy());
            }
        }
        
        // Simplify
        for(SCLValue value : valueArray) {
            //LOGGER.info("BEFORE " + value.getName() + " = " + value.getExpression());
            value.getSimplifiedExpression(simplificationContext);
            //LOGGER.info("AFTER " + value.getName() + " = " + value.getExpression());
        }
    }
    
    public void convertToSSA() {
        ModuleWriter mw = new ModuleWriter(compilationContext.namingPolicy.getModuleClassName(), compilationContext.lineLocator);
        for(SCLValue value : module.getValues()) {
            //LOGGER.info(value.getName().name + " :: " + value.getType());
            Expression expression = value.getExpression();
            if(expression == null)
                continue;

            Name name = value.getName();
            
            SCLConstant constant = new SCLConstant(name, value.getType());
            value.setValue(constant);            
            /*constant.setBase(new JavaStaticMethod(
                    namingPolicy.getModuleClassName(), namingPolicy.getMethodName(name.name), 
                    decomposed.effect,
                    decomposed.typeParameters, 
                    decomposed.returnType, 
                    decomposed.parameterTypes));*/
            boolean isDerived = false;
            for(SCLValueProperty prop : value.getProperties()) {
                if(prop instanceof InlineProperty) {
                    InlineProperty inlineProperty = (InlineProperty)prop;
                    constant.setInlineArity(inlineProperty.arity, inlineProperty.phaseMask);
                }
                else if(prop == PrivateProperty.INSTANCE)
                    constant.setPrivate(!isDerived);
                else if(prop == DerivedProperty.INSTANCE) {
                    constant.setPrivate(false);
                    isDerived = true;
                }
            }
        }
        // This is quite hackish optimization that can be possibly removed when
        // better optimizations exist
        /*for(SCLValue value : module.getValues()) {
            Expression expression = value.getExpression();
            if(!(expression instanceof EConstant))
                continue;
            EConstant constant = (EConstant)expression;
            if(constant.getTypeParameters().length > 0)
                continue;
            
            //LOGGER.info(value.getName() + " <- " + constant.getValue().getName());
            value.setValue(constant.getValue().getValue());
            value.setExpression(null); // HMM??
        }*/
        for(SCLValue value : module.getValues()) {
            try {
                Expression expression = value.getExpression();
                if(expression == null)
                    continue;
     
                DecomposedExpression decomposed = 
                        DecomposedExpression.decompose(errorLog, expression);
    
                CodeWriter w = mw.createFunction((SCLConstant)value.getValue(),
                        decomposed.typeParameters,
                        decomposed.effect,
                        decomposed.returnType, 
                        decomposed.parameterTypes);    
                if(value.getValue() instanceof SCLConstant) // FIXME should be redundant test, if expression is nulled above
                    ((SCLConstant)value.getValue()).setDefinition(w.getFunction());
                IVal[] parameterVals = w.getParameters();
                for(int i=0;i<decomposed.parameters.length;++i)
                    decomposed.parameters[i].setVal(parameterVals[i]);
                w.return_(expression.location, decomposed.body.toVal(compilationContext, w));            
            } catch(RuntimeException e) {
                long location = value.getExpression().location;
                if(location == Locations.NO_LOCATION)
                    location = value.definitionLocation;
                errorLog.setExceptionPosition(location);
                throw e;
            }
        }
        ssaModule = mw.getModule();
        if(SCLCompilerConfiguration.DEBUG)
            ssaModule.validate();
        
        this.externalConstants = mw.getExternalConstants();
    }
    
    public void optimizeSSA() {
        if(SCLCompilerConfiguration.SHOW_SSA_BEFORE_OPTIMIZATION && SCLCompilerConfiguration.debugFilter(module.getName())) {
            LOGGER.info("=== SSA before optimization ====================================");
            LOGGER.info("{}", ssaModule);            
        }
        if(SCLCompilerConfiguration.DEBUG)
        ssaModule.validate();
        int optCount = 0;
        for(int phase=0;phase<OPTIMIZATION_PHASES;++phase) {
            while(optCount++ < 100 && ssaModule.simplify(compilationContext.environment, phase)) {
                //LOGGER.info("simplify " + optCount);
                //LOGGER.info("================================================================");
                //LOGGER.info(ssaModule);        
            }
            if(phase == 0)
                ssaModule.saveInlinableDefinitions();
        }
        if(SCLCompilerConfiguration.SHOW_SSA_BEFORE_LAMBDA_LIFTING && SCLCompilerConfiguration.debugFilter(module.getName())) {
            LOGGER.info("=== SSA before lambda lifting ==================================");
            LOGGER.info("{}", ssaModule);            
        }
        ssaModule.lambdaLift(errorLog);
        //ssaModule.validate();
        // TODO prevent creating more lambdas here
        //ssaModule.simplify(environment);
        ssaModule.markGenerateOnFly();
    }
    
    public void generateCode() {
        if(SCLCompilerConfiguration.SHOW_FINAL_SSA && SCLCompilerConfiguration.debugFilter(module.getName())) {
            LOGGER.info("=== Final SSA ==================================================");
            LOGGER.info("{}", ssaModule);
        }
        try {
            ssaModule.generateCode(moduleBuilder);
        } catch (CodeBuildingException e) {
            errorLog.log(e.getMessage());
        }
        if(SCLCompilerConfiguration.TRACE_MAX_METHOD_SIZE && moduleBuilder.getMethodSizeCounter() != null)
            LOGGER.info("[Max method size] " + module.getName() + ": " + moduleBuilder.getMethodSizeCounter());
        classes = moduleBuilder.getClasses();
    }
    
    public void generateDataTypes(ArrayList<StandardTypeConstructor> dataTypes) {
        for(StandardTypeConstructor dataType : dataTypes) {
            if(dataType.external)
                continue;
            if(dataType.constructors.length == 1) {
                Constructor constructor = dataType.constructors[0];
                if(constructor.parameterTypes.length != 1) {                    
                    String javaName = MethodBuilderBase.getClassName(dataType.getTypeDesc());
                    if(SCLCompilerConfiguration.TRACE_METHOD_CREATION)
                        LOGGER.info("Create class " + javaName);
                    ClassBuilder cf = new ClassBuilder(moduleBuilder, Opcodes.ACC_PUBLIC, javaName, "java/lang/Object");
                    cf.setSourceFile("_SCL_DataType");
                    CodeBuilderUtils.makeRecord(cf, constructor.name.name,
                            Opcodes.ACC_PUBLIC | Opcodes.ACC_FINAL, "c", 
                            compilationContext.javaTypeTranslator.toTypeDescs(constructor.parameterTypes),
                            true);
                    moduleBuilder.addClass(cf);
                }
            }
            else {                
                String javaName = MethodBuilderBase.getClassName(dataType.getTypeDesc());
                // Create supertype
                {
                    if(SCLCompilerConfiguration.TRACE_METHOD_CREATION)
                        LOGGER.info("Create class " + javaName);
                    ClassBuilder cf = new ClassBuilder(moduleBuilder,
                            Opcodes.ACC_ABSTRACT | Opcodes.ACC_PUBLIC,
                            javaName, "java/lang/Object");
                    cf.setSourceFile("_SCL_DataType");
                    cf.addDefaultConstructor();
                    moduleBuilder.addClass(cf);
                }
    
                // Create constructors
                for(Constructor constructor : dataType.constructors) {
                    if(SCLCompilerConfiguration.TRACE_METHOD_CREATION)
                        LOGGER.info("Create class " + constructor.javaName);
                    ClassBuilder cf = new ClassBuilder(moduleBuilder, Opcodes.ACC_PUBLIC, constructor.javaName, javaName);
                    cf.setSourceFile("_SCL_DataType");
                    CodeBuilderUtils.makeRecord(cf, constructor.name.name,
                            Opcodes.ACC_PUBLIC | Opcodes.ACC_FINAL, "c", 
                            compilationContext.javaTypeTranslator.toTypeDescs(constructor.parameterTypes),
                            true);
                    moduleBuilder.addClass(cf);
                }
            }
        }
    }
    
    public void generateTypeClasses() {
        for(TypeClass typeClass : module.getTypeClasses()) {
            final JavaTypeTranslator javaTypeTranslator = moduleBuilder.getJavaTypeTranslator();
            
            if(SCLCompilerConfiguration.TRACE_METHOD_CREATION)
                LOGGER.info("Create class " + typeClass.javaName);
            final ClassBuilder cf = new ClassBuilder(moduleBuilder,
                    Opcodes.ACC_INTERFACE | Opcodes.ACC_PUBLIC,
                    typeClass.javaName, "java/lang/Object");
  
            for(int i=0;i<typeClass.context.length;++i) {
                TPred sup = typeClass.context[i];
                /*if(Types.equals(sup.parameters, typeClass.parameters))
                    cf.addInterface(javaTypeTranslator.toTypeDesc(sup).getDescriptor());
                else*/
                    cf.addAbstractMethod(Opcodes.ACC_PUBLIC | Opcodes.ACC_ABSTRACT, "super" + i,
                            javaTypeTranslator.toTypeDesc(sup),
                            Constants.EMPTY_TYPEDESC_ARRAY);
            }

            for(TypeClassMethod method : typeClass.methods.values()) {
                MultiFunction mfun;
                try {
                    mfun = Types.matchFunction(method.getBaseType(), method.getArity());
                } catch (MatchException e) {
                    throw new InternalCompilerError("Method " + method.getName() + " has too high arity.");
                }
                cf.addAbstractMethod(Opcodes.ACC_PUBLIC | Opcodes.ACC_ABSTRACT, method.getJavaName(), 
                        javaTypeTranslator.toTypeDesc(mfun.returnType),
                        JavaTypeTranslator.filterVoid(javaTypeTranslator.toTypeDescs(mfun.parameterTypes)));
            }
    
            moduleBuilder.addClass(cf);
        }
    }
    
    public void generateTypeClassInstances() {
    	THashMap<TCon, ArrayList<TypeClassInstance>> typeInstances = module.getTypeInstances();
    	ArrayList<TCon> cons = new ArrayList<>(typeInstances.keySet());
    	Collections.sort(cons, new Comparator<TCon>() {

			@Override
			public int compare(TCon o1, TCon o2) {
				return o1.toName().compareTo(o2.toName());
			}
		});
    	for(TCon con : cons) {
    		ArrayList<TypeClassInstance> instances = typeInstances.get(con);
            for(TypeClassInstance instance : instances)
                generateTypeClassInstance(instance);
    	}
    }

    private void generateTypeClassInstance(final TypeClassInstance instance) {
        final JavaTypeTranslator javaTypeTranslator = moduleBuilder.getJavaTypeTranslator();
        
        if(SCLCompilerConfiguration.TRACE_METHOD_CREATION)
            LOGGER.info("Create class " + instance.javaName);
        final ClassBuilder cb = new ClassBuilder(moduleBuilder, Opcodes.ACC_PUBLIC, instance.javaName, "java/lang/Object",
                instance.typeClass.javaName);
        cb.setSourceFile("_SCL_TypeClassInstance");
        
        CodeBuilderUtils.makeRecord(cb, instance.javaName, Opcodes.ACC_PRIVATE, "cx", 
                javaTypeTranslator.toTypeDescs(instance.context), false);

        for(int i=0;i<instance.superExpressions.length;++i) {
            TypeDesc returnTypeDesc = javaTypeTranslator.toTypeDesc(instance.typeClass.context[i]); 
            MethodBuilder mb = cb.addMethod(Opcodes.ACC_PUBLIC, "super" + i,
                    returnTypeDesc,
                    Constants.EMPTY_TYPEDESC_ARRAY);
            Val[] parameters = new Val[instance.context.length];
            for(int j=0;j<instance.context.length;++j)     
                parameters[j] = new LocalFieldConstant(instance.context[j], "cx"+j);
            instance.superExpressions[i].getValue().apply(mb, Type.EMPTY_ARRAY, parameters);
            mb.returnValue(returnTypeDesc);
            mb.finish();
        }
        
        for(TypeClassMethod method : instance.typeClass.methods.values()) {
        	MultiFunction mfun;
        	Type baseType = method.getBaseType();
        	try {                    
        		mfun = Types.matchFunction(baseType, method.getArity());
        	} catch (MatchException e) {
        		throw new InternalCompilerError("Method " + method.getName() + " has too high arity.");
        	}
        	//LOGGER.info("Interface types: " + Arrays.toString(types));
        	TypeDesc[] parameterTypeDescs = javaTypeTranslator.toTypeDescs(mfun.parameterTypes);
        	TypeDesc returnTypeDesc = javaTypeTranslator.toTypeDesc(mfun.returnType); 
        	MethodBuilder mb = cb.addMethod(Opcodes.ACC_PUBLIC, method.getJavaName(), 
        			returnTypeDesc,
        			JavaTypeTranslator.filterVoid(parameterTypeDescs));

        	MethodImplementation implementation = 
        			instance.methodImplementations.get(method.getName());
        	if(implementation.isDefault) {
        		IVal function = compilationContext.environment.getValue(implementation.name).getValue();

        		Val[] parameters = new Val[method.getArity() + 1];
        		MultiFunction mfun2;
        		try {
        			mfun2 = Types.matchFunction(Types.removeForAll(function.getType()), parameters.length);
        		} catch (MatchException e) {
        			throw new InternalCompilerError(e);
        		}
        		parameters[0] = new ThisConstant(instance.instance);
        		for(int i=0,j=0;i<method.getArity();++i)
        			if(javaTypeTranslator.toTypeDesc(mfun2.parameterTypes[1 + i]).equals(TypeDesc.VOID))
        				parameters[1+i] = new NoRepConstant(mfun2.parameterTypes[1 + i]);
        			else
        				parameters[1+i] = new LocalVariableConstant(mfun2.parameterTypes[1 + i], mb.getParameter(j++));
        		Type returnType = function.apply(mb, Type.EMPTY_ARRAY, parameters);
        		if(returnTypeDesc == TypeDesc.OBJECT)
        			mb.box(returnType);
        		mb.returnValue(returnTypeDesc);
        	}
        	else {
        		IVal function = module.getValue(implementation.name.name).getValue();

        		Val[] parameters = new Val[method.getArity() + instance.context.length];
        		MultiFunction mfun2;
        		try {
        			mfun2 = Types.matchFunction(Types.removeForAll(function.getType()), parameters.length);
        			//LOGGER.info("Implementation types: " + Arrays.toString(functionTypes));
        		} catch (MatchException e) {
        			throw new InternalCompilerError(e);
        		}
        		for(int i=0;i<instance.context.length;++i) 
        			parameters[i] = new LocalFieldConstant(instance.context[i], "cx"+i);
        		for(int i=0,j=0;i<method.getArity();++i)
        			if(javaTypeTranslator.toTypeDesc(mfun2.parameterTypes[instance.context.length + i]).equals(TypeDesc.VOID))
        				parameters[instance.context.length+i] = new NoRepConstant(mfun2.parameterTypes[instance.context.length + i]);
        			else
        				parameters[instance.context.length+i] = new LocalVariableConstant(mfun2.parameterTypes[instance.context.length + i], mb.getParameter(j++));
        		Type returnType = function.apply(mb, Type.EMPTY_ARRAY, parameters);
        		if(returnTypeDesc == TypeDesc.OBJECT)
        			mb.box(returnType);
        		mb.returnValue(returnTypeDesc);
        	}
        	mb.finish();
        }

        moduleBuilder.addClass(cb);
    }
}
