package org.simantics.scl.compiler.internal.elaboration.matching;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.simantics.scl.compiler.common.exceptions.InternalCompilerError;
import org.simantics.scl.compiler.compilation.CompilationContext;
import org.simantics.scl.compiler.constants.Constant;
import org.simantics.scl.compiler.elaboration.expressions.EApply;
import org.simantics.scl.compiler.elaboration.expressions.EApplyType;
import org.simantics.scl.compiler.elaboration.expressions.EAsPattern;
import org.simantics.scl.compiler.elaboration.expressions.EConstant;
import org.simantics.scl.compiler.elaboration.expressions.EExternalConstant;
import org.simantics.scl.compiler.elaboration.expressions.ELiteral;
import org.simantics.scl.compiler.elaboration.expressions.EVariable;
import org.simantics.scl.compiler.elaboration.expressions.EViewPattern;
import org.simantics.scl.compiler.elaboration.expressions.Expression;
import org.simantics.scl.compiler.elaboration.expressions.GuardedExpressionGroup;
import org.simantics.scl.compiler.elaboration.java.DynamicConstructor;
import org.simantics.scl.compiler.elaboration.modules.SCLValue;
import org.simantics.scl.compiler.elaboration.modules.TypeConstructor;
import org.simantics.scl.compiler.internal.codegen.continuations.Branch;
import org.simantics.scl.compiler.internal.codegen.continuations.ICont;
import org.simantics.scl.compiler.internal.codegen.references.IVal;
import org.simantics.scl.compiler.internal.codegen.writer.CodeWriter;
import org.simantics.scl.compiler.types.TCon;
import org.simantics.scl.compiler.types.Types;
import org.simantics.scl.compiler.types.exceptions.MatchException;

import gnu.trove.map.hash.THashMap;

public class PatternMatchingCompiler {

    private static class ExpressionMatrix {
        final CodeWriter w;
        final IVal[] scrutinee;
        final List<Row> rows = new ArrayList<Row>();

        public ExpressionMatrix(CodeWriter w, IVal[] scrutinee) {
            this.w = w;
            this.scrutinee = scrutinee;
        }
    }

    public static IVal[] replace(IVal[] vals, int columnToReplace, IVal ... substitution) {
        IVal[] newVals = new IVal[vals.length-1+substitution.length];
        int j=0;
        for(int i=0;i<columnToReplace;++i)
            newVals[j++] = vals[i];
        for(int i=0;i<substitution.length;++i)
            newVals[j++] = substitution[i];
        for(int i=columnToReplace+1;i<vals.length;++i)
            newVals[j++] = vals[i];
        return newVals;
    }

    private static void splitByConstructors(long location, CodeWriter w, final CompilationContext context, IVal[] scrutinee, final ICont success, ICont failure, List<Row> rows, int columnId) {
        THashMap<Object, ExpressionMatrix> matrixMap = new THashMap<Object, ExpressionMatrix>();
        ArrayList<Branch> branches = new ArrayList<Branch>();
        ArrayList<ExpressionMatrix> matrices = new ArrayList<ExpressionMatrix>();
        
        /*System.out.println("---");
        for(Row row : rows) {
            for(Expression e : row.patterns)
                System.out.print(e + " ");
            System.out.println();
        }*/

        int i;
        for(i=0;i<rows.size();++i) {
            Row row = rows.get(i);
            Expression pattern = row.patterns[columnId];
            while(true) {
                if(pattern instanceof EApplyType)
                    pattern = ((EApplyType)pattern).getExpression();
                else if(pattern instanceof EAsPattern) {
                    EAsPattern asPattern = (EAsPattern)pattern;
                    pattern = asPattern.getPattern();
                    asPattern.getVariable().setVal(scrutinee[columnId]);
                }
                else
                    break;
                row.patterns[columnId] = pattern;
            }
            if(pattern instanceof EVariable)
                break;
            else if(pattern instanceof EApply) {
                EApply applyConstructor = (EApply)pattern;
                Expression constructor_ = applyConstructor.getFunction();
                while(constructor_ instanceof EApplyType)
                    constructor_ = ((EApplyType)constructor_).getExpression();
                Expression[] parameters = applyConstructor.getParameters();
                // TODO How type parameters are handled???
                if(constructor_ instanceof EConstant) {
                    SCLValue constructor = ((EConstant)constructor_).getValue();
    
                    ExpressionMatrix matrix = constructor.getValue() == DynamicConstructor.INSTANCE ? null : matrixMap.get(constructor.getName());
                    if(matrix == null) {
                        CodeWriter newW = w.createBlock(Types.getTypes(parameters));
                        branches.add(new Branch((Constant)constructor.getValue(), newW.getContinuation()));
                        matrix = new ExpressionMatrix(newW, replace(scrutinee, columnId, newW.getParameters()));
                        matrices.add(matrix);
                        matrixMap.put(constructor.getName(), matrix);
                    }
                    matrix.rows.add(row.replace(columnId, parameters));
                }
                else if(constructor_ instanceof ELiteral) {
                    Constant constructor = ((ELiteral)constructor_).getValue();
                    
                    ExpressionMatrix matrix = matrixMap.get(constructor);
                    if(matrix == null) {
                        CodeWriter newW = w.createBlock(Types.getTypes(parameters));
                        branches.add(new Branch(constructor, newW.getContinuation()));
                        matrix = new ExpressionMatrix(newW, replace(scrutinee, columnId, newW.getParameters()));
                        matrices.add(matrix);
                        matrixMap.put(constructor, matrix);
                    }
                    matrix.rows.add(row.replace(columnId, parameters));
                }
            }
            else if(pattern instanceof EConstant) {
                EConstant applyConstructor = (EConstant)pattern;
                SCLValue constructor = applyConstructor.getValue();

                ExpressionMatrix matrix = matrixMap.get(constructor.getName());
                if(matrix == null) {
                    CodeWriter newW = w.createBlock();
                    branches.add(new Branch((Constant)constructor.getValue(), newW.getContinuation()));
                    matrix = new ExpressionMatrix(newW, replace(scrutinee, columnId, newW.getParameters()));
                    matrices.add(matrix);
                    matrixMap.put(constructor.getName(), matrix);
                }
                matrix.rows.add(row.replace(columnId, Expression.EMPTY_ARRAY));
            }
            else if(pattern instanceof ELiteral) {
                ELiteral literal = (ELiteral)pattern;
                Constant constructor = literal.getValue();

                ExpressionMatrix matrix = matrixMap.get(constructor);
                if(matrix == null) {
                    CodeWriter newW = w.createBlock();
                    branches.add(new Branch(constructor, newW.getContinuation()));
                    matrix = new ExpressionMatrix(newW, replace(scrutinee, columnId, newW.getParameters()));
                    matrices.add(matrix);
                    matrixMap.put(constructor, matrix);
                }
                matrix.rows.add(row.replace(columnId, Expression.EMPTY_ARRAY));
            }
            else if(pattern instanceof EExternalConstant) {
                EExternalConstant constant = (EExternalConstant)pattern;
                Constant constructor = w.getModuleWriter().getExternalConstant(constant.getValue(), constant.getType());

                ExpressionMatrix matrix = matrixMap.get(constructor);
                if(matrix == null) {
                    CodeWriter newW = w.createBlock();
                    branches.add(new Branch(constructor, newW.getContinuation()));
                    matrix = new ExpressionMatrix(newW, replace(scrutinee, columnId, newW.getParameters()));
                    matrices.add(matrix);
                    matrixMap.put(constructor, matrix);
                }
                matrix.rows.add(row.replace(columnId, Expression.EMPTY_ARRAY));
            }
            else
                throw new InternalCompilerError("Cannot handle an instance of " + pattern.getClass().getSimpleName() + " in a pattern.");
        }
        if(i < rows.size()) {
            CodeWriter newW = w.createBlock();
            ICont cont = newW.getContinuation();
            branches.add(new Branch(null, cont));
            split(location, newW, context, scrutinee, success, failure, rows.subList(i, rows.size()));
            failure = cont;
        }
        else {
            TCon con;
            try {
                con = Types.getConstructor(scrutinee[columnId].getType());
            } catch (MatchException e) {
                throw new InternalCompilerError();
            }
            TypeConstructor cons = (TypeConstructor)context.environment.getTypeDescriptor(con);
            int maxBranchCount = cons.isOpen ? Integer.MAX_VALUE 
                    : cons.constructors.length;
            if(branches.size() < maxBranchCount)
                branches.add(new Branch(null, failure));
        }

        for(ExpressionMatrix mx : matrices)
            split(location, mx.w, context, mx.scrutinee, success, failure, mx.rows);
        w.switch_(location, scrutinee[columnId], branches.toArray(new Branch[branches.size()]));
    }

    private static void splitByViewPattern(long location, CodeWriter w, CompilationContext context, IVal[] scrutinee, ICont success,
            ICont failure, List<Row> rows, int viewPatternColumn) {
        Row firstRow = rows.get(0);
        EViewPattern firstViewPattern = (EViewPattern)firstRow.patterns[viewPatternColumn];
        firstRow.patterns[viewPatternColumn] = firstViewPattern.pattern;
        int i;
        for(i=1;i<rows.size();++i) {
            Row row = rows.get(i);
            Expression pattern = row.patterns[viewPatternColumn];
            while(true) {
                if(pattern instanceof EApplyType)
                    pattern = ((EApplyType)pattern).getExpression();
                else if(pattern instanceof EAsPattern) {
                    EAsPattern asPattern = (EAsPattern)pattern;
                    pattern = asPattern.getPattern();
                    asPattern.getVariable().setVal(scrutinee[viewPatternColumn]);
                }
                else
                    break;
                row.patterns[viewPatternColumn] = pattern;
            }
            if(!(pattern instanceof EViewPattern))
                break;
            EViewPattern otherViewPattern = (EViewPattern)pattern;
            if(!otherViewPattern.expression.equalsExpression(firstViewPattern.expression))
                break;
            row.patterns[viewPatternColumn] = otherViewPattern.pattern;
        }
        
        IVal[] newScrutinee = Arrays.copyOf(scrutinee, scrutinee.length);
        newScrutinee[viewPatternColumn] =
                w.apply(firstViewPattern.location,
                        firstViewPattern.expression.toVal(context, w),
                        scrutinee[viewPatternColumn]);
        if(i == rows.size()) {
            split(location, w, context, newScrutinee, success, failure, rows);
        }
        else {
            CodeWriter cont = w.createBlock();
            split(location, w, context, newScrutinee, success, cont.getContinuation(), rows.subList(0, i));
            split(location, cont, context, scrutinee, success, failure, rows.subList(i, rows.size()));
        }
    }

    public static void split(long location, CodeWriter w, CompilationContext context, IVal[] scrutinee, ICont success, ICont failure, List<Row> rows) {
        Row firstRow = rows.get(0);
        Expression[] patterns = firstRow.patterns;
        if(scrutinee.length != patterns.length)
            throw new InternalCompilerError("Scrutinee and patterns have a different length");
        
        // Find a non-variable pattern and split by it
        int viewPatternColumn = -1;
        for(int i=0;i<patterns.length;++i) {
            Expression pattern = patterns[i];
            if(pattern instanceof EViewPattern) {
                if(viewPatternColumn == -1)
                    viewPatternColumn = i;
            }
            else if(!(pattern instanceof EVariable)) {
                splitByConstructors(location, w, context, scrutinee, success, failure, rows, i);
                return;
            }
        }
        
        if(viewPatternColumn >= 0) {
            splitByViewPattern(location, w, context, scrutinee, success, failure, rows, viewPatternColumn);
            return;
        }

        // The first row has only variable patterns: no matching needed
        for(int i=0;i<patterns.length;++i)
            ((EVariable)patterns[i]).getVariable().setVal(scrutinee[i]);
        if(firstRow.value instanceof GuardedExpressionGroup) {
            GuardedExpressionGroup group = (GuardedExpressionGroup)firstRow.value;
            if(rows.size() == 1) {
                group.compile(context, w, success, failure);
            }
            else {
                CodeWriter newW = w.createBlock();            
                ICont cont = newW.getContinuation();
                group.compile(context, w, success, cont);
                split(location, newW, context, scrutinee, success, failure, rows.subList(1, rows.size()));
            }
        }
        else
            w.jump(location, success, firstRow.value.toVal(context, w));
    }
}
