package org.simantics.scl.compiler.internal.codegen.chr;

import java.util.ArrayList;

import org.cojen.classfile.TypeDesc;
import org.objectweb.asm.Label;
import org.objectweb.asm.Opcodes;
import org.simantics.scl.compiler.elaboration.chr.CHRRuleset;
import org.simantics.scl.compiler.elaboration.chr.relations.CHRConstraint;
import org.simantics.scl.compiler.elaboration.chr.relations.CHRConstraint.IndexInfo;
import org.simantics.scl.compiler.internal.codegen.types.JavaTypeTranslator;
import org.simantics.scl.compiler.internal.codegen.utils.ClassBuilder;
import org.simantics.scl.compiler.internal.codegen.utils.Constants;
import org.simantics.scl.compiler.internal.codegen.utils.LocalVariable;
import org.simantics.scl.compiler.internal.codegen.utils.MethodBuilderBase;
import org.simantics.scl.compiler.internal.codegen.utils.ModuleBuilder;

public class CHRFactCodeGenerator implements CHRCodeGenerationConstants {

    private ModuleBuilder moduleBuilder; 
    private JavaTypeTranslator jtt;

    private ClassBuilder storeClassBuilder;
    private CHRRuleset ruleset;
    private CHRConstraint constraint;

    private String factClassName;
    private TypeDesc factTypeDesc;
    private ClassBuilder classBuilder;

    private TypeDesc storeTypeDesc;
    private TypeDesc[] storeTypeDescArray;

    private TypeDesc[] parameterTypeDescs;
    private boolean supportsRemoval;

    CHRFactCodeGenerator(ClassBuilder storeClassBuilder, CHRRuleset ruleset, CHRConstraint constraint) {
        this.storeClassBuilder = storeClassBuilder;
        this.ruleset = ruleset;
        this.constraint = constraint;

        this.moduleBuilder = storeClassBuilder.getModuleBuilder();
        this.jtt = moduleBuilder.getJavaTypeTranslator();
        this.storeTypeDesc = storeClassBuilder.getType();
        this.storeTypeDescArray = new TypeDesc[] { storeTypeDesc };

        this.factClassName = storeClassBuilder.getClassName() + "$" + constraint.name;
        this.factTypeDesc = TypeDesc.forClass(factClassName);
        this.classBuilder = new ClassBuilder(moduleBuilder, Opcodes.ACC_PUBLIC, factClassName, CHRFact_name);

        this.parameterTypeDescs = jtt.toTypeDescs(constraint.parameterTypes);
        this.supportsRemoval = constraint.mayBeRemoved();
    }

    public void generate(ArrayList<StoreInitialization> hashIndexInitializations) {
        generateFields(hashIndexInitializations);
        hashIndexInitializations.add(new StoreInitialization(Opcodes.ACC_PRIVATE | Opcodes.ACC_FINAL, constraint.name + "$temp", factTypeDesc, factClassName));

        generateIndices();
        generateAdd();

        if(supportsRemoval)
            generateRemove();

        generateConstructor();
        classBuilder.addDefaultConstructor();

        moduleBuilder.addClass(classBuilder);
    }

    private void generateIndices() {
        // public ExampleFact ExampleFact$bf(int c0) {
        //     ExampleFact$temp.c0 = c0;
        //     return (ExampleFact)ExampleFact_bfIndex.getEqual(ExampleFact$temp);
        // }

        for(IndexInfo indexInfo : constraint.getIndices()) {
            if(indexInfo.indexMask != 0) {
                ArrayList<TypeDesc> getParameterTypeDescs = new ArrayList<TypeDesc>(constraint.parameterTypes.length);
                for(int i=0;i<constraint.parameterTypes.length;++i)
                    if(((indexInfo.indexMask>>i)&1)==1)
                        getParameterTypeDescs.add(parameterTypeDescs[i]);
                MethodBuilderBase mb = storeClassBuilder.addMethodBase(Opcodes.ACC_PUBLIC, constraint.name + "$" + indexInfo.indexName, factTypeDesc,
                        getParameterTypeDescs.toArray(new TypeDesc[getParameterTypeDescs.size()]));

                // ExampleFact$temp.c0 = c0;
                mb.loadThis();
                mb.loadField(storeClassBuilder.getClassName(), constraint.name + "$temp", factTypeDesc);
                LocalVariable tempFactVar = mb.createLocalVariable("temp", factTypeDesc);
                mb.storeLocal(tempFactVar);
                int parameterId=0;
                for(int i=0;i<constraint.parameterTypes.length;++i)
                    if(((indexInfo.indexMask>>i)&1)==1) {
                        TypeDesc typeDesc = parameterTypeDescs[i];
                        if(!typeDesc.equals(TypeDesc.VOID)) {
                            mb.loadLocal(tempFactVar);
                            mb.loadLocal(mb.getParameter(parameterId));
                            mb.storeField(factClassName, CHRCodeGenerationConstants.fieldName(i), typeDesc);
                        }
                        ++parameterId;
                    }

                // return (ExampleFact)ExampleFact_bfIndex.getEqual(ExampleFact$temp);
                mb.loadThis();
                mb.loadField(storeClassBuilder.getClassName(), constraint.name + "$" + indexInfo.indexName, CHRHashIndex);
                mb.loadLocal(tempFactVar);
                mb.invokeVirtual(CHRHashIndex_name, supportsRemoval ? "getEqual" : "getEqualNoRemovals", TypeDesc.OBJECT, Constants.OBJECTS[1]);
                mb.checkCast(factTypeDesc);
                mb.returnValue(factTypeDesc);
                mb.finish();
            }
        }   
    }

    private void generateConstructor() {
        // public ExampleFact(int id, int c0, int c1) {
        //     this.id = id;            
        //     this.c0 = c0;
        //     this.c1 = c1;
        // }

        ArrayList<TypeDesc> constructorParameters = new ArrayList<TypeDesc>(parameterTypeDescs.length+1);
        constructorParameters.add(FACT_ID_TYPE);
        for(TypeDesc typeDesc : parameterTypeDescs) {
            if(typeDesc.equals(TypeDesc.VOID))
                continue;
            constructorParameters.add(typeDesc);
        }
        MethodBuilderBase mb = classBuilder.addConstructorBase(Opcodes.ACC_PUBLIC, constructorParameters.toArray(new TypeDesc[constructorParameters.size()]));
        mb.loadThis();
        mb.invokeConstructor(classBuilder.getSuperClassName(), Constants.EMPTY_TYPEDESC_ARRAY);
        mb.loadThis();
        mb.loadLocal(mb.getParameter(0));
        mb.storeField(CHRFact_name, "id", FACT_ID_TYPE);
        for(int i=0,parameterId=1;i<constraint.parameterTypes.length;++i) {
            TypeDesc typeDesc = parameterTypeDescs[i];
            if(typeDesc.equals(TypeDesc.VOID))
                continue;
            mb.loadThis();
            mb.loadLocal(mb.getParameter(parameterId++));
            mb.storeField(factClassName, CHRCodeGenerationConstants.fieldName(i), typeDesc);
        }
        mb.returnVoid();
        mb.finish();
    }

    private void generateAdd() {
        MethodBuilderBase mb = classBuilder.addMethodBase(Opcodes.ACC_PUBLIC, "add", TypeDesc.VOID, new TypeDesc[] {storeTypeDesc, CHRContext});
        LocalVariable storeParameter = mb.getParameter(0);
        
        // Add fact to indices
        for(IndexInfo indexInfo : constraint.getIndices()) {
            String linkedListPrev = indexInfo.indexName + "Prev";
            String linkedListNext = indexInfo.indexName + "Next";
            String storeHashIndexName = constraint.name + "$" + indexInfo.indexName;

            // public void add(ExampleStore store) {
            //     bfNext = (ExampleFact)store.ExampleFact_bfIndex.addFreshAndReturnOld(this);
            //     if(bfNext != null)
            //         bfNext.bfPrev = this;
            // }

            if(indexInfo.indexMask == 0) {
                mb.loadThis();
                mb.loadLocal(storeParameter);
                mb.loadField(storeClassBuilder.getClassName(), storeHashIndexName, factTypeDesc);
                if(supportsRemoval)
                    mb.dupX1();
                mb.storeField(factClassName, linkedListNext, factTypeDesc);
                if(supportsRemoval) {
                    Label cont = new Label();
                    mb.ifNullBranch(cont, true);
                    mb.loadThis();
                    mb.loadField(factClassName, linkedListNext, factTypeDesc);
                    mb.loadThis();
                    mb.storeField(factClassName, linkedListPrev, factTypeDesc);
                    mb.setLocation(cont);
                }
                mb.loadLocal(storeParameter);
                mb.loadThis();
                mb.storeField(storeClassBuilder.getClassName(), storeHashIndexName, factTypeDesc);
            }
            else {
                // bfNext = (ExampleFact)store.ExampleFact_bfIndex.addFreshAndReturnOld(this);
                mb.loadThis();
                mb.loadLocal(storeParameter);
                mb.loadField(storeClassBuilder.getClassName(), storeHashIndexName, CHRHashIndex);
                mb.loadThis();
                mb.invokeVirtual(CHRHashIndex_name, supportsRemoval ? "addFreshAndReturnOld" : "addFreshAndReturnOld", TypeDesc.OBJECT, Constants.OBJECTS[1]);
                mb.checkCast(factTypeDesc);
                if(supportsRemoval)
                    mb.dupX1();
                mb.storeField(factClassName, linkedListNext, factTypeDesc);
                // leaves bfNext on the stack

                //if(bfNext != null)
                //    bfNext.bfPrev = this;
                if(supportsRemoval) {
                    Label cont = new Label();
                    mb.ifNullBranch(cont, true);
                    mb.loadThis();
                    mb.loadField(factClassName, linkedListNext, factTypeDesc);
                    mb.loadThis();
                    mb.storeField(factClassName, linkedListPrev, factTypeDesc);
                    mb.setLocation(cont);
                }
            }
        }
        
        // Add fact to priority queue
        int minimumPriority = ruleset.getMinimumPriority(constraint);
        if(minimumPriority != Integer.MAX_VALUE) {
            mb.loadLocal(storeParameter);
            mb.loadField(storeClassBuilder.getClassName(), CHRCodeGenerationConstants.priorityName(minimumPriority), CHRPriorityFactContainer);
            mb.loadLocal(mb.getParameter(1));
            mb.loadThis();
            mb.invokeVirtual(CHRPriorityFactContainer_name, "addFact", TypeDesc.VOID, new TypeDesc[] {CHRContext, CHRFact});
        }
        else if(constraint.nextContainerFieldName != null) {
            mb.loadLocal(storeParameter);
            mb.loadField(storeClassBuilder.getClassName(), constraint.nextContainerFieldName, CHRPriorityFactContainer);
            LocalVariable containerVar = mb.createLocalVariable("container", CHRPriorityFactContainer);
            mb.storeLocal(containerVar);
            
            mb.loadLocal(containerVar);
            Label finishLabel = mb.createLabel();
            mb.ifNullBranch(finishLabel, true);
            
            mb.loadLocal(containerVar);
            mb.loadLocal(mb.getParameter(1));
            mb.loadThis();
            mb.invokeVirtual(CHRPriorityFactContainer_name, "addFact", TypeDesc.VOID, new TypeDesc[] {CHRContext, CHRFact});
            mb.setLocation(finishLabel);
        }
        mb.returnVoid();
        mb.finish();
    }

    private void generateFields(ArrayList<StoreInitialization> hashIndexInitializations) {
        // public int id;
        // public int c0; // key
        // public int c1;
        // public ExampleFact bfPrev;
        // public ExampleFact bfNext;

        //classBuilder.addField(Opcodes.ACC_PUBLIC, "id", FACT_ID_TYPE);
        for(int i=0;i<constraint.parameterTypes.length;++i) {
            TypeDesc typeDesc = parameterTypeDescs[i];
            if(typeDesc.equals(TypeDesc.VOID))
                continue;
            if(parameterTypeDescs[i] != TypeDesc.VOID)
                classBuilder.addField(Opcodes.ACC_PUBLIC, CHRCodeGenerationConstants.fieldName(i), typeDesc);
        }

        for(IndexInfo indexInfo : constraint.getIndices()) {
            if(supportsRemoval)
                classBuilder.addField(Opcodes.ACC_PUBLIC, indexInfo.indexName + "Prev", factTypeDesc);
            classBuilder.addField(Opcodes.ACC_PUBLIC, indexInfo.indexName + "Next", factTypeDesc);

            String hashIndexField = constraint.name + "$" + indexInfo.indexName;
            if(indexInfo.indexMask == 0) {
                // If there are no bound parameters, use just a direct reference to a fact
                storeClassBuilder.addField(Opcodes.ACC_PUBLIC, hashIndexField, factTypeDesc);
            }
            else {
                ClassBuilder hashClass = CHRHashIndexCodeGenerator.generateHashIndex(storeClassBuilder, constraint, indexInfo, factTypeDesc, factClassName);
                moduleBuilder.addClass(hashClass);
                hashIndexInitializations.add(new StoreInitialization(Opcodes.ACC_PUBLIC | Opcodes.ACC_FINAL, hashIndexField, CHRHashIndex, hashClass.getClassName()));
            }
        }
    }

    private void generateRemove() {
        // public void remove(ExampleStore store) {
        //     if(bfPrev == null) {
        //         if(bfNext == null)
        //             store.ExampleFact_bfIndex.removeKnownToExistKey(this);
        //         else {
        //             bfNext.bfPrev = null;
        //             store.ExampleFact_bfIndex.replaceKnownToExistKey(this, bfNext);
        //         }
        //     }
        //     else {
        //         bfPrev.bfNext = bfNext;
        //         if(bfNext != null)
        //             bfNext.bfPrev = bfPrev;
        //     }
        // }

        MethodBuilderBase mb = classBuilder.addMethodBase(Opcodes.ACC_PUBLIC, "remove", TypeDesc.VOID, storeTypeDescArray);
        LocalVariable storeParameter = mb.getParameter(0);
        for(IndexInfo indexInfo : constraint.getIndices()) {
            String linkedListPrev = indexInfo.indexName + "Prev";
            String linkedListNext = indexInfo.indexName + "Next";
            String storeHashIndexName = constraint.name + "$" + indexInfo.indexName;

            Label nextIndex = mb.createLabel();

            // if(bfPrev == null) {
            mb.loadThis();
            mb.loadField(factClassName, linkedListPrev, factTypeDesc);
            Label else1 = new Label();
            mb.ifNullBranch(else1, false);

            //     if(bfNext == null)
            mb.loadThis();
            mb.loadField(factClassName, linkedListNext, factTypeDesc);
            Label else2 = new Label();
            mb.ifNullBranch(else2, false);

            //         store.ExampleFact_bfIndex.removeKnownToExistKey(this);
            if(indexInfo.indexMask == 0) {
                mb.loadLocal(storeParameter);
                mb.loadNull();
                mb.storeField(storeClassBuilder.getClassName(), storeHashIndexName, factTypeDesc);
            }
            else {
                mb.loadLocal(storeParameter);
                mb.loadField(storeClassBuilder.getClassName(), storeHashIndexName, CHRHashIndex);
                mb.loadThis();
                mb.invokeVirtual(CHRHashIndex_name, "removeKnownToExistKey", TypeDesc.VOID, Constants.OBJECTS[1]);
            }
            mb.branch(nextIndex);

            //     else {
            mb.setLocation(else2);
            //         bfNext.bfPrev = null;
            mb.loadThis();
            mb.loadField(factClassName, linkedListNext, factTypeDesc);
            mb.loadNull();
            mb.storeField(factClassName, linkedListPrev, factTypeDesc);
            //         store.ExampleFact_bfIndex.replaceKnownToExistKey(this, bfNext);
            if(indexInfo.indexMask == 0) {
                mb.loadLocal(storeParameter);
                mb.loadThis();
                mb.loadField(factClassName, linkedListNext, factTypeDesc);
                mb.storeField(storeClassBuilder.getClassName(), storeHashIndexName, factTypeDesc);
            }
            else {
                mb.loadLocal(storeParameter);
                mb.loadField(storeClassBuilder.getClassName(), storeHashIndexName, CHRHashIndex);
                mb.loadThis();
                mb.loadThis();
                mb.loadField(factClassName, linkedListNext, factTypeDesc);
                mb.invokeVirtual(CHRHashIndex_name, "replaceKnownToExistKey", TypeDesc.VOID, Constants.OBJECTS[2]);
            }
            mb.branch(nextIndex);
            //     }

            // else {
            mb.setLocation(else1);
            //     bfPrev.bfNext = bfNext;
            mb.loadThis();
            mb.loadField(factClassName, linkedListPrev, factTypeDesc);
            mb.loadThis();
            mb.loadField(factClassName, linkedListNext, factTypeDesc);
            mb.storeField(factClassName, linkedListNext, factTypeDesc);
            //     if(bfNext != null)
            mb.loadThis();
            mb.loadField(factClassName, linkedListNext, factTypeDesc);
            Label else3 = new Label();
            mb.ifNullBranch(else3, true);
            //         bfNext.bfPrev = bfPrev;
            mb.loadThis();
            mb.loadField(factClassName, linkedListNext, factTypeDesc);
            mb.loadThis();
            mb.loadField(factClassName, linkedListPrev, factTypeDesc);
            mb.storeField(factClassName, linkedListPrev, factTypeDesc);
            mb.setLocation(else3);
            mb.branch(nextIndex);
            // }

            mb.setLocation(nextIndex);
        }
        mb.loadThis();
        mb.loadConstant(-1);
        mb.storeField(CHRFact_name, "id", FACT_ID_TYPE);
        mb.returnVoid();
        mb.finish();
    }
}
